feat: add OpenVINO router classifier prototype
This commit is contained in:
committed by
William Valentin
parent
4a065de754
commit
4003198ba9
@@ -0,0 +1,102 @@
|
||||
#!/usr/bin/env python3
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib.util
|
||||
import json
|
||||
import sys
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
ROOT = Path(__file__).resolve().parents[1]
|
||||
MODULE_PATH = ROOT / "router_classifier.py"
|
||||
spec = importlib.util.spec_from_file_location("router_classifier", MODULE_PATH)
|
||||
assert spec and spec.loader
|
||||
router_classifier = importlib.util.module_from_spec(spec)
|
||||
sys.modules["router_classifier"] = router_classifier
|
||||
spec.loader.exec_module(router_classifier)
|
||||
|
||||
|
||||
class FakeClient:
|
||||
def embed(self, texts, *, purpose="query"):
|
||||
# Deterministic toy embeddings based on keyword buckets. The tests focus on
|
||||
# rule safety and API shape; live smoke tests cover the real NPU upstream.
|
||||
vectors = []
|
||||
for text in texts:
|
||||
t = text.lower()
|
||||
vec = [0.0] * 8
|
||||
if any(w in t for w in ["time", "current", "weather", "news", "port", "git", "logs", "systemd"]):
|
||||
vec[0] = 1.0
|
||||
if any(w in t for w in ["remember", "prefer", "preference"]):
|
||||
vec[1] = 1.0
|
||||
if any(w in t for w in ["urgent", "down", "outage", "critical"]):
|
||||
vec[2] = 1.0
|
||||
if any(w in t for w in ["code", "pytest", "debug", "git", "diff"]):
|
||||
vec[3] = 1.0
|
||||
if any(w in t for w in ["service", "systemd", "port", "gateway", "docker"]):
|
||||
vec[4] = 1.0
|
||||
if any(w in t for w in ["kanban", "task", "blocked", "review"]):
|
||||
vec[5] = 1.0
|
||||
if any(w in t for w in ["light", "thermostat"]):
|
||||
vec[6] = 1.0
|
||||
if any(w in t for w in ["transcribe", "voice", "memo", "audio"]):
|
||||
vec[7] = 1.0
|
||||
if not any(vec):
|
||||
vec[0] = 0.2
|
||||
vectors.append(vec)
|
||||
return router_classifier.EmbedResult(vectors=vectors, npu_busy_delta_us=123, duration_ms=1.0, embedding_dim=8)
|
||||
|
||||
|
||||
class RouterClassifierTests(unittest.TestCase):
|
||||
def service(self):
|
||||
svc = router_classifier.ClassifierService("http://fake.local/v1/embeddings")
|
||||
svc.client = FakeClient()
|
||||
svc.warmup()
|
||||
return svc
|
||||
|
||||
def test_health_and_label_schema(self):
|
||||
svc = self.service()
|
||||
health = svc.health()
|
||||
self.assertEqual(health["service"], "atlas-router-classifier")
|
||||
self.assertEqual(health["mode"], "dry_run")
|
||||
self.assertIn("tool_needed", health["labels"])
|
||||
labels = svc.labels()
|
||||
self.assertIn("workflow_category", labels["enums"])
|
||||
self.assertIn("safety_confirmation_required", labels["thresholds"])
|
||||
|
||||
def test_explicit_preference_is_memory_candidate(self):
|
||||
result = self.service().classify("pref", "Remember that I prefer concise terminal replies.")
|
||||
self.assertEqual(result["labels"]["memory_candidate"]["value"], "user_preference")
|
||||
self.assertGreaterEqual(result["labels"]["memory_candidate"]["confidence"], 0.78)
|
||||
self.assertFalse(result["labels"]["safety_confirmation_required"]["value"])
|
||||
|
||||
def test_current_local_state_needs_tool(self):
|
||||
result = self.service().classify("port", "Check whether port 18819 is listening and inspect systemd logs.")
|
||||
self.assertTrue(result["labels"]["tool_needed"]["value"])
|
||||
self.assertIn("local_state_requested", result["labels"]["tool_needed"]["reason_codes"])
|
||||
|
||||
def test_live_gateway_restart_requires_confirmation(self):
|
||||
result = self.service().classify("safe", "Restart the live Atlas gateway and switch primary routing.")
|
||||
self.assertTrue(result["labels"]["safety_confirmation_required"]["value"])
|
||||
self.assertIn("live_service_or_routing_change", result["labels"]["safety_confirmation_required"]["reason_codes"])
|
||||
|
||||
def test_batch_shape(self):
|
||||
result = self.service().batch_classify([
|
||||
{"id": "a", "text": "What time is it?"},
|
||||
{"id": "b", "text": "Delete the existing collection and reindex it in place."},
|
||||
])
|
||||
self.assertEqual(result["model"], router_classifier.MODEL)
|
||||
self.assertEqual(len(result["results"]), 2)
|
||||
self.assertGreater(result["npu_busy_delta_us"], 0)
|
||||
|
||||
def test_fixture_file_is_valid_jsonl(self):
|
||||
fixture = ROOT / "fixtures" / "atlas_hermes_messages.jsonl"
|
||||
rows = [json.loads(line) for line in fixture.read_text().splitlines() if line.strip()]
|
||||
self.assertGreaterEqual(len(rows), 8)
|
||||
for row in rows:
|
||||
self.assertIn("id", row)
|
||||
self.assertIn("text", row)
|
||||
self.assertIn("expected", row)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user