[verified] refresh OpenVINO router classifier prototype
This commit is contained in:
@@ -30,6 +30,7 @@ MODEL = "bge-base-en-v1.5-int8-ov/prototype-router-v0"
|
||||
DEFAULT_HOST = "127.0.0.1"
|
||||
DEFAULT_PORT = 18819
|
||||
DEFAULT_EMBED_URL = "http://127.0.0.1:18817/v1/embeddings"
|
||||
DEFAULT_MAX_BATCH_SIZE = 32
|
||||
NPU_BUSY_FILE = Path("/sys/class/accel/accel0/device/npu_busy_time_us")
|
||||
|
||||
WORKFLOW_CATEGORIES = [
|
||||
@@ -150,6 +151,26 @@ def npu_busy_time_us() -> int | None:
|
||||
return None
|
||||
|
||||
|
||||
def env_int(name: str, default: int) -> int:
|
||||
raw = os.environ.get(name)
|
||||
if raw is None:
|
||||
return default
|
||||
try:
|
||||
return int(raw)
|
||||
except ValueError as exc:
|
||||
raise SystemExit(f"{name} must be an integer, got {raw!r}") from exc
|
||||
|
||||
|
||||
def env_float(name: str, default: float) -> float:
|
||||
raw = os.environ.get(name)
|
||||
if raw is None:
|
||||
return default
|
||||
try:
|
||||
return float(raw)
|
||||
except ValueError as exc:
|
||||
raise SystemExit(f"{name} must be a number, got {raw!r}") from exc
|
||||
|
||||
|
||||
def clamp01(value: float) -> float:
|
||||
return max(0.0, min(1.0, value))
|
||||
|
||||
@@ -220,9 +241,10 @@ class EmbeddingClient:
|
||||
|
||||
|
||||
class ClassifierService:
|
||||
def __init__(self, embed_url: str, *, timeout_s: float = 30.0) -> None:
|
||||
def __init__(self, embed_url: str, *, timeout_s: float = 30.0, max_batch_size: int = DEFAULT_MAX_BATCH_SIZE) -> None:
|
||||
self.embed_url = embed_url
|
||||
self.client = EmbeddingClient(embed_url, timeout_s=timeout_s)
|
||||
self.max_batch_size = max(1, int(max_batch_size))
|
||||
self.loaded_at = time.time()
|
||||
self.prototype_texts: list[str] = []
|
||||
self.prototype_keys: list[str] = []
|
||||
@@ -255,6 +277,7 @@ class ClassifierService:
|
||||
"labels": ["tool_needed", "memory_candidate", "urgency", "workflow_category", "safety_confirmation_required"],
|
||||
"embedding_dim": self.embedding_dim,
|
||||
"prototype_count": len(self.prototype_texts),
|
||||
"max_batch_size": self.max_batch_size,
|
||||
"prototype_npu_busy_delta_us": self.prototype_npu_busy_delta_us,
|
||||
"npu_busy_time_us": npu_busy_time_us(),
|
||||
"uptime_s": round(time.time() - self.loaded_at, 3),
|
||||
@@ -271,6 +294,7 @@ class ClassifierService:
|
||||
"workflow_category": 0.52,
|
||||
},
|
||||
"enums": {"memory_candidate": MEMORY_VALUES, "urgency": URGENCY_VALUES, "workflow_category": WORKFLOW_CATEGORIES},
|
||||
"limits": {"max_batch_size": self.max_batch_size},
|
||||
"prototype_ids": sorted(PROTOTYPES),
|
||||
}
|
||||
|
||||
@@ -351,6 +375,10 @@ class ClassifierService:
|
||||
return response
|
||||
|
||||
def batch_classify(self, items: list[dict[str, Any]], options: dict[str, Any] | None = None) -> dict[str, Any]:
|
||||
if not items:
|
||||
raise ValueError("items must contain at least one classification request")
|
||||
if len(items) > self.max_batch_size:
|
||||
raise ValueError(f"items exceeds max_batch_size={self.max_batch_size}")
|
||||
started = time.perf_counter()
|
||||
results = [self.classify(item.get("id"), str(item.get("text") or ""), options) for item in items]
|
||||
return {
|
||||
@@ -400,13 +428,15 @@ class ClassifierService:
|
||||
high_rule, high_codes, high_ev = best_rule(text, "urgency_high")
|
||||
critical_rule, critical_codes, critical_ev = best_rule(text, "urgency_critical")
|
||||
low_rule = 0.82 if re.search(r"\b(no rush|whenever convenient|low priority|someday|backlog)\b", text, re.I) else 0.0
|
||||
# Urgency is safety-sensitive for notifications. Prefer explicit rules;
|
||||
# use prototype scores only when they are unusually strong.
|
||||
# Urgency is safety-sensitive for notifications, so require explicit
|
||||
# language instead of relying on broad prototype similarity.
|
||||
score_map = {
|
||||
"low": max(low_rule, scores.get("urgency_low", 0.0) if scores.get("urgency_low", 0.0) >= 0.9 else 0.0),
|
||||
# Urgency should be explicit; broad embedding similarity otherwise
|
||||
# turns neutral requests such as "what time is it" into low/high/critical urgency.
|
||||
"low": low_rule,
|
||||
"normal": 0.68,
|
||||
"high": max(high_rule, scores.get("urgency_high", 0.0) if scores.get("urgency_high", 0.0) >= 0.9 else 0.0),
|
||||
"critical": max(critical_rule, scores.get("urgency_critical", 0.0) if scores.get("urgency_critical", 0.0) >= 0.92 else 0.0),
|
||||
"high": high_rule,
|
||||
"critical": critical_rule,
|
||||
}
|
||||
if score_map["critical"] >= 0.9:
|
||||
score_map["normal"] = 0.05
|
||||
@@ -509,13 +539,14 @@ class Handler(BaseHTTPRequestHandler):
|
||||
def main() -> int:
|
||||
parser = argparse.ArgumentParser(description="Dry-run Atlas/Hermes router classifier")
|
||||
parser.add_argument("--host", default=os.environ.get("OPENVINO_CLASSIFIER_HOST", DEFAULT_HOST))
|
||||
parser.add_argument("--port", type=int, default=int(os.environ.get("OPENVINO_CLASSIFIER_PORT", DEFAULT_PORT)))
|
||||
parser.add_argument("--port", type=int, default=env_int("OPENVINO_CLASSIFIER_PORT", DEFAULT_PORT))
|
||||
parser.add_argument("--embed-url", default=os.environ.get("OPENVINO_CLASSIFIER_EMBED_URL", DEFAULT_EMBED_URL))
|
||||
parser.add_argument("--timeout-s", type=float, default=float(os.environ.get("OPENVINO_CLASSIFIER_TIMEOUT_S", "30")))
|
||||
parser.add_argument("--timeout-s", type=float, default=env_float("OPENVINO_CLASSIFIER_TIMEOUT_S", 30.0))
|
||||
parser.add_argument("--max-batch-size", type=int, default=env_int("OPENVINO_CLASSIFIER_MAX_BATCH_SIZE", DEFAULT_MAX_BATCH_SIZE))
|
||||
parser.add_argument("--no-warmup", action="store_true", help="skip prototype embedding warmup until first request")
|
||||
args = parser.parse_args()
|
||||
|
||||
service = ClassifierService(args.embed_url, timeout_s=args.timeout_s)
|
||||
service = ClassifierService(args.embed_url, timeout_s=args.timeout_s, max_batch_size=args.max_batch_size)
|
||||
if not args.no_warmup:
|
||||
service.warmup()
|
||||
httpd = ThreadingHTTPServer((args.host, args.port), Handler)
|
||||
|
||||
Reference in New Issue
Block a user