564 lines
26 KiB
Python
564 lines
26 KiB
Python
#!/usr/bin/env python3
|
|
"""Dry-run Atlas/Hermes router classifier backed by the local OpenVINO NPU embedding service.
|
|
|
|
Default port: 18819
|
|
Default upstream: http://127.0.0.1:18817/v1/embeddings
|
|
|
|
This service is intentionally advisory only. It does not write memory, mutate routing,
|
|
restart services, or call external APIs. NPU execution is proved by the upstream
|
|
embedding service's npu_busy_delta_us and by reading the local sysfs busy counter.
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
import argparse
|
|
import json
|
|
import math
|
|
import os
|
|
import re
|
|
import sys
|
|
import time
|
|
import urllib.error
|
|
import urllib.request
|
|
from dataclasses import dataclass
|
|
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
|
|
from pathlib import Path
|
|
from typing import Any
|
|
|
|
VERSION = "0.1.0"
|
|
SERVICE = "atlas-router-classifier"
|
|
MODEL = "bge-base-en-v1.5-int8-ov/prototype-router-v0"
|
|
DEFAULT_HOST = "127.0.0.1"
|
|
DEFAULT_PORT = 18819
|
|
DEFAULT_EMBED_URL = "http://127.0.0.1:18817/v1/embeddings"
|
|
DEFAULT_MAX_BATCH_SIZE = 32
|
|
NPU_BUSY_FILE = Path("/sys/class/accel/accel0/device/npu_busy_time_us")
|
|
|
|
WORKFLOW_CATEGORIES = [
|
|
"chat",
|
|
"research",
|
|
"coding",
|
|
"debugging",
|
|
"devops",
|
|
"smart_home",
|
|
"media",
|
|
"note_taking",
|
|
"productivity",
|
|
"kanban",
|
|
"unknown",
|
|
]
|
|
MEMORY_VALUES = ["none", "user_preference", "durable_user_fact", "environment_fact", "workflow_convention", "skill_candidate"]
|
|
URGENCY_VALUES = ["low", "normal", "high", "critical"]
|
|
|
|
PROTOTYPES: dict[str, list[str]] = {
|
|
"tool_needed": [
|
|
"check the current date time weather news versions or live facts",
|
|
"inspect files git branches logs ports processes disk memory or system state",
|
|
"send a message create a cron job call an API or interact with a local service",
|
|
"search the web browse a website download or verify current information",
|
|
],
|
|
"memory_user_preference": [
|
|
"remember that I prefer concise replies and a direct style",
|
|
"my preference is use short answers and avoid unnecessary detail",
|
|
"please remember I like this convention for future sessions",
|
|
],
|
|
"memory_durable_user_fact": [
|
|
"remember that I live in Seattle and work on local AI infrastructure",
|
|
"my name role location identity or durable personal detail is",
|
|
],
|
|
"memory_environment_fact": [
|
|
"this project uses pytest and this server runs linux with openvino npu",
|
|
"remember this repository convention service port path or environment setup",
|
|
],
|
|
"memory_workflow_convention": [
|
|
"for this workflow use this recurring procedure convention or process",
|
|
"the team convention is to run checks before code review and use a worktree",
|
|
],
|
|
"memory_skill_candidate": [
|
|
"we discovered a reusable multi step workflow that should become a skill",
|
|
"save this procedure as a reusable skill after solving a tricky task",
|
|
],
|
|
"urgency_low": [
|
|
"whenever convenient no rush low priority idea someday backlog",
|
|
],
|
|
"urgency_high": [
|
|
"urgent asap high priority today please handle soon production issue",
|
|
"service is degraded broken failing down users are blocked",
|
|
],
|
|
"urgency_critical": [
|
|
"critical outage security incident data loss production down emergency now",
|
|
"stop the bleeding rollback immediately credentials leaked destructive incident",
|
|
],
|
|
"workflow_chat": [
|
|
"answer a general question explain a concept brainstorm rewrite text chat casually",
|
|
],
|
|
"workflow_research": [
|
|
"research compare summarize sources papers market docs web search literature review",
|
|
],
|
|
"workflow_coding": [
|
|
"implement code write tests refactor add feature fix type errors create a branch",
|
|
],
|
|
"workflow_debugging": [
|
|
"debug failing tests inspect logs reproduce error traceback diagnose regression",
|
|
],
|
|
"workflow_devops": [
|
|
"operate services systemd docker kubernetes ports health checks deploy infrastructure",
|
|
],
|
|
"workflow_smart_home": [
|
|
"turn on lights adjust thermostat control tv speaker home assistant hue wiz",
|
|
],
|
|
"workflow_media": [
|
|
"transcribe audio process video image gif spotify music youtube media file",
|
|
],
|
|
"workflow_note_taking": [
|
|
"obsidian notes daily diary memory knowledge base document personal context",
|
|
],
|
|
"workflow_productivity": [
|
|
"calendar email spreadsheet presentation notion airtable linear task planning",
|
|
],
|
|
"workflow_kanban": [
|
|
"kanban task board card assignee handoff review required blocked complete worker",
|
|
],
|
|
}
|
|
|
|
RULES: dict[str, list[tuple[re.Pattern[str], str, float]]] = {
|
|
"tool_needed": [
|
|
(re.compile(r"\b(current|today|now|latest|weather|news|version|price|stock)\b", re.I), "current_fact_requested", 0.88),
|
|
(re.compile(r"\b(file|directory|git|branch|commit|diff|log|port|process|disk|memory|cpu|gpu|npu|service|systemd|reindex)\b", re.I), "local_state_requested", 0.84),
|
|
(re.compile(r"\b(send|schedule|create cron|call api|download|browse|search web|open website|turn on|turn off|set the thermostat|transcribe|restart|switch primary routing|work kanban|kanban task)\b", re.I), "external_or_tool_action_requested", 0.86),
|
|
],
|
|
"safety": [
|
|
(re.compile(r"\b(delete|remove|overwrite|drop|truncate|wipe|reindex|reset --hard|force push)\b", re.I), "destructive_or_irreversible_action", 0.92),
|
|
(re.compile(r"\b(restart|stop|deploy|expose|public|0\.0\.0\.0|route live|primary routing|gateway)\b", re.I), "live_service_or_routing_change", 0.88),
|
|
(re.compile(r"\b(secret|token|api key|credential|password|private document|external upload|send message|spend money|purchase)\b", re.I), "credential_privacy_or_external_side_effect", 0.9),
|
|
],
|
|
"memory": [
|
|
(re.compile(r"\b(remember that|please remember|don'?t forget|my preference|I prefer|call me)\b", re.I), "explicit_memory_language", 0.9),
|
|
(re.compile(r"\b(always|for future|going forward|convention|workflow|standard practice)\b", re.I), "durable_convention_language", 0.78),
|
|
],
|
|
"urgency_high": [
|
|
(re.compile(r"\b(urgent|asap|immediately|high priority|production|down|broken|blocked)\b", re.I), "urgent_language", 0.84),
|
|
],
|
|
"urgency_critical": [
|
|
(re.compile(r"\b(critical|emergency|outage|data loss|credential leak|security incident|prod down)\b", re.I), "critical_incident_language", 0.94),
|
|
],
|
|
}
|
|
|
|
|
|
def npu_busy_time_us() -> int | None:
|
|
try:
|
|
return int(NPU_BUSY_FILE.read_text().strip())
|
|
except Exception:
|
|
return None
|
|
|
|
|
|
def env_int(name: str, default: int) -> int:
|
|
raw = os.environ.get(name)
|
|
if raw is None:
|
|
return default
|
|
try:
|
|
return int(raw)
|
|
except ValueError as exc:
|
|
raise SystemExit(f"{name} must be an integer, got {raw!r}") from exc
|
|
|
|
|
|
def env_float(name: str, default: float) -> float:
|
|
raw = os.environ.get(name)
|
|
if raw is None:
|
|
return default
|
|
try:
|
|
return float(raw)
|
|
except ValueError as exc:
|
|
raise SystemExit(f"{name} must be a number, got {raw!r}") from exc
|
|
|
|
|
|
def clamp01(value: float) -> float:
|
|
return max(0.0, min(1.0, value))
|
|
|
|
|
|
def cosine(a: list[float], b: list[float]) -> float:
|
|
if not a or not b or len(a) != len(b):
|
|
return 0.0
|
|
dot = sum(x * y for x, y in zip(a, b))
|
|
na = math.sqrt(sum(x * x for x in a))
|
|
nb = math.sqrt(sum(y * y for y in b))
|
|
if na == 0.0 or nb == 0.0:
|
|
return 0.0
|
|
# Map [-1, 1] to [0, 1] for confidence-like scoring.
|
|
return clamp01((dot / (na * nb) + 1.0) / 2.0)
|
|
|
|
|
|
def best_rule(text: str, group: str) -> tuple[float, list[str], list[dict[str, Any]]]:
|
|
best = 0.0
|
|
codes: list[str] = []
|
|
evidence: list[dict[str, Any]] = []
|
|
for pattern, code, score in RULES.get(group, []):
|
|
match = pattern.search(text)
|
|
if match:
|
|
best = max(best, score)
|
|
codes.append(code)
|
|
evidence.append({"label": group, "source": "rule", "matched": match.group(0), "reason_code": code, "score": score})
|
|
return best, sorted(set(codes)), evidence
|
|
|
|
|
|
@dataclass
|
|
class EmbedResult:
|
|
vectors: list[list[float]]
|
|
npu_busy_delta_us: int | None
|
|
duration_ms: float
|
|
embedding_dim: int | None
|
|
|
|
|
|
class EmbeddingClient:
|
|
def __init__(self, url: str, timeout_s: float = 30.0) -> None:
|
|
self.url = url
|
|
self.timeout_s = timeout_s
|
|
|
|
def embed(self, texts: list[str], *, purpose: str = "query") -> EmbedResult:
|
|
payload = json.dumps({"input": texts, "purpose": purpose}).encode("utf-8")
|
|
request = urllib.request.Request(
|
|
self.url,
|
|
data=payload,
|
|
headers={"Content-Type": "application/json"},
|
|
method="POST",
|
|
)
|
|
started = time.perf_counter()
|
|
try:
|
|
with urllib.request.urlopen(request, timeout=self.timeout_s) as response: # noqa: S310 - local configured URL
|
|
body = response.read().decode("utf-8", "replace")
|
|
except urllib.error.HTTPError as exc:
|
|
detail = exc.read().decode("utf-8", "replace")
|
|
raise RuntimeError(f"embedding service HTTP {exc.code}: {detail}") from exc
|
|
except urllib.error.URLError as exc:
|
|
raise RuntimeError(f"embedding service unavailable at {self.url}: {exc.reason}") from exc
|
|
data = json.loads(body)
|
|
vectors = [item["embedding"] for item in data.get("data", [])]
|
|
return EmbedResult(
|
|
vectors=[[float(x) for x in vec] for vec in vectors],
|
|
npu_busy_delta_us=data.get("npu_busy_delta_us"),
|
|
duration_ms=round((time.perf_counter() - started) * 1000, 3),
|
|
embedding_dim=data.get("embedding_dim") or (len(vectors[0]) if vectors else None),
|
|
)
|
|
|
|
|
|
class ClassifierService:
|
|
def __init__(self, embed_url: str, *, timeout_s: float = 30.0, max_batch_size: int = DEFAULT_MAX_BATCH_SIZE) -> None:
|
|
self.embed_url = embed_url
|
|
self.client = EmbeddingClient(embed_url, timeout_s=timeout_s)
|
|
self.max_batch_size = max(1, int(max_batch_size))
|
|
self.loaded_at = time.time()
|
|
self.prototype_texts: list[str] = []
|
|
self.prototype_keys: list[str] = []
|
|
for key, examples in PROTOTYPES.items():
|
|
for example in examples:
|
|
self.prototype_keys.append(key)
|
|
self.prototype_texts.append(example)
|
|
self.prototype_vectors: list[list[float]] | None = None
|
|
self.prototype_npu_busy_delta_us: int | None = None
|
|
self.embedding_dim: int | None = None
|
|
self.warnings: list[str] = []
|
|
|
|
def warmup(self) -> None:
|
|
result = self.client.embed(self.prototype_texts, purpose="document")
|
|
self.prototype_vectors = result.vectors
|
|
self.prototype_npu_busy_delta_us = result.npu_busy_delta_us
|
|
self.embedding_dim = result.embedding_dim
|
|
if not result.npu_busy_delta_us or result.npu_busy_delta_us <= 0:
|
|
self.warnings.append("prototype embedding warmup did not report positive NPU busy delta")
|
|
|
|
def health(self) -> dict[str, Any]:
|
|
return {
|
|
"status": "ok" if self.prototype_vectors else "starting",
|
|
"service": SERVICE,
|
|
"version": VERSION,
|
|
"mode": "dry_run",
|
|
"model": MODEL,
|
|
"embed_url": self.embed_url,
|
|
"device": "NPU-via-embedding-service",
|
|
"labels": ["tool_needed", "memory_candidate", "urgency", "workflow_category", "safety_confirmation_required"],
|
|
"embedding_dim": self.embedding_dim,
|
|
"prototype_count": len(self.prototype_texts),
|
|
"max_batch_size": self.max_batch_size,
|
|
"prototype_npu_busy_delta_us": self.prototype_npu_busy_delta_us,
|
|
"npu_busy_time_us": npu_busy_time_us(),
|
|
"uptime_s": round(time.time() - self.loaded_at, 3),
|
|
"warnings": self.warnings,
|
|
}
|
|
|
|
def labels(self) -> dict[str, Any]:
|
|
return {
|
|
"model": MODEL,
|
|
"thresholds": {
|
|
"tool_needed": 0.72,
|
|
"memory_candidate": 0.78,
|
|
"safety_confirmation_required": 0.80,
|
|
"workflow_category": 0.52,
|
|
},
|
|
"enums": {"memory_candidate": MEMORY_VALUES, "urgency": URGENCY_VALUES, "workflow_category": WORKFLOW_CATEGORIES},
|
|
"limits": {"max_batch_size": self.max_batch_size},
|
|
"prototype_ids": sorted(PROTOTYPES),
|
|
}
|
|
|
|
def classify(self, item_id: str | None, text: str, options: dict[str, Any] | None = None) -> dict[str, Any]:
|
|
if self.prototype_vectors is None:
|
|
self.warmup()
|
|
options = options or {}
|
|
include_evidence = bool(options.get("include_evidence", True))
|
|
include_embedding_debug = bool(options.get("include_embedding_debug", False))
|
|
dry_run = bool(options.get("dry_run", True))
|
|
started = time.perf_counter()
|
|
text = str(text or "")
|
|
if not text.strip():
|
|
raise ValueError("text must be a non-empty string")
|
|
|
|
sysfs_before = npu_busy_time_us()
|
|
embedded = self.client.embed([text], purpose="query")
|
|
sysfs_after = npu_busy_time_us()
|
|
if not embedded.vectors:
|
|
raise RuntimeError("embedding service returned no vectors")
|
|
message_vec = embedded.vectors[0]
|
|
similarities = self._prototype_scores(message_vec)
|
|
|
|
evidence: list[dict[str, Any]] = []
|
|
labels: dict[str, Any] = {}
|
|
|
|
tool_rule, tool_codes, tool_evidence = best_rule(text, "tool_needed")
|
|
tool_proto = max([similarities.get("tool_needed", 0.0)], default=0.0)
|
|
# Similarity alone is too broad for action classification; require either
|
|
# a deterministic rule hit or a very strong prototype match.
|
|
tool_conf = round(max(tool_rule, tool_proto if tool_proto >= 0.88 else 0.0), 3)
|
|
labels["tool_needed"] = {"value": tool_conf >= 0.72, "confidence": tool_conf, "threshold": 0.72, "reason_codes": tool_codes}
|
|
evidence.extend(tool_evidence)
|
|
if tool_proto > 0:
|
|
evidence.append({"label": "tool_needed", "source": "prototype_similarity", "prototype": "tool_needed", "score": round(tool_proto, 3)})
|
|
|
|
mem_label, mem_conf, mem_codes, mem_ev = self._memory_label(text, similarities)
|
|
labels["memory_candidate"] = {"value": mem_label, "confidence": round(mem_conf, 3), "threshold": 0.78, "reason_codes": mem_codes}
|
|
evidence.extend(mem_ev)
|
|
|
|
urgency_value, urgency_conf, urgency_scores, urgency_codes, urgency_ev = self._urgency_label(text, similarities)
|
|
labels["urgency"] = {"value": urgency_value, "confidence": round(urgency_conf, 3), "scores": {k: round(v, 3) for k, v in urgency_scores.items()}, "reason_codes": urgency_codes}
|
|
evidence.extend(urgency_ev)
|
|
|
|
workflow_value, workflow_conf, workflow_scores, workflow_ev = self._workflow_label(similarities, text)
|
|
labels["workflow_category"] = {"value": workflow_value, "confidence": round(workflow_conf, 3), "scores": {k: round(v, 3) for k, v in workflow_scores.items()}}
|
|
evidence.extend(workflow_ev)
|
|
|
|
safety_rule, safety_codes, safety_evidence = best_rule(text, "safety")
|
|
safety_proto = 0.0
|
|
safety_conf = round(max(safety_rule, safety_proto), 3)
|
|
labels["safety_confirmation_required"] = {"value": safety_conf >= 0.80, "confidence": safety_conf, "threshold": 0.80, "reason_codes": safety_codes}
|
|
evidence.extend(safety_evidence)
|
|
|
|
npu_delta = embedded.npu_busy_delta_us
|
|
sysfs_delta = None if sysfs_before is None or sysfs_after is None else sysfs_after - sysfs_before
|
|
warnings = list(self.warnings)
|
|
if not npu_delta or npu_delta <= 0:
|
|
warnings.append("embedding call did not report positive npu_busy_delta_us; NPU execution not proven for this request")
|
|
if sysfs_delta is not None and sysfs_delta <= 0:
|
|
warnings.append("sysfs npu_busy_time_us did not increase during classification request")
|
|
|
|
response: dict[str, Any] = {
|
|
"id": item_id,
|
|
"model": MODEL,
|
|
"created": int(time.time()),
|
|
"duration_ms": round((time.perf_counter() - started) * 1000, 3),
|
|
"npu_busy_delta_us": npu_delta,
|
|
"sysfs_npu_busy_delta_us": sysfs_delta,
|
|
"dry_run": dry_run,
|
|
"labels": labels,
|
|
"warnings": warnings,
|
|
}
|
|
if include_evidence:
|
|
response["evidence"] = evidence[:30]
|
|
if include_embedding_debug:
|
|
response["embedding_debug"] = {"embedding_dim": len(message_vec), "prototype_scores": {k: round(v, 3) for k, v in similarities.items()}}
|
|
return response
|
|
|
|
def batch_classify(self, items: list[dict[str, Any]], options: dict[str, Any] | None = None) -> dict[str, Any]:
|
|
if not items:
|
|
raise ValueError("items must contain at least one classification request")
|
|
if len(items) > self.max_batch_size:
|
|
raise ValueError(f"items exceeds max_batch_size={self.max_batch_size}")
|
|
started = time.perf_counter()
|
|
results = [self.classify(item.get("id"), str(item.get("text") or ""), options) for item in items]
|
|
return {
|
|
"model": MODEL,
|
|
"duration_ms": round((time.perf_counter() - started) * 1000, 3),
|
|
"npu_busy_delta_us": sum((r.get("npu_busy_delta_us") or 0) for r in results),
|
|
"results": results,
|
|
}
|
|
|
|
def _prototype_scores(self, vec: list[float]) -> dict[str, float]:
|
|
assert self.prototype_vectors is not None
|
|
scores: dict[str, float] = {}
|
|
for key, prototype_vec in zip(self.prototype_keys, self.prototype_vectors):
|
|
scores[key] = max(scores.get(key, 0.0), cosine(vec, prototype_vec))
|
|
return scores
|
|
|
|
def _memory_label(self, text: str, scores: dict[str, float]) -> tuple[str, float, list[str], list[dict[str, Any]]]:
|
|
rule_score, codes, evidence = best_rule(text, "memory")
|
|
candidates = {
|
|
"user_preference": scores.get("memory_user_preference", 0.0),
|
|
"durable_user_fact": scores.get("memory_durable_user_fact", 0.0),
|
|
"environment_fact": scores.get("memory_environment_fact", 0.0),
|
|
"workflow_convention": scores.get("memory_workflow_convention", 0.0),
|
|
"skill_candidate": scores.get("memory_skill_candidate", 0.0),
|
|
}
|
|
label, proto_score = max(candidates.items(), key=lambda kv: kv[1])
|
|
confidence = max(proto_score, rule_score)
|
|
explicit_memory = rule_score >= 0.78
|
|
durable_fact_hint = bool(re.search(r"\b(project uses|repo uses|environment uses|runs on|standard practice|convention|workflow convention)\b", text, re.I))
|
|
if explicit_memory:
|
|
if re.search(r"\b(prefer|preference|call me|my name|I live|I am)\b", text, re.I):
|
|
label = "user_preference" if re.search(r"\b(prefer|preference)\b", text, re.I) else "durable_user_fact"
|
|
elif durable_fact_hint:
|
|
label = "environment_fact"
|
|
elif re.search(r"\b(skill|procedure|workflow)\b", text, re.I):
|
|
label = "skill_candidate"
|
|
# BGE prototype similarities are advisory but broad; avoid recommending
|
|
# memory writes from similarity alone unless the text also has durable-
|
|
# fact language or an unusually strong prototype match.
|
|
if confidence < 0.78 or (not explicit_memory and not durable_fact_hint and proto_score < 0.88):
|
|
label = "none"
|
|
else:
|
|
evidence.append({"label": "memory_candidate", "source": "prototype_similarity", "prototype": f"memory_{label}", "score": round(proto_score, 3)})
|
|
return label, confidence if label != "none" else max(0.0, min(confidence, 0.77)), codes, evidence
|
|
|
|
def _urgency_label(self, text: str, scores: dict[str, float]) -> tuple[str, float, dict[str, float], list[str], list[dict[str, Any]]]:
|
|
high_rule, high_codes, high_ev = best_rule(text, "urgency_high")
|
|
critical_rule, critical_codes, critical_ev = best_rule(text, "urgency_critical")
|
|
low_rule = 0.82 if re.search(r"\b(no rush|whenever convenient|low priority|someday|backlog)\b", text, re.I) else 0.0
|
|
# Urgency is safety-sensitive for notifications, so require explicit
|
|
# language instead of relying on broad prototype similarity.
|
|
score_map = {
|
|
# Urgency should be explicit; broad embedding similarity otherwise
|
|
# turns neutral requests such as "what time is it" into low/high/critical urgency.
|
|
"low": low_rule,
|
|
"normal": 0.68,
|
|
"high": high_rule,
|
|
"critical": critical_rule,
|
|
}
|
|
if score_map["critical"] >= 0.9:
|
|
score_map["normal"] = 0.05
|
|
elif score_map["high"] >= 0.8 or score_map["low"] >= 0.8:
|
|
score_map["normal"] = 0.2
|
|
value, confidence = max(score_map.items(), key=lambda kv: kv[1])
|
|
evidence = high_ev + critical_ev
|
|
return value, confidence, score_map, sorted(set(high_codes + critical_codes)), evidence
|
|
|
|
def _workflow_label(self, scores: dict[str, float], text: str = "") -> tuple[str, float, dict[str, float], list[dict[str, Any]]]:
|
|
score_map = {category: scores.get(f"workflow_{category}", 0.0) for category in WORKFLOW_CATEGORIES if category != "unknown"}
|
|
rule_patterns: list[tuple[str, str]] = [
|
|
("chat", r"\bwhat time is it|what date is it|general question\b"),
|
|
("kanban", r"\bkanban|task card|review-required|blocked\b"),
|
|
("smart_home", r"\blights?|thermostat|home assistant|hue|wiz\b"),
|
|
("media", r"\btranscribe|voice memo|audio|video|image|spotify|youtube\b"),
|
|
("research", r"\bresearch|compare sources|papers?|literature|web search\b"),
|
|
("devops", r"\bsystemd|docker|kubernetes|service|ports?|gateway|deploy|infrastructure\b"),
|
|
("debugging", r"\bdebug|failing|traceback|logs?|reproduce|diagnose\b"),
|
|
("coding", r"\bimplement|code|pytest|refactor|feature|PR\b"),
|
|
("note_taking", r"\bobsidian|notes?|memory|diary|chroma|reindex\b"),
|
|
("productivity", r"\bcalendar|email|spreadsheet|presentation|notion|airtable|linear\b"),
|
|
]
|
|
rule_value: str | None = None
|
|
for category, pattern in rule_patterns:
|
|
if re.search(pattern, text, re.I):
|
|
rule_value = category
|
|
break
|
|
if rule_value:
|
|
value = rule_value
|
|
confidence = max(0.86, score_map.get(rule_value, 0.0))
|
|
score_map[rule_value] = confidence
|
|
source = "rule"
|
|
else:
|
|
value, confidence = max(score_map.items(), key=lambda kv: kv[1])
|
|
source = "prototype_similarity"
|
|
if confidence < 0.52:
|
|
value = "unknown"
|
|
confidence = 0.52
|
|
score_map["unknown"] = 1.0 - confidence if value != "unknown" else confidence
|
|
evidence = [{"label": "workflow_category", "source": source, "prototype": f"workflow_{value}", "score": round(confidence, 3)}]
|
|
return value, confidence, score_map, evidence
|
|
|
|
|
|
class Handler(BaseHTTPRequestHandler):
|
|
server_version = "AtlasRouterClassifier/0.1"
|
|
|
|
@property
|
|
def svc(self) -> ClassifierService:
|
|
return self.server.classifier_service # type: ignore[attr-defined]
|
|
|
|
def do_GET(self) -> None:
|
|
path = self.path.split("?", 1)[0].rstrip("/") or "/"
|
|
if path in {"/", "/healthz", "/readyz", "/health"}:
|
|
self.write_json(self.svc.health())
|
|
elif path == "/v1/labels":
|
|
self.write_json(self.svc.labels())
|
|
else:
|
|
self.write_json({"error": "not found"}, status=404)
|
|
|
|
def do_POST(self) -> None:
|
|
path = self.path.split("?", 1)[0].rstrip("/") or "/"
|
|
try:
|
|
payload = self.read_json()
|
|
options = payload.get("options") if isinstance(payload.get("options"), dict) else {}
|
|
if path == "/v1/classify":
|
|
self.write_json(self.svc.classify(payload.get("id"), str(payload.get("text") or ""), options))
|
|
elif path == "/v1/batch_classify":
|
|
items = payload.get("items")
|
|
if not isinstance(items, list):
|
|
raise ValueError("items must be a list")
|
|
self.write_json(self.svc.batch_classify(items, options))
|
|
else:
|
|
self.write_json({"error": "not found"}, status=404)
|
|
except ValueError as exc:
|
|
self.write_json({"error": str(exc)}, status=400)
|
|
except Exception as exc:
|
|
self.write_json({"error": f"{type(exc).__name__}: {exc}"}, status=500)
|
|
|
|
def read_json(self) -> dict[str, Any]:
|
|
length = int(self.headers.get("Content-Length") or 0)
|
|
body = self.rfile.read(length).decode("utf-8", "replace") if length else "{}"
|
|
payload = json.loads(body or "{}")
|
|
if not isinstance(payload, dict):
|
|
raise ValueError("JSON body must be an object")
|
|
return payload
|
|
|
|
def write_json(self, payload: dict[str, Any], status: int = 200) -> None:
|
|
body = json.dumps(payload, ensure_ascii=False, sort_keys=True).encode("utf-8")
|
|
self.send_response(status)
|
|
self.send_header("Content-Type", "application/json")
|
|
self.send_header("Content-Length", str(len(body)))
|
|
self.end_headers()
|
|
self.wfile.write(body)
|
|
|
|
def log_message(self, format: str, *args: Any) -> None: # noqa: A002 - stdlib override name
|
|
print(f"{self.address_string()} - {format % args}", file=sys.stderr, flush=True)
|
|
|
|
|
|
def main() -> int:
|
|
parser = argparse.ArgumentParser(description="Dry-run Atlas/Hermes router classifier")
|
|
parser.add_argument("--host", default=os.environ.get("OPENVINO_CLASSIFIER_HOST", DEFAULT_HOST))
|
|
parser.add_argument("--port", type=int, default=env_int("OPENVINO_CLASSIFIER_PORT", DEFAULT_PORT))
|
|
parser.add_argument("--embed-url", default=os.environ.get("OPENVINO_CLASSIFIER_EMBED_URL", DEFAULT_EMBED_URL))
|
|
parser.add_argument("--timeout-s", type=float, default=env_float("OPENVINO_CLASSIFIER_TIMEOUT_S", 30.0))
|
|
parser.add_argument("--max-batch-size", type=int, default=env_int("OPENVINO_CLASSIFIER_MAX_BATCH_SIZE", DEFAULT_MAX_BATCH_SIZE))
|
|
parser.add_argument("--no-warmup", action="store_true", help="skip prototype embedding warmup until first request")
|
|
args = parser.parse_args()
|
|
|
|
service = ClassifierService(args.embed_url, timeout_s=args.timeout_s, max_batch_size=args.max_batch_size)
|
|
if not args.no_warmup:
|
|
service.warmup()
|
|
httpd = ThreadingHTTPServer((args.host, args.port), Handler)
|
|
httpd.classifier_service = service # type: ignore[attr-defined]
|
|
print(f"{SERVICE} listening on {args.host}:{args.port} embed_url={args.embed_url} mode=dry_run", flush=True)
|
|
try:
|
|
httpd.serve_forever()
|
|
except KeyboardInterrupt:
|
|
pass
|
|
return 0
|
|
|
|
|
|
if __name__ == "__main__":
|
|
raise SystemExit(main())
|