feat(npu): add OpenVINO reranker prototype
This commit is contained in:
Executable
+393
@@ -0,0 +1,393 @@
|
||||
#!/usr/bin/env python3
|
||||
"""OpenVINO NPU cross-encoder reranker HTTP service.
|
||||
|
||||
Default port: 18818
|
||||
Default model: cross-encoder/ms-marco-MiniLM-L6-v2 exported as OpenVINO IR
|
||||
Default device: NPU
|
||||
|
||||
Endpoints:
|
||||
GET /, /healthz, /readyz
|
||||
POST /rerank
|
||||
POST /v1/rerank
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
import socket
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import openvino as ov
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
DEFAULT_MODEL_ID = "cross-encoder/ms-marco-MiniLM-L6-v2"
|
||||
DEFAULT_MODEL_DIR = Path("/home/will/.cache/openvino-models/rerankers/ms-marco-MiniLM-L6-v2-int8-ov")
|
||||
DEFAULT_PORT = 18818
|
||||
DEFAULT_MAX_LENGTH = 512
|
||||
DEFAULT_MAX_DOCUMENTS = 100
|
||||
DEFAULT_MAX_BODY_BYTES = 5 * 1024 * 1024
|
||||
NPU_BUSY_FILE = Path("/sys/class/accel/accel0/device/npu_busy_time_us")
|
||||
|
||||
|
||||
def npu_busy_time_us() -> int | None:
|
||||
try:
|
||||
return int(NPU_BUSY_FILE.read_text().strip())
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def sigmoid(x: float) -> float:
|
||||
if x >= 0:
|
||||
z = math.exp(-x)
|
||||
return 1.0 / (1.0 + z)
|
||||
z = math.exp(x)
|
||||
return z / (1.0 + z)
|
||||
|
||||
|
||||
def softmax_prob(logits: np.ndarray, index: int = 1) -> float:
|
||||
row = np.asarray(logits, dtype=np.float64).reshape(-1)
|
||||
shifted = row - np.max(row)
|
||||
probs = np.exp(shifted) / np.sum(np.exp(shifted))
|
||||
return float(probs[index])
|
||||
|
||||
|
||||
class RerankerService:
|
||||
def __init__(
|
||||
self,
|
||||
model_dir: Path,
|
||||
model_id: str,
|
||||
device: str,
|
||||
max_length: int,
|
||||
startup_smoke: bool = True,
|
||||
) -> None:
|
||||
self.model_dir = model_dir
|
||||
self.model_id = model_id
|
||||
self.device = device
|
||||
self.max_length = int(max_length)
|
||||
self.loaded_at = time.time()
|
||||
self.lock = threading.Lock()
|
||||
self.last_inference: dict[str, Any] | None = None
|
||||
self.startup_smoke: dict[str, Any] | None = None
|
||||
self.ready = False
|
||||
self.ready_error: str | None = None
|
||||
|
||||
if not self.model_dir.exists():
|
||||
raise FileNotFoundError(f"model directory not found: {self.model_dir}")
|
||||
|
||||
self.core = ov.Core()
|
||||
self.available_devices = list(self.core.available_devices)
|
||||
if self.device not in self.available_devices:
|
||||
raise RuntimeError(f"OpenVINO device {self.device!r} unavailable; available={self.available_devices}")
|
||||
|
||||
xml_path = self.model_dir / "openvino_model.xml"
|
||||
if not xml_path.exists():
|
||||
raise FileNotFoundError(f"OpenVINO IR not found: {xml_path}")
|
||||
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(str(self.model_dir), local_files_only=True)
|
||||
model = self.core.read_model(str(xml_path))
|
||||
self._reshape_static(model)
|
||||
self.compiled = self.core.compile_model(model, self.device)
|
||||
self.input_names = {inp.get_any_name() for inp in self.compiled.inputs}
|
||||
self.output = self.compiled.output(0)
|
||||
|
||||
if startup_smoke:
|
||||
try:
|
||||
smoke = self.rerank(
|
||||
"npu busy time",
|
||||
[{"id": "smoke", "text": "OpenVINO NPU usage is verified by npu_busy_time_us."}],
|
||||
top_k=1,
|
||||
return_documents=False,
|
||||
)
|
||||
self.startup_smoke = {
|
||||
"ok": bool(smoke.get("ok")),
|
||||
"duration_ms": smoke.get("duration_ms"),
|
||||
"npu_busy_delta_us": smoke.get("npu_busy_delta_us"),
|
||||
}
|
||||
if self.device == "NPU" and int(smoke.get("npu_busy_delta_us") or 0) <= 0:
|
||||
raise RuntimeError("startup smoke did not increase npu_busy_time_us")
|
||||
except Exception as exc:
|
||||
self.ready_error = f"startup smoke failed: {type(exc).__name__}: {exc}"
|
||||
raise
|
||||
|
||||
self.ready = True
|
||||
|
||||
def _reshape_static(self, model: ov.Model) -> None:
|
||||
shape_by_name: dict[str, list[int]] = {}
|
||||
for inp in model.inputs:
|
||||
name = inp.get_any_name()
|
||||
if name in {"input_ids", "attention_mask", "token_type_ids"}:
|
||||
shape_by_name[name] = [1, self.max_length]
|
||||
if shape_by_name:
|
||||
model.reshape(shape_by_name)
|
||||
|
||||
def _tokenize(self, query: str, document: str) -> dict[str, np.ndarray]:
|
||||
tokens = self.tokenizer(
|
||||
query,
|
||||
document,
|
||||
max_length=self.max_length,
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
return_tensors="np",
|
||||
)
|
||||
return {name: np.asarray(value) for name, value in tokens.items() if name in self.input_names}
|
||||
|
||||
def _score_pair(self, query: str, document: str) -> dict[str, float | None]:
|
||||
inputs = self._tokenize(query, document)
|
||||
missing = self.input_names - set(inputs)
|
||||
# Some exported BERT models do not use token_type_ids. input_ids and attention_mask are required.
|
||||
required_missing = missing & {"input_ids", "attention_mask"}
|
||||
if required_missing:
|
||||
raise RuntimeError(f"tokenizer did not produce required inputs: {sorted(required_missing)}")
|
||||
outputs = self.compiled(inputs)
|
||||
logits = np.asarray(outputs[self.output])
|
||||
flat = logits.reshape(-1)
|
||||
if flat.size == 1:
|
||||
raw = float(flat[0])
|
||||
return {"score": raw, "raw_score": raw, "probability": sigmoid(raw)}
|
||||
if flat.size >= 2:
|
||||
raw = float(flat[1])
|
||||
return {"score": raw, "raw_score": raw, "probability": softmax_prob(flat, 1)}
|
||||
raise RuntimeError(f"unexpected empty logits shape: {list(logits.shape)}")
|
||||
|
||||
def rerank(
|
||||
self,
|
||||
query: str,
|
||||
documents: list[dict[str, Any]],
|
||||
*,
|
||||
top_k: int | None,
|
||||
return_documents: bool = True,
|
||||
) -> dict[str, Any]:
|
||||
before = npu_busy_time_us()
|
||||
started = time.perf_counter()
|
||||
results: list[dict[str, Any]] = []
|
||||
with self.lock:
|
||||
for idx, doc in enumerate(documents):
|
||||
scored = self._score_pair(query, str(doc["text"]))
|
||||
item: dict[str, Any] = {
|
||||
"index": idx,
|
||||
"score": scored["score"],
|
||||
"raw_score": scored["raw_score"],
|
||||
"probability": scored["probability"],
|
||||
}
|
||||
if doc.get("id") is not None:
|
||||
item["id"] = doc.get("id")
|
||||
if return_documents:
|
||||
item["text"] = doc["text"]
|
||||
item["metadata"] = doc.get("metadata") if isinstance(doc.get("metadata"), dict) else {}
|
||||
results.append(item)
|
||||
after = npu_busy_time_us()
|
||||
results.sort(key=lambda item: (-float(item["score"]), int(item["index"])))
|
||||
clamped_top_k = len(results) if top_k is None else max(1, min(int(top_k), len(results)))
|
||||
duration_ms = round((time.perf_counter() - started) * 1000, 3)
|
||||
npu_delta = None if before is None or after is None else after - before
|
||||
payload = {
|
||||
"ok": True,
|
||||
"model": self.model_id,
|
||||
"model_dir": str(self.model_dir),
|
||||
"device": self.device,
|
||||
"query": query,
|
||||
"input_count": len(documents),
|
||||
"top_k": clamped_top_k,
|
||||
"duration_ms": duration_ms,
|
||||
"npu_busy_delta_us": npu_delta,
|
||||
"results": results[:clamped_top_k],
|
||||
}
|
||||
self.last_inference = {
|
||||
"duration_ms": duration_ms,
|
||||
"docs": len(documents),
|
||||
"npu_busy_delta_us": npu_delta,
|
||||
}
|
||||
return payload
|
||||
|
||||
def health(self) -> dict[str, Any]:
|
||||
status = "ok" if self.ready else "degraded"
|
||||
return {
|
||||
"status": status,
|
||||
"ok": self.ready,
|
||||
"service": "openvino-reranker",
|
||||
"model": self.model_id,
|
||||
"model_dir": str(self.model_dir),
|
||||
"device": self.device,
|
||||
"available_devices": self.available_devices,
|
||||
"max_length": self.max_length,
|
||||
"input_names": sorted(self.input_names),
|
||||
"uptime_s": round(time.time() - self.loaded_at, 3),
|
||||
"npu_busy_time_us": npu_busy_time_us(),
|
||||
"startup_smoke": self.startup_smoke,
|
||||
"last_inference": self.last_inference,
|
||||
"ready_error": self.ready_error,
|
||||
}
|
||||
|
||||
|
||||
def normalize_documents(value: Any, max_documents: int) -> list[dict[str, Any]]:
|
||||
if not isinstance(value, list) or not value:
|
||||
raise ValueError("documents must be a non-empty list")
|
||||
if len(value) > max_documents:
|
||||
raise ValueError(f"documents exceeds max_documents={max_documents}")
|
||||
docs: list[dict[str, Any]] = []
|
||||
for idx, item in enumerate(value):
|
||||
if isinstance(item, str):
|
||||
text = item
|
||||
doc: dict[str, Any] = {"text": text}
|
||||
elif isinstance(item, dict):
|
||||
text = item.get("text")
|
||||
doc = {
|
||||
"id": item.get("id"),
|
||||
"text": text,
|
||||
"metadata": item.get("metadata") if isinstance(item.get("metadata"), dict) else {},
|
||||
}
|
||||
else:
|
||||
raise ValueError(f"documents[{idx}] must be a string or object")
|
||||
if not isinstance(text, str) or not text.strip():
|
||||
raise ValueError(f"documents[{idx}].text must be a non-empty string")
|
||||
docs.append(doc)
|
||||
return docs
|
||||
|
||||
|
||||
def parse_top_k(value: Any, document_count: int) -> int:
|
||||
"""Validate top_k/top_n before inference so schema errors return HTTP 400."""
|
||||
if value is None:
|
||||
return document_count
|
||||
if isinstance(value, bool) or not isinstance(value, int):
|
||||
raise ValueError("top_k/top_n must be a positive integer")
|
||||
if value < 1:
|
||||
raise ValueError("top_k/top_n must be a positive integer")
|
||||
return min(value, document_count)
|
||||
|
||||
|
||||
def assert_port_available(host: str, port: int) -> None:
|
||||
"""Fail fast on listener conflicts before compiling the OpenVINO model."""
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
|
||||
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
||||
try:
|
||||
sock.bind((host, port))
|
||||
except OSError as exc:
|
||||
raise RuntimeError(f"cannot bind {host}:{port}; listener conflict or invalid bind: {exc}") from exc
|
||||
|
||||
|
||||
class Handler(BaseHTTPRequestHandler):
|
||||
server_version = "OpenVINOReranker/0.1"
|
||||
|
||||
@property
|
||||
def svc(self) -> RerankerService:
|
||||
return self.server.reranker_service # type: ignore[attr-defined]
|
||||
|
||||
@property
|
||||
def max_body_bytes(self) -> int:
|
||||
return self.server.max_body_bytes # type: ignore[attr-defined]
|
||||
|
||||
@property
|
||||
def max_documents(self) -> int:
|
||||
return self.server.max_documents # type: ignore[attr-defined]
|
||||
|
||||
def do_GET(self) -> None:
|
||||
path = self.path.split("?", 1)[0].rstrip("/") or "/"
|
||||
if path == "/":
|
||||
self.write_json({"ok": True, "service": "openvino-reranker", "endpoints": ["/healthz", "/readyz", "/rerank", "/v1/rerank"]})
|
||||
elif path in {"/healthz", "/health"}:
|
||||
self.write_json(self.svc.health(), status=200)
|
||||
elif path == "/readyz":
|
||||
health = self.svc.health()
|
||||
self.write_json(health, status=200 if health.get("ok") else 503)
|
||||
else:
|
||||
self.write_json({"ok": False, "error": "not found", "results": []}, status=404)
|
||||
|
||||
def do_POST(self) -> None:
|
||||
path = self.path.split("?", 1)[0].rstrip("/") or "/"
|
||||
try:
|
||||
if path not in {"/rerank", "/v1/rerank"}:
|
||||
self.write_json({"ok": False, "error": "not found", "results": []}, status=404)
|
||||
return
|
||||
if not self.svc.ready:
|
||||
self.write_json({"ok": False, "error": self.svc.ready_error or "model not ready", "results": []}, status=503)
|
||||
return
|
||||
payload = self.read_json()
|
||||
query = payload.get("query")
|
||||
if not isinstance(query, str) or not query.strip():
|
||||
raise ValueError("query is required")
|
||||
top_k = payload.get("top_k", payload.get("top_n"))
|
||||
documents = normalize_documents(payload.get("documents"), self.max_documents)
|
||||
top_k = parse_top_k(top_k, len(documents))
|
||||
return_documents = bool(payload.get("return_documents", True))
|
||||
response = self.svc.rerank(query.strip(), documents, top_k=top_k, return_documents=return_documents)
|
||||
self.write_json(response)
|
||||
except RequestTooLarge as exc:
|
||||
self.write_json({"ok": False, "error": str(exc), "results": []}, status=413)
|
||||
except ValueError as exc:
|
||||
self.write_json({"ok": False, "error": str(exc), "results": []}, status=400)
|
||||
except Exception as exc:
|
||||
self.write_json({"ok": False, "error": f"{type(exc).__name__}: {exc}", "results": []}, status=500)
|
||||
|
||||
def read_json(self) -> dict[str, Any]:
|
||||
length = int(self.headers.get("Content-Length") or 0)
|
||||
if length > self.max_body_bytes:
|
||||
raise RequestTooLarge(f"request body exceeds {self.max_body_bytes} bytes")
|
||||
body = self.rfile.read(length).decode("utf-8", "replace") if length else "{}"
|
||||
payload = json.loads(body or "{}")
|
||||
if not isinstance(payload, dict):
|
||||
raise ValueError("JSON body must be an object")
|
||||
return payload
|
||||
|
||||
def write_json(self, payload: dict[str, Any], status: int = 200) -> None:
|
||||
body = json.dumps(payload, ensure_ascii=False).encode("utf-8")
|
||||
self.send_response(status)
|
||||
self.send_header("Content-Type", "application/json")
|
||||
self.send_header("Content-Length", str(len(body)))
|
||||
self.end_headers()
|
||||
self.wfile.write(body)
|
||||
|
||||
def log_message(self, format: str, *args: Any) -> None: # noqa: A002 - stdlib override name
|
||||
print(f"{self.address_string()} - {format % args}", file=sys.stderr, flush=True)
|
||||
|
||||
|
||||
class RequestTooLarge(ValueError):
|
||||
pass
|
||||
|
||||
|
||||
def main() -> int:
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--host", default=os.environ.get("OPENVINO_RERANKER_HOST", "127.0.0.1"))
|
||||
parser.add_argument("--port", type=int, default=int(os.environ.get("OPENVINO_RERANKER_PORT", DEFAULT_PORT)))
|
||||
parser.add_argument("--model-dir", default=os.environ.get("OPENVINO_RERANKER_MODEL_DIR", str(DEFAULT_MODEL_DIR)))
|
||||
parser.add_argument("--model", default=os.environ.get("OPENVINO_RERANKER_MODEL", DEFAULT_MODEL_ID))
|
||||
parser.add_argument("--device", default=os.environ.get("OPENVINO_RERANKER_DEVICE", "NPU"))
|
||||
parser.add_argument("--max-length", type=int, default=int(os.environ.get("OPENVINO_RERANKER_MAX_LENGTH", str(DEFAULT_MAX_LENGTH))))
|
||||
parser.add_argument("--max-documents", type=int, default=int(os.environ.get("OPENVINO_RERANKER_MAX_DOCUMENTS", str(DEFAULT_MAX_DOCUMENTS))))
|
||||
parser.add_argument("--max-body-bytes", type=int, default=int(os.environ.get("OPENVINO_RERANKER_MAX_BODY_BYTES", str(DEFAULT_MAX_BODY_BYTES))))
|
||||
parser.add_argument("--skip-startup-smoke", action="store_true", default=os.environ.get("OPENVINO_RERANKER_SKIP_STARTUP_SMOKE", "").lower() in {"1", "true", "yes"})
|
||||
args = parser.parse_args()
|
||||
|
||||
assert_port_available(args.host, args.port)
|
||||
service = RerankerService(
|
||||
Path(args.model_dir).expanduser(),
|
||||
args.model,
|
||||
args.device,
|
||||
args.max_length,
|
||||
startup_smoke=not args.skip_startup_smoke,
|
||||
)
|
||||
httpd = ThreadingHTTPServer((args.host, args.port), Handler)
|
||||
httpd.reranker_service = service # type: ignore[attr-defined]
|
||||
httpd.max_body_bytes = args.max_body_bytes # type: ignore[attr-defined]
|
||||
httpd.max_documents = args.max_documents # type: ignore[attr-defined]
|
||||
print(
|
||||
f"openvino-reranker listening on {args.host}:{args.port} model={args.model} "
|
||||
f"model_dir={args.model_dir} device={args.device} max_length={args.max_length}",
|
||||
flush=True,
|
||||
)
|
||||
try:
|
||||
httpd.serve_forever()
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise SystemExit(main())
|
||||
Reference in New Issue
Block a user