[verified] refresh OpenVINO NPU reranker prototype

This commit is contained in:
William Valentin
2026-06-04 12:16:15 -07:00
parent 4dc77bb0c7
commit 418be69f96
4 changed files with 93 additions and 3 deletions
+24
View File
@@ -16,6 +16,7 @@ import argparse
import json
import math
import os
import socket
import sys
import threading
import time
@@ -251,6 +252,27 @@ def normalize_documents(value: Any, max_documents: int) -> list[dict[str, Any]]:
return docs
def parse_top_k(value: Any, document_count: int) -> int:
"""Validate top_k/top_n before inference so schema errors return HTTP 400."""
if value is None:
return document_count
if isinstance(value, bool) or not isinstance(value, int):
raise ValueError("top_k/top_n must be a positive integer")
if value < 1:
raise ValueError("top_k/top_n must be a positive integer")
return min(value, document_count)
def assert_port_available(host: str, port: int) -> None:
"""Fail fast on listener conflicts before compiling the OpenVINO model."""
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
try:
sock.bind((host, port))
except OSError as exc:
raise RuntimeError(f"cannot bind {host}:{port}; listener conflict or invalid bind: {exc}") from exc
class Handler(BaseHTTPRequestHandler):
server_version = "OpenVINOReranker/0.1"
@@ -293,6 +315,7 @@ class Handler(BaseHTTPRequestHandler):
raise ValueError("query is required")
top_k = payload.get("top_k", payload.get("top_n"))
documents = normalize_documents(payload.get("documents"), self.max_documents)
top_k = parse_top_k(top_k, len(documents))
return_documents = bool(payload.get("return_documents", True))
response = self.svc.rerank(query.strip(), documents, top_k=top_k, return_documents=return_documents)
self.write_json(response)
@@ -342,6 +365,7 @@ def main() -> int:
parser.add_argument("--skip-startup-smoke", action="store_true", default=os.environ.get("OPENVINO_RERANKER_SKIP_STARTUP_SMOKE", "").lower() in {"1", "true", "yes"})
args = parser.parse_args()
assert_port_available(args.host, args.port)
service = RerankerService(
Path(args.model_dir).expanduser(),
args.model,