103 lines
4.5 KiB
Python
103 lines
4.5 KiB
Python
#!/usr/bin/env python3
|
|
from __future__ import annotations
|
|
|
|
import importlib.util
|
|
import json
|
|
import sys
|
|
import unittest
|
|
from pathlib import Path
|
|
|
|
ROOT = Path(__file__).resolve().parents[1]
|
|
MODULE_PATH = ROOT / "router_classifier.py"
|
|
spec = importlib.util.spec_from_file_location("router_classifier", MODULE_PATH)
|
|
assert spec and spec.loader
|
|
router_classifier = importlib.util.module_from_spec(spec)
|
|
sys.modules["router_classifier"] = router_classifier
|
|
spec.loader.exec_module(router_classifier)
|
|
|
|
|
|
class FakeClient:
|
|
def embed(self, texts, *, purpose="query"):
|
|
# Deterministic toy embeddings based on keyword buckets. The tests focus on
|
|
# rule safety and API shape; live smoke tests cover the real NPU upstream.
|
|
vectors = []
|
|
for text in texts:
|
|
t = text.lower()
|
|
vec = [0.0] * 8
|
|
if any(w in t for w in ["time", "current", "weather", "news", "port", "git", "logs", "systemd"]):
|
|
vec[0] = 1.0
|
|
if any(w in t for w in ["remember", "prefer", "preference"]):
|
|
vec[1] = 1.0
|
|
if any(w in t for w in ["urgent", "down", "outage", "critical"]):
|
|
vec[2] = 1.0
|
|
if any(w in t for w in ["code", "pytest", "debug", "git", "diff"]):
|
|
vec[3] = 1.0
|
|
if any(w in t for w in ["service", "systemd", "port", "gateway", "docker"]):
|
|
vec[4] = 1.0
|
|
if any(w in t for w in ["kanban", "task", "blocked", "review"]):
|
|
vec[5] = 1.0
|
|
if any(w in t for w in ["light", "thermostat"]):
|
|
vec[6] = 1.0
|
|
if any(w in t for w in ["transcribe", "voice", "memo", "audio"]):
|
|
vec[7] = 1.0
|
|
if not any(vec):
|
|
vec[0] = 0.2
|
|
vectors.append(vec)
|
|
return router_classifier.EmbedResult(vectors=vectors, npu_busy_delta_us=123, duration_ms=1.0, embedding_dim=8)
|
|
|
|
|
|
class RouterClassifierTests(unittest.TestCase):
|
|
def service(self):
|
|
svc = router_classifier.ClassifierService("http://fake.local/v1/embeddings")
|
|
svc.client = FakeClient()
|
|
svc.warmup()
|
|
return svc
|
|
|
|
def test_health_and_label_schema(self):
|
|
svc = self.service()
|
|
health = svc.health()
|
|
self.assertEqual(health["service"], "atlas-router-classifier")
|
|
self.assertEqual(health["mode"], "dry_run")
|
|
self.assertIn("tool_needed", health["labels"])
|
|
labels = svc.labels()
|
|
self.assertIn("workflow_category", labels["enums"])
|
|
self.assertIn("safety_confirmation_required", labels["thresholds"])
|
|
|
|
def test_explicit_preference_is_memory_candidate(self):
|
|
result = self.service().classify("pref", "Remember that I prefer concise terminal replies.")
|
|
self.assertEqual(result["labels"]["memory_candidate"]["value"], "user_preference")
|
|
self.assertGreaterEqual(result["labels"]["memory_candidate"]["confidence"], 0.78)
|
|
self.assertFalse(result["labels"]["safety_confirmation_required"]["value"])
|
|
|
|
def test_current_local_state_needs_tool(self):
|
|
result = self.service().classify("port", "Check whether port 18819 is listening and inspect systemd logs.")
|
|
self.assertTrue(result["labels"]["tool_needed"]["value"])
|
|
self.assertIn("local_state_requested", result["labels"]["tool_needed"]["reason_codes"])
|
|
|
|
def test_live_gateway_restart_requires_confirmation(self):
|
|
result = self.service().classify("safe", "Restart the live Atlas gateway and switch primary routing.")
|
|
self.assertTrue(result["labels"]["safety_confirmation_required"]["value"])
|
|
self.assertIn("live_service_or_routing_change", result["labels"]["safety_confirmation_required"]["reason_codes"])
|
|
|
|
def test_batch_shape(self):
|
|
result = self.service().batch_classify([
|
|
{"id": "a", "text": "What time is it?"},
|
|
{"id": "b", "text": "Delete the existing collection and reindex it in place."},
|
|
])
|
|
self.assertEqual(result["model"], router_classifier.MODEL)
|
|
self.assertEqual(len(result["results"]), 2)
|
|
self.assertGreater(result["npu_busy_delta_us"], 0)
|
|
|
|
def test_fixture_file_is_valid_jsonl(self):
|
|
fixture = ROOT / "fixtures" / "atlas_hermes_messages.jsonl"
|
|
rows = [json.loads(line) for line in fixture.read_text().splitlines() if line.strip()]
|
|
self.assertGreaterEqual(len(rows), 8)
|
|
for row in rows:
|
|
self.assertIn("id", row)
|
|
self.assertIn("text", row)
|
|
self.assertIn("expected", row)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|