feat(npu): add local context gate advisory
This commit is contained in:
@@ -0,0 +1,200 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import socket
|
||||
import subprocess
|
||||
import sys
|
||||
import threading
|
||||
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
REPO_ROOT = Path(__file__).resolve().parents[1]
|
||||
if str(REPO_ROOT) not in sys.path:
|
||||
sys.path.insert(0, str(REPO_ROOT))
|
||||
|
||||
from openvino_context_gate.context_gate import ( # noqa: E402
|
||||
AUTHORITY,
|
||||
ClassifierResult,
|
||||
ContextGateError,
|
||||
build_plan,
|
||||
classify_live,
|
||||
compact_json,
|
||||
compact_line,
|
||||
)
|
||||
|
||||
|
||||
def fake_classifier(
|
||||
labels: dict,
|
||||
*,
|
||||
endpoint_delta: int | None = 120,
|
||||
sysfs_delta: int | None = 120,
|
||||
outer_delta: int | None = 80,
|
||||
) -> ClassifierResult:
|
||||
return ClassifierResult(
|
||||
labels=labels,
|
||||
npu_busy_delta_us=endpoint_delta,
|
||||
sysfs_npu_busy_delta_us=sysfs_delta,
|
||||
outer_sysfs_delta_us=outer_delta,
|
||||
live=True,
|
||||
)
|
||||
|
||||
|
||||
def labels(category: str, *, tool: bool = False, safety: bool = False, memory: str = "none") -> dict:
|
||||
return {
|
||||
"tool_needed": {"value": tool, "confidence": 0.8 if tool else 0.4},
|
||||
"memory_candidate": {"value": memory, "confidence": 0.8 if memory != "none" else 0.3},
|
||||
"urgency": {"value": "normal", "confidence": 0.6},
|
||||
"workflow_category": {"value": category, "confidence": 0.86},
|
||||
"safety_confirmation_required": {"value": safety, "confidence": 0.9 if safety else 0.1},
|
||||
}
|
||||
|
||||
|
||||
def test_current_npu_debug_query_selects_ops_live_and_repo_sources() -> None:
|
||||
plan = build_plan(
|
||||
"How do I check whether the RAG reranker is using the NPU?",
|
||||
context={"platform": "cli", "repo_path": "/home/will/lab/swarm"},
|
||||
classifier=fake_classifier(labels("devops", tool=True)),
|
||||
)
|
||||
assert plan["schema"] == "atlas_context_gate_plan_v1"
|
||||
assert plan["bundle_plan"]["bundle_name"] == "OpsDebugBundle"
|
||||
assert [s["source"] for s in plan["source_plan"]][:2] == ["live_system", "repo_files"]
|
||||
assert plan["npu_proof"]["verified"] is True
|
||||
assert plan["authority"] == AUTHORITY
|
||||
assert all(value.startswith("closed_") for value in plan["gates"].values())
|
||||
|
||||
|
||||
def test_prior_plan_query_uses_session_or_rag_and_coding_for_kanban() -> None:
|
||||
plan = build_plan(
|
||||
"Where did we leave the NPU context gate implementation plan?",
|
||||
context={"platform": "kanban", "task_id": "t_example", "repo_path": "/home/will/lab/swarm"},
|
||||
classifier=fake_classifier(labels("coding", tool=True)),
|
||||
)
|
||||
sources = [s["source"] for s in plan["source_plan"]]
|
||||
assert plan["bundle_plan"]["bundle_name"] == "CodingTaskBundle"
|
||||
assert "repo_files" in sources
|
||||
assert "session_search" in sources
|
||||
assert "rag_search" in sources
|
||||
|
||||
|
||||
def test_simple_creative_query_no_retrieval_offline_no_npu_claim() -> None:
|
||||
plan = build_plan("Write a haiku about Seattle rain.")
|
||||
assert plan["bundle_plan"]["bundle_name"] == "SimpleResponseBundle"
|
||||
assert [s["source"] for s in plan["source_plan"]] == ["no_retrieval"]
|
||||
assert plan["npu_proof"]["verified"] is False
|
||||
assert "npu_proof_inconclusive" in plan["warnings"]
|
||||
assert "offline_heuristic_classifier_no_npu_claim" in plan["warnings"]
|
||||
|
||||
|
||||
def test_unsafe_live_routing_request_keeps_authority_closed_and_blocks_side_effect() -> None:
|
||||
plan = build_plan(
|
||||
"Change Hermes live routing to use the classifier automatically.",
|
||||
context={"repo_path": "/home/will/lab/swarm"},
|
||||
classifier=fake_classifier(labels("coding", tool=True, safety=True)),
|
||||
)
|
||||
assert plan["authority"] == AUTHORITY
|
||||
assert plan["authority"]["may_route"] is False
|
||||
assert any(field["field"] == "authority_side_effect" for field in plan["bundle_plan"]["blocked_fields"])
|
||||
assert plan["gates"]["live_routing_change"] == "closed_requires_explicit_approval"
|
||||
|
||||
|
||||
def test_rejects_non_dry_run_and_private_text_options() -> None:
|
||||
with pytest.raises(ContextGateError, match="dry_run_must_remain_true"):
|
||||
build_plan("hello", options={"dry_run": False})
|
||||
with pytest.raises(ContextGateError, match="include_private_text"):
|
||||
build_plan("hello", options={"include_private_text": True})
|
||||
|
||||
|
||||
def test_compact_outputs_are_small_and_parseable() -> None:
|
||||
plan = build_plan("How do I check whether port 18819 is healthy?")
|
||||
line = compact_line(plan)
|
||||
assert "schema=atlas_context_gate_plan_v1" in line
|
||||
assert "gates=closed:" in line
|
||||
parsed = json.loads(compact_json(plan))
|
||||
assert parsed["schema"] == "atlas_context_gate_plan_v1"
|
||||
assert isinstance(parsed["sources"], list)
|
||||
assert "authority" in parsed
|
||||
|
||||
|
||||
def test_cli_offline_compact_json_smoke() -> None:
|
||||
script = REPO_ROOT / "scripts" / "context-gate-advisory.py"
|
||||
result = subprocess.run(
|
||||
[sys.executable, str(script), "--offline", "--query", "Write a haiku about Seattle rain.", "--format", "compact-json"],
|
||||
check=True,
|
||||
text=True,
|
||||
capture_output=True,
|
||||
cwd=REPO_ROOT,
|
||||
)
|
||||
parsed = json.loads(result.stdout)
|
||||
assert parsed["ok"] is True
|
||||
assert parsed["bundle_name"] == "SimpleResponseBundle"
|
||||
assert parsed["sources"] == ["no_retrieval"]
|
||||
assert parsed["npu_proof"]["verified"] is False
|
||||
|
||||
|
||||
def test_npu_proof_requires_positive_sysfs_delta() -> None:
|
||||
classifier = fake_classifier(labels("devops", tool=True), endpoint_delta=120, sysfs_delta=0, outer_delta=None)
|
||||
plan = build_plan("How do I check whether the RAG reranker is using the NPU?", classifier=classifier)
|
||||
assert plan["npu_proof"]["verified"] is False
|
||||
assert "npu_proof_inconclusive" in plan["warnings"]
|
||||
|
||||
endpoint_sysfs_plan = build_plan(
|
||||
"How do I check whether the RAG reranker is using the NPU?",
|
||||
classifier=fake_classifier(labels("devops", tool=True), endpoint_delta=120, sysfs_delta=1, outer_delta=None),
|
||||
)
|
||||
assert endpoint_sysfs_plan["npu_proof"]["verified"] is True
|
||||
|
||||
outer_sysfs_plan = build_plan(
|
||||
"How do I check whether the RAG reranker is using the NPU?",
|
||||
classifier=fake_classifier(labels("devops", tool=True), endpoint_delta=120, sysfs_delta=0, outer_delta=1),
|
||||
)
|
||||
assert outer_sysfs_plan["npu_proof"]["verified"] is True
|
||||
|
||||
|
||||
def test_classifier_url_must_be_loopback_or_localhost() -> None:
|
||||
for url in [
|
||||
"http://example.com/v1/classify",
|
||||
"https://10.0.0.5/v1/classify",
|
||||
"http://0.0.0.0:18819/v1/classify",
|
||||
"ftp://127.0.0.1/v1/classify",
|
||||
]:
|
||||
with pytest.raises(ContextGateError, match="invalid_classifier_url"):
|
||||
classify_live("hello", classifier_url=url, timeout=0.01)
|
||||
|
||||
|
||||
def test_classifier_url_redirect_to_non_loopback_is_not_followed(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
requests: list[str] = []
|
||||
|
||||
class RedirectHandler(BaseHTTPRequestHandler):
|
||||
def do_POST(self) -> None: # noqa: N802 - stdlib callback name
|
||||
requests.append(self.path)
|
||||
self.send_response(302)
|
||||
self.send_header("Location", "http://example.com/v1/classify")
|
||||
self.end_headers()
|
||||
|
||||
def log_message(self, format: str, *args: object) -> None:
|
||||
return
|
||||
|
||||
original_create_connection = socket.create_connection
|
||||
|
||||
def guarded_create_connection(address, *args, **kwargs): # type: ignore[no-untyped-def]
|
||||
host = address[0]
|
||||
if host not in {"127.0.0.1", "localhost"}:
|
||||
raise AssertionError(f"attempted non-loopback redirect connection to {host}")
|
||||
return original_create_connection(address, *args, **kwargs)
|
||||
|
||||
server = ThreadingHTTPServer(("127.0.0.1", 0), RedirectHandler)
|
||||
thread = threading.Thread(target=server.serve_forever, daemon=True)
|
||||
thread.start()
|
||||
monkeypatch.setattr(socket, "create_connection", guarded_create_connection)
|
||||
try:
|
||||
url = f"http://127.0.0.1:{server.server_port}/v1/classify"
|
||||
with pytest.raises(ContextGateError, match="classifier_unavailable"):
|
||||
classify_live("hello", classifier_url=url, timeout=1.0)
|
||||
finally:
|
||||
server.shutdown()
|
||||
server.server_close()
|
||||
thread.join(timeout=2)
|
||||
|
||||
assert requests == ["/v1/classify"]
|
||||
Reference in New Issue
Block a user