201 lines
7.9 KiB
Python
201 lines
7.9 KiB
Python
from __future__ import annotations
|
|
|
|
import json
|
|
import socket
|
|
import subprocess
|
|
import sys
|
|
import threading
|
|
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
|
|
from pathlib import Path
|
|
|
|
import pytest
|
|
|
|
REPO_ROOT = Path(__file__).resolve().parents[1]
|
|
if str(REPO_ROOT) not in sys.path:
|
|
sys.path.insert(0, str(REPO_ROOT))
|
|
|
|
from openvino_context_gate.context_gate import ( # noqa: E402
|
|
AUTHORITY,
|
|
ClassifierResult,
|
|
ContextGateError,
|
|
build_plan,
|
|
classify_live,
|
|
compact_json,
|
|
compact_line,
|
|
)
|
|
|
|
|
|
def fake_classifier(
|
|
labels: dict,
|
|
*,
|
|
endpoint_delta: int | None = 120,
|
|
sysfs_delta: int | None = 120,
|
|
outer_delta: int | None = 80,
|
|
) -> ClassifierResult:
|
|
return ClassifierResult(
|
|
labels=labels,
|
|
npu_busy_delta_us=endpoint_delta,
|
|
sysfs_npu_busy_delta_us=sysfs_delta,
|
|
outer_sysfs_delta_us=outer_delta,
|
|
live=True,
|
|
)
|
|
|
|
|
|
def labels(category: str, *, tool: bool = False, safety: bool = False, memory: str = "none") -> dict:
|
|
return {
|
|
"tool_needed": {"value": tool, "confidence": 0.8 if tool else 0.4},
|
|
"memory_candidate": {"value": memory, "confidence": 0.8 if memory != "none" else 0.3},
|
|
"urgency": {"value": "normal", "confidence": 0.6},
|
|
"workflow_category": {"value": category, "confidence": 0.86},
|
|
"safety_confirmation_required": {"value": safety, "confidence": 0.9 if safety else 0.1},
|
|
}
|
|
|
|
|
|
def test_current_npu_debug_query_selects_ops_live_and_repo_sources() -> None:
|
|
plan = build_plan(
|
|
"How do I check whether the RAG reranker is using the NPU?",
|
|
context={"platform": "cli", "repo_path": "/home/will/lab/swarm"},
|
|
classifier=fake_classifier(labels("devops", tool=True)),
|
|
)
|
|
assert plan["schema"] == "atlas_context_gate_plan_v1"
|
|
assert plan["bundle_plan"]["bundle_name"] == "OpsDebugBundle"
|
|
assert [s["source"] for s in plan["source_plan"]][:2] == ["live_system", "repo_files"]
|
|
assert plan["npu_proof"]["verified"] is True
|
|
assert plan["authority"] == AUTHORITY
|
|
assert all(value.startswith("closed_") for value in plan["gates"].values())
|
|
|
|
|
|
def test_prior_plan_query_uses_session_or_rag_and_coding_for_kanban() -> None:
|
|
plan = build_plan(
|
|
"Where did we leave the NPU context gate implementation plan?",
|
|
context={"platform": "kanban", "task_id": "t_example", "repo_path": "/home/will/lab/swarm"},
|
|
classifier=fake_classifier(labels("coding", tool=True)),
|
|
)
|
|
sources = [s["source"] for s in plan["source_plan"]]
|
|
assert plan["bundle_plan"]["bundle_name"] == "CodingTaskBundle"
|
|
assert "repo_files" in sources
|
|
assert "session_search" in sources
|
|
assert "rag_search" in sources
|
|
|
|
|
|
def test_simple_creative_query_no_retrieval_offline_no_npu_claim() -> None:
|
|
plan = build_plan("Write a haiku about Seattle rain.")
|
|
assert plan["bundle_plan"]["bundle_name"] == "SimpleResponseBundle"
|
|
assert [s["source"] for s in plan["source_plan"]] == ["no_retrieval"]
|
|
assert plan["npu_proof"]["verified"] is False
|
|
assert "npu_proof_inconclusive" in plan["warnings"]
|
|
assert "offline_heuristic_classifier_no_npu_claim" in plan["warnings"]
|
|
|
|
|
|
def test_unsafe_live_routing_request_keeps_authority_closed_and_blocks_side_effect() -> None:
|
|
plan = build_plan(
|
|
"Change Hermes live routing to use the classifier automatically.",
|
|
context={"repo_path": "/home/will/lab/swarm"},
|
|
classifier=fake_classifier(labels("coding", tool=True, safety=True)),
|
|
)
|
|
assert plan["authority"] == AUTHORITY
|
|
assert plan["authority"]["may_route"] is False
|
|
assert any(field["field"] == "authority_side_effect" for field in plan["bundle_plan"]["blocked_fields"])
|
|
assert plan["gates"]["live_routing_change"] == "closed_requires_explicit_approval"
|
|
|
|
|
|
def test_rejects_non_dry_run_and_private_text_options() -> None:
|
|
with pytest.raises(ContextGateError, match="dry_run_must_remain_true"):
|
|
build_plan("hello", options={"dry_run": False})
|
|
with pytest.raises(ContextGateError, match="include_private_text"):
|
|
build_plan("hello", options={"include_private_text": True})
|
|
|
|
|
|
def test_compact_outputs_are_small_and_parseable() -> None:
|
|
plan = build_plan("How do I check whether port 18819 is healthy?")
|
|
line = compact_line(plan)
|
|
assert "schema=atlas_context_gate_plan_v1" in line
|
|
assert "gates=closed:" in line
|
|
parsed = json.loads(compact_json(plan))
|
|
assert parsed["schema"] == "atlas_context_gate_plan_v1"
|
|
assert isinstance(parsed["sources"], list)
|
|
assert "authority" in parsed
|
|
|
|
|
|
def test_cli_offline_compact_json_smoke() -> None:
|
|
script = REPO_ROOT / "scripts" / "context-gate-advisory.py"
|
|
result = subprocess.run(
|
|
[sys.executable, str(script), "--offline", "--query", "Write a haiku about Seattle rain.", "--format", "compact-json"],
|
|
check=True,
|
|
text=True,
|
|
capture_output=True,
|
|
cwd=REPO_ROOT,
|
|
)
|
|
parsed = json.loads(result.stdout)
|
|
assert parsed["ok"] is True
|
|
assert parsed["bundle_name"] == "SimpleResponseBundle"
|
|
assert parsed["sources"] == ["no_retrieval"]
|
|
assert parsed["npu_proof"]["verified"] is False
|
|
|
|
|
|
def test_npu_proof_requires_positive_sysfs_delta() -> None:
|
|
classifier = fake_classifier(labels("devops", tool=True), endpoint_delta=120, sysfs_delta=0, outer_delta=None)
|
|
plan = build_plan("How do I check whether the RAG reranker is using the NPU?", classifier=classifier)
|
|
assert plan["npu_proof"]["verified"] is False
|
|
assert "npu_proof_inconclusive" in plan["warnings"]
|
|
|
|
endpoint_sysfs_plan = build_plan(
|
|
"How do I check whether the RAG reranker is using the NPU?",
|
|
classifier=fake_classifier(labels("devops", tool=True), endpoint_delta=120, sysfs_delta=1, outer_delta=None),
|
|
)
|
|
assert endpoint_sysfs_plan["npu_proof"]["verified"] is True
|
|
|
|
outer_sysfs_plan = build_plan(
|
|
"How do I check whether the RAG reranker is using the NPU?",
|
|
classifier=fake_classifier(labels("devops", tool=True), endpoint_delta=120, sysfs_delta=0, outer_delta=1),
|
|
)
|
|
assert outer_sysfs_plan["npu_proof"]["verified"] is True
|
|
|
|
|
|
def test_classifier_url_must_be_loopback_or_localhost() -> None:
|
|
for url in [
|
|
"http://example.com/v1/classify",
|
|
"https://10.0.0.5/v1/classify",
|
|
"http://0.0.0.0:18819/v1/classify",
|
|
"ftp://127.0.0.1/v1/classify",
|
|
]:
|
|
with pytest.raises(ContextGateError, match="invalid_classifier_url"):
|
|
classify_live("hello", classifier_url=url, timeout=0.01)
|
|
|
|
|
|
def test_classifier_url_redirect_to_non_loopback_is_not_followed(monkeypatch: pytest.MonkeyPatch) -> None:
|
|
requests: list[str] = []
|
|
|
|
class RedirectHandler(BaseHTTPRequestHandler):
|
|
def do_POST(self) -> None: # noqa: N802 - stdlib callback name
|
|
requests.append(self.path)
|
|
self.send_response(302)
|
|
self.send_header("Location", "http://example.com/v1/classify")
|
|
self.end_headers()
|
|
|
|
def log_message(self, format: str, *args: object) -> None:
|
|
return
|
|
|
|
original_create_connection = socket.create_connection
|
|
|
|
def guarded_create_connection(address, *args, **kwargs): # type: ignore[no-untyped-def]
|
|
host = address[0]
|
|
if host not in {"127.0.0.1", "localhost"}:
|
|
raise AssertionError(f"attempted non-loopback redirect connection to {host}")
|
|
return original_create_connection(address, *args, **kwargs)
|
|
|
|
server = ThreadingHTTPServer(("127.0.0.1", 0), RedirectHandler)
|
|
thread = threading.Thread(target=server.serve_forever, daemon=True)
|
|
thread.start()
|
|
monkeypatch.setattr(socket, "create_connection", guarded_create_connection)
|
|
try:
|
|
url = f"http://127.0.0.1:{server.server_port}/v1/classify"
|
|
with pytest.raises(ContextGateError, match="classifier_unavailable"):
|
|
classify_live("hello", classifier_url=url, timeout=1.0)
|
|
finally:
|
|
server.shutdown()
|
|
server.server_close()
|
|
thread.join(timeout=2)
|
|
|
|
assert requests == ["/v1/classify"]
|