Files
swarm-master/openvino-classifier-npu/tests/test_router_classifier.py
2026-06-04 13:07:51 -07:00

122 lines
5.5 KiB
Python

#!/usr/bin/env python3
from __future__ import annotations
import importlib.util
import json
import sys
import unittest
from pathlib import Path
ROOT = Path(__file__).resolve().parents[1]
MODULE_PATH = ROOT / "router_classifier.py"
spec = importlib.util.spec_from_file_location("router_classifier", MODULE_PATH)
assert spec and spec.loader
router_classifier = importlib.util.module_from_spec(spec)
sys.modules["router_classifier"] = router_classifier
spec.loader.exec_module(router_classifier)
class FakeClient:
def embed(self, texts, *, purpose="query"):
# Deterministic toy embeddings based on keyword buckets. The tests focus on
# rule safety and API shape; live smoke tests cover the real NPU upstream.
vectors = []
for text in texts:
t = text.lower()
vec = [0.0] * 8
if any(w in t for w in ["time", "current", "weather", "news", "port", "git", "logs", "systemd"]):
vec[0] = 1.0
if any(w in t for w in ["remember", "prefer", "preference"]):
vec[1] = 1.0
if any(w in t for w in ["urgent", "down", "outage", "critical"]):
vec[2] = 1.0
if any(w in t for w in ["code", "pytest", "debug", "git", "diff"]):
vec[3] = 1.0
if any(w in t for w in ["service", "systemd", "port", "gateway", "docker"]):
vec[4] = 1.0
if any(w in t for w in ["kanban", "task", "blocked", "review"]):
vec[5] = 1.0
if any(w in t for w in ["light", "thermostat"]):
vec[6] = 1.0
if any(w in t for w in ["transcribe", "voice", "memo", "audio"]):
vec[7] = 1.0
if not any(vec):
vec[0] = 0.2
vectors.append(vec)
return router_classifier.EmbedResult(vectors=vectors, npu_busy_delta_us=123, duration_ms=1.0, embedding_dim=8)
class RouterClassifierTests(unittest.TestCase):
def service(self):
svc = router_classifier.ClassifierService("http://fake.local/v1/embeddings")
svc.client = FakeClient()
svc.warmup()
return svc
def test_health_and_label_schema(self):
svc = self.service()
health = svc.health()
self.assertEqual(health["service"], "atlas-router-classifier")
self.assertEqual(health["mode"], "dry_run")
self.assertIn("tool_needed", health["labels"])
labels = svc.labels()
self.assertIn("workflow_category", labels["enums"])
self.assertIn("safety_confirmation_required", labels["thresholds"])
def test_explicit_preference_is_memory_candidate(self):
result = self.service().classify("pref", "Remember that I prefer concise terminal replies.")
self.assertEqual(result["labels"]["memory_candidate"]["value"], "user_preference")
self.assertGreaterEqual(result["labels"]["memory_candidate"]["confidence"], 0.78)
self.assertFalse(result["labels"]["safety_confirmation_required"]["value"])
def test_current_local_state_needs_tool(self):
result = self.service().classify("port", "Check whether port 18819 is listening and inspect systemd logs.")
self.assertTrue(result["labels"]["tool_needed"]["value"])
self.assertIn("local_state_requested", result["labels"]["tool_needed"]["reason_codes"])
def test_live_gateway_restart_requires_confirmation(self):
result = self.service().classify("safe", "Restart the live Atlas gateway and switch primary routing.")
self.assertTrue(result["labels"]["safety_confirmation_required"]["value"])
self.assertIn("live_service_or_routing_change", result["labels"]["safety_confirmation_required"]["reason_codes"])
def test_batch_shape(self):
result = self.service().batch_classify([
{"id": "a", "text": "What time is it?"},
{"id": "b", "text": "Delete the existing collection and reindex it in place."},
])
self.assertEqual(result["model"], router_classifier.MODEL)
self.assertEqual(len(result["results"]), 2)
self.assertGreater(result["npu_busy_delta_us"], 0)
def test_batch_limits_are_enforced(self):
svc = self.service()
with self.assertRaisesRegex(ValueError, "at least one"):
svc.batch_classify([])
too_many = [{"id": str(i), "text": "What time is it?"} for i in range(router_classifier.DEFAULT_MAX_BATCH_SIZE + 1)]
with self.assertRaisesRegex(ValueError, "max_batch_size"):
svc.batch_classify(too_many)
def test_fixture_file_is_valid_jsonl(self):
fixture = ROOT / "fixtures" / "atlas_hermes_messages.jsonl"
rows = [json.loads(line) for line in fixture.read_text().splitlines() if line.strip()]
self.assertGreaterEqual(len(rows), 8)
for row in rows:
self.assertIn("id", row)
self.assertIn("text", row)
self.assertIn("expected", row)
def test_synthetic_fixture_expectations(self):
svc = self.service()
fixture = ROOT / "fixtures" / "atlas_hermes_messages.jsonl"
rows = [json.loads(line) for line in fixture.read_text().splitlines() if line.strip()]
for row in rows:
with self.subTest(row=row["id"]):
result = svc.classify(row["id"], row["text"], {"include_evidence": False})
labels = result["labels"]
for label_name, expected_value in row["expected"].items():
self.assertEqual(labels[label_name]["value"], expected_value)
if __name__ == "__main__":
unittest.main()