[verified] refresh OpenVINO router classifier prototype
This commit is contained in:
@@ -191,7 +191,7 @@ Response:
|
|||||||
|
|
||||||
Batch limits for prototype review:
|
Batch limits for prototype review:
|
||||||
|
|
||||||
- Keep batches small, ideally <= 32 items.
|
- Keep batches small; the prototype rejects empty batches and batches larger than `OPENVINO_CLASSIFIER_MAX_BATCH_SIZE` (default `32`).
|
||||||
- Use only synthetic fixtures unless Will explicitly approves a real non-private sample set.
|
- Use only synthetic fixtures unless Will explicitly approves a real non-private sample set.
|
||||||
- Do not retain request bodies to disk.
|
- Do not retain request bodies to disk.
|
||||||
|
|
||||||
@@ -213,6 +213,7 @@ Required flags/env:
|
|||||||
- `--port` / `OPENVINO_CLASSIFIER_PORT`; default `18819`.
|
- `--port` / `OPENVINO_CLASSIFIER_PORT`; default `18819`.
|
||||||
- `--embed-url` / `OPENVINO_CLASSIFIER_EMBED_URL`; default `http://127.0.0.1:18817/v1/embeddings`.
|
- `--embed-url` / `OPENVINO_CLASSIFIER_EMBED_URL`; default `http://127.0.0.1:18817/v1/embeddings`.
|
||||||
- `--timeout-s` / `OPENVINO_CLASSIFIER_TIMEOUT_S`; default `30`.
|
- `--timeout-s` / `OPENVINO_CLASSIFIER_TIMEOUT_S`; default `30`.
|
||||||
|
- `--max-batch-size` / `OPENVINO_CLASSIFIER_MAX_BATCH_SIZE`; default `32`.
|
||||||
- `--no-warmup` to defer prototype embedding until first request.
|
- `--no-warmup` to defer prototype embedding until first request.
|
||||||
|
|
||||||
A future dedicated CLI mode may be added for one-shot JSONL classification, but foreground HTTP review is sufficient for the dry-run contract.
|
A future dedicated CLI mode may be added for one-shot JSONL classification, but foreground HTTP review is sufficient for the dry-run contract.
|
||||||
@@ -280,6 +281,13 @@ echo "$response" | jq '{npu_busy_delta_us, sysfs_npu_busy_delta_us, warnings}'
|
|||||||
echo "outer_sysfs_npu_busy_delta_us=$((after-before))"
|
echo "outer_sysfs_npu_busy_delta_us=$((after-before))"
|
||||||
```
|
```
|
||||||
|
|
||||||
|
Optional localhost smoke helper, after starting the foreground service:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
/home/will/.venvs/npu/bin/python openvino-classifier-npu/smoke_classifier.py \
|
||||||
|
--base-url http://127.0.0.1:18819
|
||||||
|
```
|
||||||
|
|
||||||
Acceptance for an NPU-backed classification request:
|
Acceptance for an NPU-backed classification request:
|
||||||
|
|
||||||
- HTTP request succeeds.
|
- HTTP request succeeds.
|
||||||
|
|||||||
@@ -17,6 +17,7 @@ services, or send external messages.
|
|||||||
- Default port: `18819`
|
- Default port: `18819`
|
||||||
- Default bind: `127.0.0.1`
|
- Default bind: `127.0.0.1`
|
||||||
- Upstream: `http://127.0.0.1:18817/v1/embeddings`
|
- Upstream: `http://127.0.0.1:18817/v1/embeddings`
|
||||||
|
- Batch limit: `OPENVINO_CLASSIFIER_MAX_BATCH_SIZE`, default `32`
|
||||||
- Model label: `bge-base-en-v1.5-int8-ov/prototype-router-v0`
|
- Model label: `bge-base-en-v1.5-int8-ov/prototype-router-v0`
|
||||||
- NPU proof: `/sys/class/accel/accel0/device/npu_busy_time_us` before/after plus upstream `npu_busy_delta_us`
|
- NPU proof: `/sys/class/accel/accel0/device/npu_busy_time_us` before/after plus upstream `npu_busy_delta_us`
|
||||||
|
|
||||||
@@ -90,6 +91,10 @@ cd /home/will/lab/swarm/openvino-classifier-npu
|
|||||||
/home/will/.venvs/npu/bin/python router_classifier.py --host 127.0.0.1 --port 18819
|
/home/will/.venvs/npu/bin/python router_classifier.py --host 127.0.0.1 --port 18819
|
||||||
```
|
```
|
||||||
|
|
||||||
|
Environment variables mirror the flags: `OPENVINO_CLASSIFIER_HOST`,
|
||||||
|
`OPENVINO_CLASSIFIER_PORT`, `OPENVINO_CLASSIFIER_EMBED_URL`,
|
||||||
|
`OPENVINO_CLASSIFIER_TIMEOUT_S`, and `OPENVINO_CLASSIFIER_MAX_BATCH_SIZE`.
|
||||||
|
|
||||||
Then from another shell:
|
Then from another shell:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
@@ -102,6 +107,15 @@ curl -fsS http://127.0.0.1:18819/v1/classify \
|
|||||||
A valid NPU-backed response must have positive `npu_busy_delta_us`; HTTP 200 by
|
A valid NPU-backed response must have positive `npu_busy_delta_us`; HTTP 200 by
|
||||||
itself is not considered proof.
|
itself is not considered proof.
|
||||||
|
|
||||||
|
Synthetic fixture smoke helper, after the foreground service is running:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
/home/will/.venvs/npu/bin/python smoke_classifier.py --base-url http://127.0.0.1:18819
|
||||||
|
```
|
||||||
|
|
||||||
|
The helper refuses non-local URLs, checks fixture label expectations, and prints
|
||||||
|
response plus outer sysfs NPU busy deltas.
|
||||||
|
|
||||||
## Tests
|
## Tests
|
||||||
|
|
||||||
Unit tests use a fake embedding client and do not touch the NPU:
|
Unit tests use a fake embedding client and do not touch the NPU:
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ WorkingDirectory=/home/will/lab/swarm/openvino-classifier-npu
|
|||||||
Environment=OPENVINO_CLASSIFIER_HOST=127.0.0.1
|
Environment=OPENVINO_CLASSIFIER_HOST=127.0.0.1
|
||||||
Environment=OPENVINO_CLASSIFIER_PORT=18819
|
Environment=OPENVINO_CLASSIFIER_PORT=18819
|
||||||
Environment=OPENVINO_CLASSIFIER_EMBED_URL=http://127.0.0.1:18817/v1/embeddings
|
Environment=OPENVINO_CLASSIFIER_EMBED_URL=http://127.0.0.1:18817/v1/embeddings
|
||||||
|
Environment=OPENVINO_CLASSIFIER_MAX_BATCH_SIZE=32
|
||||||
ExecStart=/home/will/.venvs/npu/bin/python /home/will/lab/swarm/openvino-classifier-npu/router_classifier.py
|
ExecStart=/home/will/.venvs/npu/bin/python /home/will/lab/swarm/openvino-classifier-npu/router_classifier.py
|
||||||
Restart=on-failure
|
Restart=on-failure
|
||||||
RestartSec=5
|
RestartSec=5
|
||||||
|
|||||||
@@ -30,6 +30,7 @@ MODEL = "bge-base-en-v1.5-int8-ov/prototype-router-v0"
|
|||||||
DEFAULT_HOST = "127.0.0.1"
|
DEFAULT_HOST = "127.0.0.1"
|
||||||
DEFAULT_PORT = 18819
|
DEFAULT_PORT = 18819
|
||||||
DEFAULT_EMBED_URL = "http://127.0.0.1:18817/v1/embeddings"
|
DEFAULT_EMBED_URL = "http://127.0.0.1:18817/v1/embeddings"
|
||||||
|
DEFAULT_MAX_BATCH_SIZE = 32
|
||||||
NPU_BUSY_FILE = Path("/sys/class/accel/accel0/device/npu_busy_time_us")
|
NPU_BUSY_FILE = Path("/sys/class/accel/accel0/device/npu_busy_time_us")
|
||||||
|
|
||||||
WORKFLOW_CATEGORIES = [
|
WORKFLOW_CATEGORIES = [
|
||||||
@@ -150,6 +151,26 @@ def npu_busy_time_us() -> int | None:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def env_int(name: str, default: int) -> int:
|
||||||
|
raw = os.environ.get(name)
|
||||||
|
if raw is None:
|
||||||
|
return default
|
||||||
|
try:
|
||||||
|
return int(raw)
|
||||||
|
except ValueError as exc:
|
||||||
|
raise SystemExit(f"{name} must be an integer, got {raw!r}") from exc
|
||||||
|
|
||||||
|
|
||||||
|
def env_float(name: str, default: float) -> float:
|
||||||
|
raw = os.environ.get(name)
|
||||||
|
if raw is None:
|
||||||
|
return default
|
||||||
|
try:
|
||||||
|
return float(raw)
|
||||||
|
except ValueError as exc:
|
||||||
|
raise SystemExit(f"{name} must be a number, got {raw!r}") from exc
|
||||||
|
|
||||||
|
|
||||||
def clamp01(value: float) -> float:
|
def clamp01(value: float) -> float:
|
||||||
return max(0.0, min(1.0, value))
|
return max(0.0, min(1.0, value))
|
||||||
|
|
||||||
@@ -220,9 +241,10 @@ class EmbeddingClient:
|
|||||||
|
|
||||||
|
|
||||||
class ClassifierService:
|
class ClassifierService:
|
||||||
def __init__(self, embed_url: str, *, timeout_s: float = 30.0) -> None:
|
def __init__(self, embed_url: str, *, timeout_s: float = 30.0, max_batch_size: int = DEFAULT_MAX_BATCH_SIZE) -> None:
|
||||||
self.embed_url = embed_url
|
self.embed_url = embed_url
|
||||||
self.client = EmbeddingClient(embed_url, timeout_s=timeout_s)
|
self.client = EmbeddingClient(embed_url, timeout_s=timeout_s)
|
||||||
|
self.max_batch_size = max(1, int(max_batch_size))
|
||||||
self.loaded_at = time.time()
|
self.loaded_at = time.time()
|
||||||
self.prototype_texts: list[str] = []
|
self.prototype_texts: list[str] = []
|
||||||
self.prototype_keys: list[str] = []
|
self.prototype_keys: list[str] = []
|
||||||
@@ -255,6 +277,7 @@ class ClassifierService:
|
|||||||
"labels": ["tool_needed", "memory_candidate", "urgency", "workflow_category", "safety_confirmation_required"],
|
"labels": ["tool_needed", "memory_candidate", "urgency", "workflow_category", "safety_confirmation_required"],
|
||||||
"embedding_dim": self.embedding_dim,
|
"embedding_dim": self.embedding_dim,
|
||||||
"prototype_count": len(self.prototype_texts),
|
"prototype_count": len(self.prototype_texts),
|
||||||
|
"max_batch_size": self.max_batch_size,
|
||||||
"prototype_npu_busy_delta_us": self.prototype_npu_busy_delta_us,
|
"prototype_npu_busy_delta_us": self.prototype_npu_busy_delta_us,
|
||||||
"npu_busy_time_us": npu_busy_time_us(),
|
"npu_busy_time_us": npu_busy_time_us(),
|
||||||
"uptime_s": round(time.time() - self.loaded_at, 3),
|
"uptime_s": round(time.time() - self.loaded_at, 3),
|
||||||
@@ -271,6 +294,7 @@ class ClassifierService:
|
|||||||
"workflow_category": 0.52,
|
"workflow_category": 0.52,
|
||||||
},
|
},
|
||||||
"enums": {"memory_candidate": MEMORY_VALUES, "urgency": URGENCY_VALUES, "workflow_category": WORKFLOW_CATEGORIES},
|
"enums": {"memory_candidate": MEMORY_VALUES, "urgency": URGENCY_VALUES, "workflow_category": WORKFLOW_CATEGORIES},
|
||||||
|
"limits": {"max_batch_size": self.max_batch_size},
|
||||||
"prototype_ids": sorted(PROTOTYPES),
|
"prototype_ids": sorted(PROTOTYPES),
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -351,6 +375,10 @@ class ClassifierService:
|
|||||||
return response
|
return response
|
||||||
|
|
||||||
def batch_classify(self, items: list[dict[str, Any]], options: dict[str, Any] | None = None) -> dict[str, Any]:
|
def batch_classify(self, items: list[dict[str, Any]], options: dict[str, Any] | None = None) -> dict[str, Any]:
|
||||||
|
if not items:
|
||||||
|
raise ValueError("items must contain at least one classification request")
|
||||||
|
if len(items) > self.max_batch_size:
|
||||||
|
raise ValueError(f"items exceeds max_batch_size={self.max_batch_size}")
|
||||||
started = time.perf_counter()
|
started = time.perf_counter()
|
||||||
results = [self.classify(item.get("id"), str(item.get("text") or ""), options) for item in items]
|
results = [self.classify(item.get("id"), str(item.get("text") or ""), options) for item in items]
|
||||||
return {
|
return {
|
||||||
@@ -400,13 +428,15 @@ class ClassifierService:
|
|||||||
high_rule, high_codes, high_ev = best_rule(text, "urgency_high")
|
high_rule, high_codes, high_ev = best_rule(text, "urgency_high")
|
||||||
critical_rule, critical_codes, critical_ev = best_rule(text, "urgency_critical")
|
critical_rule, critical_codes, critical_ev = best_rule(text, "urgency_critical")
|
||||||
low_rule = 0.82 if re.search(r"\b(no rush|whenever convenient|low priority|someday|backlog)\b", text, re.I) else 0.0
|
low_rule = 0.82 if re.search(r"\b(no rush|whenever convenient|low priority|someday|backlog)\b", text, re.I) else 0.0
|
||||||
# Urgency is safety-sensitive for notifications. Prefer explicit rules;
|
# Urgency is safety-sensitive for notifications, so require explicit
|
||||||
# use prototype scores only when they are unusually strong.
|
# language instead of relying on broad prototype similarity.
|
||||||
score_map = {
|
score_map = {
|
||||||
"low": max(low_rule, scores.get("urgency_low", 0.0) if scores.get("urgency_low", 0.0) >= 0.9 else 0.0),
|
# Urgency should be explicit; broad embedding similarity otherwise
|
||||||
|
# turns neutral requests such as "what time is it" into low/high/critical urgency.
|
||||||
|
"low": low_rule,
|
||||||
"normal": 0.68,
|
"normal": 0.68,
|
||||||
"high": max(high_rule, scores.get("urgency_high", 0.0) if scores.get("urgency_high", 0.0) >= 0.9 else 0.0),
|
"high": high_rule,
|
||||||
"critical": max(critical_rule, scores.get("urgency_critical", 0.0) if scores.get("urgency_critical", 0.0) >= 0.92 else 0.0),
|
"critical": critical_rule,
|
||||||
}
|
}
|
||||||
if score_map["critical"] >= 0.9:
|
if score_map["critical"] >= 0.9:
|
||||||
score_map["normal"] = 0.05
|
score_map["normal"] = 0.05
|
||||||
@@ -509,13 +539,14 @@ class Handler(BaseHTTPRequestHandler):
|
|||||||
def main() -> int:
|
def main() -> int:
|
||||||
parser = argparse.ArgumentParser(description="Dry-run Atlas/Hermes router classifier")
|
parser = argparse.ArgumentParser(description="Dry-run Atlas/Hermes router classifier")
|
||||||
parser.add_argument("--host", default=os.environ.get("OPENVINO_CLASSIFIER_HOST", DEFAULT_HOST))
|
parser.add_argument("--host", default=os.environ.get("OPENVINO_CLASSIFIER_HOST", DEFAULT_HOST))
|
||||||
parser.add_argument("--port", type=int, default=int(os.environ.get("OPENVINO_CLASSIFIER_PORT", DEFAULT_PORT)))
|
parser.add_argument("--port", type=int, default=env_int("OPENVINO_CLASSIFIER_PORT", DEFAULT_PORT))
|
||||||
parser.add_argument("--embed-url", default=os.environ.get("OPENVINO_CLASSIFIER_EMBED_URL", DEFAULT_EMBED_URL))
|
parser.add_argument("--embed-url", default=os.environ.get("OPENVINO_CLASSIFIER_EMBED_URL", DEFAULT_EMBED_URL))
|
||||||
parser.add_argument("--timeout-s", type=float, default=float(os.environ.get("OPENVINO_CLASSIFIER_TIMEOUT_S", "30")))
|
parser.add_argument("--timeout-s", type=float, default=env_float("OPENVINO_CLASSIFIER_TIMEOUT_S", 30.0))
|
||||||
|
parser.add_argument("--max-batch-size", type=int, default=env_int("OPENVINO_CLASSIFIER_MAX_BATCH_SIZE", DEFAULT_MAX_BATCH_SIZE))
|
||||||
parser.add_argument("--no-warmup", action="store_true", help="skip prototype embedding warmup until first request")
|
parser.add_argument("--no-warmup", action="store_true", help="skip prototype embedding warmup until first request")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
service = ClassifierService(args.embed_url, timeout_s=args.timeout_s)
|
service = ClassifierService(args.embed_url, timeout_s=args.timeout_s, max_batch_size=args.max_batch_size)
|
||||||
if not args.no_warmup:
|
if not args.no_warmup:
|
||||||
service.warmup()
|
service.warmup()
|
||||||
httpd = ThreadingHTTPServer((args.host, args.port), Handler)
|
httpd = ThreadingHTTPServer((args.host, args.port), Handler)
|
||||||
|
|||||||
@@ -0,0 +1,113 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""Local-only smoke test for the dry-run OpenVINO router classifier.
|
||||||
|
|
||||||
|
This script uses only synthetic fixture messages. It assumes router_classifier.py is
|
||||||
|
already running on localhost and never installs/enables a persistent service.
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
import urllib.error
|
||||||
|
import urllib.request
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
DEFAULT_BASE_URL = "http://127.0.0.1:18819"
|
||||||
|
BUSY_FILE = Path("/sys/class/accel/accel0/device/npu_busy_time_us")
|
||||||
|
FIXTURE = Path(__file__).resolve().parent / "fixtures" / "atlas_hermes_messages.jsonl"
|
||||||
|
|
||||||
|
|
||||||
|
def npu_busy_time_us() -> int | None:
|
||||||
|
try:
|
||||||
|
return int(BUSY_FILE.read_text().strip())
|
||||||
|
except Exception:
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def get_json(url: str, timeout_s: float) -> dict[str, Any]:
|
||||||
|
with urllib.request.urlopen(url, timeout=timeout_s) as response: # noqa: S310 - localhost smoke URL
|
||||||
|
return json.loads(response.read().decode("utf-8"))
|
||||||
|
|
||||||
|
|
||||||
|
def post_json(url: str, payload: dict[str, Any], timeout_s: float) -> dict[str, Any]:
|
||||||
|
request = urllib.request.Request(
|
||||||
|
url,
|
||||||
|
data=json.dumps(payload).encode("utf-8"),
|
||||||
|
headers={"Content-Type": "application/json"},
|
||||||
|
method="POST",
|
||||||
|
)
|
||||||
|
with urllib.request.urlopen(request, timeout=timeout_s) as response: # noqa: S310 - localhost smoke URL
|
||||||
|
return json.loads(response.read().decode("utf-8"))
|
||||||
|
|
||||||
|
|
||||||
|
def load_fixture(limit: int) -> list[dict[str, Any]]:
|
||||||
|
rows = [json.loads(line) for line in FIXTURE.read_text().splitlines() if line.strip()]
|
||||||
|
return rows[:limit]
|
||||||
|
|
||||||
|
|
||||||
|
def assert_expected(result: dict[str, Any], expected: dict[str, Any]) -> list[str]:
|
||||||
|
failures: list[str] = []
|
||||||
|
labels = result.get("labels", {})
|
||||||
|
for key, value in expected.items():
|
||||||
|
actual_label = labels.get(key, {})
|
||||||
|
actual_value = actual_label.get("value")
|
||||||
|
if actual_value != value:
|
||||||
|
failures.append(f"{result.get('id')}: {key} expected {value!r}, got {actual_value!r}")
|
||||||
|
return failures
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> int:
|
||||||
|
parser = argparse.ArgumentParser(description="Smoke-test a running localhost router classifier")
|
||||||
|
parser.add_argument("--base-url", default=DEFAULT_BASE_URL)
|
||||||
|
parser.add_argument("--timeout-s", type=float, default=30.0)
|
||||||
|
parser.add_argument("--limit", type=int, default=10)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
if not args.base_url.startswith("http://127.0.0.1:") and not args.base_url.startswith("http://localhost:"):
|
||||||
|
raise SystemExit("refusing non-local base URL; this smoke is localhost-only")
|
||||||
|
|
||||||
|
before = npu_busy_time_us()
|
||||||
|
started = time.perf_counter()
|
||||||
|
try:
|
||||||
|
health = get_json(f"{args.base_url.rstrip('/')}/healthz", args.timeout_s)
|
||||||
|
labels = get_json(f"{args.base_url.rstrip('/')}/v1/labels", args.timeout_s)
|
||||||
|
rows = load_fixture(args.limit)
|
||||||
|
results = []
|
||||||
|
failures: list[str] = []
|
||||||
|
for row in rows:
|
||||||
|
result = post_json(
|
||||||
|
f"{args.base_url.rstrip('/')}/v1/classify",
|
||||||
|
{"id": row["id"], "text": row["text"], "options": {"include_evidence": False, "dry_run": True}},
|
||||||
|
args.timeout_s,
|
||||||
|
)
|
||||||
|
results.append(result)
|
||||||
|
failures.extend(assert_expected(result, row.get("expected", {})))
|
||||||
|
after = npu_busy_time_us()
|
||||||
|
except urllib.error.URLError as exc:
|
||||||
|
raise SystemExit(f"smoke failed: {exc}") from exc
|
||||||
|
|
||||||
|
response_npu_delta = sum((r.get("npu_busy_delta_us") or 0) for r in results)
|
||||||
|
outer_sysfs_delta = None if before is None or after is None else after - before
|
||||||
|
npu_proven = response_npu_delta > 0 and (outer_sysfs_delta is None or outer_sysfs_delta > 0)
|
||||||
|
summary = {
|
||||||
|
"ok": not failures,
|
||||||
|
"service": health.get("service"),
|
||||||
|
"mode": health.get("mode"),
|
||||||
|
"model": health.get("model"),
|
||||||
|
"label_count": len(labels.get("prototype_ids", [])),
|
||||||
|
"fixture_count": len(results),
|
||||||
|
"duration_ms": round((time.perf_counter() - started) * 1000, 3),
|
||||||
|
"response_npu_busy_delta_us": response_npu_delta,
|
||||||
|
"outer_sysfs_npu_busy_delta_us": outer_sysfs_delta,
|
||||||
|
"npu_proven": npu_proven,
|
||||||
|
"failures": failures,
|
||||||
|
}
|
||||||
|
print(json.dumps(summary, indent=2, sort_keys=True))
|
||||||
|
return 0 if not failures and npu_proven else 1
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
raise SystemExit(main())
|
||||||
@@ -88,6 +88,14 @@ class RouterClassifierTests(unittest.TestCase):
|
|||||||
self.assertEqual(len(result["results"]), 2)
|
self.assertEqual(len(result["results"]), 2)
|
||||||
self.assertGreater(result["npu_busy_delta_us"], 0)
|
self.assertGreater(result["npu_busy_delta_us"], 0)
|
||||||
|
|
||||||
|
def test_batch_limits_are_enforced(self):
|
||||||
|
svc = self.service()
|
||||||
|
with self.assertRaisesRegex(ValueError, "at least one"):
|
||||||
|
svc.batch_classify([])
|
||||||
|
too_many = [{"id": str(i), "text": "What time is it?"} for i in range(router_classifier.DEFAULT_MAX_BATCH_SIZE + 1)]
|
||||||
|
with self.assertRaisesRegex(ValueError, "max_batch_size"):
|
||||||
|
svc.batch_classify(too_many)
|
||||||
|
|
||||||
def test_fixture_file_is_valid_jsonl(self):
|
def test_fixture_file_is_valid_jsonl(self):
|
||||||
fixture = ROOT / "fixtures" / "atlas_hermes_messages.jsonl"
|
fixture = ROOT / "fixtures" / "atlas_hermes_messages.jsonl"
|
||||||
rows = [json.loads(line) for line in fixture.read_text().splitlines() if line.strip()]
|
rows = [json.loads(line) for line in fixture.read_text().splitlines() if line.strip()]
|
||||||
@@ -97,6 +105,17 @@ class RouterClassifierTests(unittest.TestCase):
|
|||||||
self.assertIn("text", row)
|
self.assertIn("text", row)
|
||||||
self.assertIn("expected", row)
|
self.assertIn("expected", row)
|
||||||
|
|
||||||
|
def test_synthetic_fixture_expectations(self):
|
||||||
|
svc = self.service()
|
||||||
|
fixture = ROOT / "fixtures" / "atlas_hermes_messages.jsonl"
|
||||||
|
rows = [json.loads(line) for line in fixture.read_text().splitlines() if line.strip()]
|
||||||
|
for row in rows:
|
||||||
|
with self.subTest(row=row["id"]):
|
||||||
|
result = svc.classify(row["id"], row["text"], {"include_evidence": False})
|
||||||
|
labels = result["labels"]
|
||||||
|
for label_name, expected_value in row["expected"].items():
|
||||||
|
self.assertEqual(labels[label_name]["value"], expected_value)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|||||||
Reference in New Issue
Block a user