fix(rag): distinguish query and document embeddings

This commit is contained in:
William Valentin
2026-06-03 19:51:55 -07:00
parent fe4dea0f07
commit bcc652e5aa
+15 -4
View File
@@ -67,15 +67,21 @@ class EmbeddingService:
cfg.batch_size = 1 cfg.batch_size = 1
self.pipeline = ovg.TextEmbeddingPipeline(self.model_dir, self.device, cfg) self.pipeline = ovg.TextEmbeddingPipeline(self.model_dir, self.device, cfg)
def embed_one(self, text: str) -> dict[str, Any]: def embed_one(self, text: str, *, purpose: str = "query") -> dict[str, Any]:
text = str(text or "") text = str(text or "")
if not text.strip(): if not text.strip():
raise ValueError("embedding input text is empty") raise ValueError("embedding input text is empty")
if purpose not in {"query", "document"}:
raise ValueError("embedding purpose must be 'query' or 'document'")
before = npu_busy_time_us() before = npu_busy_time_us()
started = time.perf_counter() started = time.perf_counter()
# TextEmbeddingPipeline is a native object; serialize calls until proven # TextEmbeddingPipeline is a native object; serialize calls until proven
# safe under concurrent NPU use. Tiny silicon clown-car avoidance clause. # safe under concurrent NPU use. Tiny silicon clown-car avoidance clause.
with self.lock: with self.lock:
if purpose == "document":
# batch_size=1 means embed_documents must receive exactly one doc.
vec = self.pipeline.embed_documents([text])[0]
else:
vec = self.pipeline.embed_query(text) vec = self.pipeline.embed_query(text)
after = npu_busy_time_us() after = npu_busy_time_us()
vector = [float(x) for x in vec] vector = [float(x) for x in vec]
@@ -83,6 +89,7 @@ class EmbeddingService:
return { return {
"embedding": vector, "embedding": vector,
"dim": len(vector), "dim": len(vector),
"purpose": purpose,
"duration_ms": round((time.perf_counter() - started) * 1000, 3), "duration_ms": round((time.perf_counter() - started) * 1000, 3),
"npu_busy_delta_us": None if before is None or after is None else after - before, "npu_busy_delta_us": None if before is None or after is None else after - before,
} }
@@ -136,17 +143,19 @@ class Handler(BaseHTTPRequestHandler):
payload = self.read_json() payload = self.read_json()
if path == "/api/embed": if path == "/api/embed":
texts = normalize_input(payload.get("input")) texts = normalize_input(payload.get("input"))
results = [self.svc.embed_one(text) for text in texts] purpose = str(payload.get("purpose") or payload.get("task") or "document")
results = [self.svc.embed_one(text, purpose=purpose) for text in texts]
self.write_json({ self.write_json({
"model": payload.get("model") or self.svc.model_name, "model": payload.get("model") or self.svc.model_name,
"embeddings": [item["embedding"] for item in results], "embeddings": [item["embedding"] for item in results],
"embedding_dim": results[0]["dim"] if results else None, "embedding_dim": results[0]["dim"] if results else None,
"purpose": purpose,
"npu_busy_delta_us": sum((item.get("npu_busy_delta_us") or 0) for item in results), "npu_busy_delta_us": sum((item.get("npu_busy_delta_us") or 0) for item in results),
"durations_ms": [item["duration_ms"] for item in results], "durations_ms": [item["duration_ms"] for item in results],
}) })
elif path == "/api/embeddings": elif path == "/api/embeddings":
text = payload.get("prompt") or payload.get("input") text = payload.get("prompt") or payload.get("input")
result = self.svc.embed_one(str(text or "")) result = self.svc.embed_one(str(text or ""), purpose="query")
self.write_json({ self.write_json({
"model": payload.get("model") or self.svc.model_name, "model": payload.get("model") or self.svc.model_name,
"embedding": result["embedding"], "embedding": result["embedding"],
@@ -156,7 +165,8 @@ class Handler(BaseHTTPRequestHandler):
}) })
elif path == "/v1/embeddings": elif path == "/v1/embeddings":
texts = normalize_input(payload.get("input")) texts = normalize_input(payload.get("input"))
results = [self.svc.embed_one(text) for text in texts] purpose = str(payload.get("purpose") or payload.get("task") or "query")
results = [self.svc.embed_one(text, purpose=purpose) for text in texts]
self.write_json({ self.write_json({
"object": "list", "object": "list",
"model": payload.get("model") or self.svc.model_name, "model": payload.get("model") or self.svc.model_name,
@@ -166,6 +176,7 @@ class Handler(BaseHTTPRequestHandler):
], ],
"usage": {"prompt_tokens": 0, "total_tokens": 0}, "usage": {"prompt_tokens": 0, "total_tokens": 0},
"embedding_dim": results[0]["dim"] if results else None, "embedding_dim": results[0]["dim"] if results else None,
"purpose": purpose,
"npu_busy_delta_us": sum((item.get("npu_busy_delta_us") or 0) for item in results), "npu_busy_delta_us": sum((item.get("npu_busy_delta_us") or 0) for item in results),
"durations_ms": [item["duration_ms"] for item in results], "durations_ms": [item["duration_ms"] for item in results],
}) })