fix(rag): distinguish query and document embeddings

This commit is contained in:
William Valentin
2026-06-03 19:51:55 -07:00
parent fe4dea0f07
commit bcc652e5aa
+16 -5
View File
@@ -67,22 +67,29 @@ class EmbeddingService:
cfg.batch_size = 1
self.pipeline = ovg.TextEmbeddingPipeline(self.model_dir, self.device, cfg)
def embed_one(self, text: str) -> dict[str, Any]:
def embed_one(self, text: str, *, purpose: str = "query") -> dict[str, Any]:
text = str(text or "")
if not text.strip():
raise ValueError("embedding input text is empty")
if purpose not in {"query", "document"}:
raise ValueError("embedding purpose must be 'query' or 'document'")
before = npu_busy_time_us()
started = time.perf_counter()
# TextEmbeddingPipeline is a native object; serialize calls until proven
# safe under concurrent NPU use. Tiny silicon clown-car avoidance clause.
with self.lock:
vec = self.pipeline.embed_query(text)
if purpose == "document":
# batch_size=1 means embed_documents must receive exactly one doc.
vec = self.pipeline.embed_documents([text])[0]
else:
vec = self.pipeline.embed_query(text)
after = npu_busy_time_us()
vector = [float(x) for x in vec]
self.embedding_dim = len(vector)
return {
"embedding": vector,
"dim": len(vector),
"purpose": purpose,
"duration_ms": round((time.perf_counter() - started) * 1000, 3),
"npu_busy_delta_us": None if before is None or after is None else after - before,
}
@@ -136,17 +143,19 @@ class Handler(BaseHTTPRequestHandler):
payload = self.read_json()
if path == "/api/embed":
texts = normalize_input(payload.get("input"))
results = [self.svc.embed_one(text) for text in texts]
purpose = str(payload.get("purpose") or payload.get("task") or "document")
results = [self.svc.embed_one(text, purpose=purpose) for text in texts]
self.write_json({
"model": payload.get("model") or self.svc.model_name,
"embeddings": [item["embedding"] for item in results],
"embedding_dim": results[0]["dim"] if results else None,
"purpose": purpose,
"npu_busy_delta_us": sum((item.get("npu_busy_delta_us") or 0) for item in results),
"durations_ms": [item["duration_ms"] for item in results],
})
elif path == "/api/embeddings":
text = payload.get("prompt") or payload.get("input")
result = self.svc.embed_one(str(text or ""))
result = self.svc.embed_one(str(text or ""), purpose="query")
self.write_json({
"model": payload.get("model") or self.svc.model_name,
"embedding": result["embedding"],
@@ -156,7 +165,8 @@ class Handler(BaseHTTPRequestHandler):
})
elif path == "/v1/embeddings":
texts = normalize_input(payload.get("input"))
results = [self.svc.embed_one(text) for text in texts]
purpose = str(payload.get("purpose") or payload.get("task") or "query")
results = [self.svc.embed_one(text, purpose=purpose) for text in texts]
self.write_json({
"object": "list",
"model": payload.get("model") or self.svc.model_name,
@@ -166,6 +176,7 @@ class Handler(BaseHTTPRequestHandler):
],
"usage": {"prompt_tokens": 0, "total_tokens": 0},
"embedding_dim": results[0]["dim"] if results else None,
"purpose": purpose,
"npu_busy_delta_us": sum((item.get("npu_busy_delta_us") or 0) for item in results),
"durations_ms": [item["duration_ms"] for item in results],
})