#!/usr/bin/env python3 from __future__ import annotations import importlib.util import json import sys import unittest from pathlib import Path ROOT = Path(__file__).resolve().parents[1] MODULE_PATH = ROOT / "router_classifier.py" spec = importlib.util.spec_from_file_location("router_classifier", MODULE_PATH) assert spec and spec.loader router_classifier = importlib.util.module_from_spec(spec) sys.modules["router_classifier"] = router_classifier spec.loader.exec_module(router_classifier) class FakeClient: def embed(self, texts, *, purpose="query"): # Deterministic toy embeddings based on keyword buckets. The tests focus on # rule safety and API shape; live smoke tests cover the real NPU upstream. vectors = [] for text in texts: t = text.lower() vec = [0.0] * 8 if any(w in t for w in ["time", "current", "weather", "news", "port", "git", "logs", "systemd"]): vec[0] = 1.0 if any(w in t for w in ["remember", "prefer", "preference"]): vec[1] = 1.0 if any(w in t for w in ["urgent", "down", "outage", "critical"]): vec[2] = 1.0 if any(w in t for w in ["code", "pytest", "debug", "git", "diff"]): vec[3] = 1.0 if any(w in t for w in ["service", "systemd", "port", "gateway", "docker"]): vec[4] = 1.0 if any(w in t for w in ["kanban", "task", "blocked", "review"]): vec[5] = 1.0 if any(w in t for w in ["light", "thermostat"]): vec[6] = 1.0 if any(w in t for w in ["transcribe", "voice", "memo", "audio"]): vec[7] = 1.0 if not any(vec): vec[0] = 0.2 vectors.append(vec) return router_classifier.EmbedResult(vectors=vectors, npu_busy_delta_us=123, duration_ms=1.0, embedding_dim=8) class RouterClassifierTests(unittest.TestCase): def service(self): svc = router_classifier.ClassifierService("http://fake.local/v1/embeddings") svc.client = FakeClient() svc.warmup() return svc def test_health_and_label_schema(self): svc = self.service() health = svc.health() self.assertEqual(health["service"], "atlas-router-classifier") self.assertEqual(health["mode"], "dry_run") self.assertIn("tool_needed", health["labels"]) labels = svc.labels() self.assertIn("workflow_category", labels["enums"]) self.assertIn("safety_confirmation_required", labels["thresholds"]) def test_explicit_preference_is_memory_candidate(self): result = self.service().classify("pref", "Remember that I prefer concise terminal replies.") self.assertEqual(result["labels"]["memory_candidate"]["value"], "user_preference") self.assertGreaterEqual(result["labels"]["memory_candidate"]["confidence"], 0.78) self.assertFalse(result["labels"]["safety_confirmation_required"]["value"]) def test_current_local_state_needs_tool(self): result = self.service().classify("port", "Check whether port 18819 is listening and inspect systemd logs.") self.assertTrue(result["labels"]["tool_needed"]["value"]) self.assertIn("local_state_requested", result["labels"]["tool_needed"]["reason_codes"]) def test_live_gateway_restart_requires_confirmation(self): result = self.service().classify("safe", "Restart the live Atlas gateway and switch primary routing.") self.assertTrue(result["labels"]["safety_confirmation_required"]["value"]) self.assertIn("live_service_or_routing_change", result["labels"]["safety_confirmation_required"]["reason_codes"]) def test_batch_shape(self): result = self.service().batch_classify([ {"id": "a", "text": "What time is it?"}, {"id": "b", "text": "Delete the existing collection and reindex it in place."}, ]) self.assertEqual(result["model"], router_classifier.MODEL) self.assertEqual(len(result["results"]), 2) self.assertGreater(result["npu_busy_delta_us"], 0) def test_batch_limits_are_enforced(self): svc = self.service() with self.assertRaisesRegex(ValueError, "at least one"): svc.batch_classify([]) too_many = [{"id": str(i), "text": "What time is it?"} for i in range(router_classifier.DEFAULT_MAX_BATCH_SIZE + 1)] with self.assertRaisesRegex(ValueError, "max_batch_size"): svc.batch_classify(too_many) def test_fixture_file_is_valid_jsonl(self): fixture = ROOT / "fixtures" / "atlas_hermes_messages.jsonl" rows = [json.loads(line) for line in fixture.read_text().splitlines() if line.strip()] self.assertGreaterEqual(len(rows), 8) for row in rows: self.assertIn("id", row) self.assertIn("text", row) self.assertIn("expected", row) def test_synthetic_fixture_expectations(self): svc = self.service() fixture = ROOT / "fixtures" / "atlas_hermes_messages.jsonl" rows = [json.loads(line) for line in fixture.read_text().splitlines() if line.strip()] for row in rows: with self.subTest(row=row["id"]): result = svc.classify(row["id"], row["text"], {"include_evidence": False}) labels = result["labels"] for label_name, expected_value in row["expected"].items(): self.assertEqual(labels[label_name]["value"], expected_value) if __name__ == "__main__": unittest.main()