diff --git a/scripts/openvino-embeddings-server.py b/scripts/openvino-embeddings-server.py index a74dfa3..e902b8e 100755 --- a/scripts/openvino-embeddings-server.py +++ b/scripts/openvino-embeddings-server.py @@ -67,22 +67,29 @@ class EmbeddingService: cfg.batch_size = 1 self.pipeline = ovg.TextEmbeddingPipeline(self.model_dir, self.device, cfg) - def embed_one(self, text: str) -> dict[str, Any]: + def embed_one(self, text: str, *, purpose: str = "query") -> dict[str, Any]: text = str(text or "") if not text.strip(): raise ValueError("embedding input text is empty") + if purpose not in {"query", "document"}: + raise ValueError("embedding purpose must be 'query' or 'document'") before = npu_busy_time_us() started = time.perf_counter() # TextEmbeddingPipeline is a native object; serialize calls until proven # safe under concurrent NPU use. Tiny silicon clown-car avoidance clause. with self.lock: - vec = self.pipeline.embed_query(text) + if purpose == "document": + # batch_size=1 means embed_documents must receive exactly one doc. + vec = self.pipeline.embed_documents([text])[0] + else: + vec = self.pipeline.embed_query(text) after = npu_busy_time_us() vector = [float(x) for x in vec] self.embedding_dim = len(vector) return { "embedding": vector, "dim": len(vector), + "purpose": purpose, "duration_ms": round((time.perf_counter() - started) * 1000, 3), "npu_busy_delta_us": None if before is None or after is None else after - before, } @@ -136,17 +143,19 @@ class Handler(BaseHTTPRequestHandler): payload = self.read_json() if path == "/api/embed": texts = normalize_input(payload.get("input")) - results = [self.svc.embed_one(text) for text in texts] + purpose = str(payload.get("purpose") or payload.get("task") or "document") + results = [self.svc.embed_one(text, purpose=purpose) for text in texts] self.write_json({ "model": payload.get("model") or self.svc.model_name, "embeddings": [item["embedding"] for item in results], "embedding_dim": results[0]["dim"] if results else None, + "purpose": purpose, "npu_busy_delta_us": sum((item.get("npu_busy_delta_us") or 0) for item in results), "durations_ms": [item["duration_ms"] for item in results], }) elif path == "/api/embeddings": text = payload.get("prompt") or payload.get("input") - result = self.svc.embed_one(str(text or "")) + result = self.svc.embed_one(str(text or ""), purpose="query") self.write_json({ "model": payload.get("model") or self.svc.model_name, "embedding": result["embedding"], @@ -156,7 +165,8 @@ class Handler(BaseHTTPRequestHandler): }) elif path == "/v1/embeddings": texts = normalize_input(payload.get("input")) - results = [self.svc.embed_one(text) for text in texts] + purpose = str(payload.get("purpose") or payload.get("task") or "query") + results = [self.svc.embed_one(text, purpose=purpose) for text in texts] self.write_json({ "object": "list", "model": payload.get("model") or self.svc.model_name, @@ -166,6 +176,7 @@ class Handler(BaseHTTPRequestHandler): ], "usage": {"prompt_tokens": 0, "total_tokens": 0}, "embedding_dim": results[0]["dim"] if results else None, + "purpose": purpose, "npu_busy_delta_us": sum((item.get("npu_busy_delta_us") or 0) for item in results), "durations_ms": [item["duration_ms"] for item in results], })