fix(rag): distinguish query and document embeddings
This commit is contained in:
@@ -67,22 +67,29 @@ class EmbeddingService:
|
||||
cfg.batch_size = 1
|
||||
self.pipeline = ovg.TextEmbeddingPipeline(self.model_dir, self.device, cfg)
|
||||
|
||||
def embed_one(self, text: str) -> dict[str, Any]:
|
||||
def embed_one(self, text: str, *, purpose: str = "query") -> dict[str, Any]:
|
||||
text = str(text or "")
|
||||
if not text.strip():
|
||||
raise ValueError("embedding input text is empty")
|
||||
if purpose not in {"query", "document"}:
|
||||
raise ValueError("embedding purpose must be 'query' or 'document'")
|
||||
before = npu_busy_time_us()
|
||||
started = time.perf_counter()
|
||||
# TextEmbeddingPipeline is a native object; serialize calls until proven
|
||||
# safe under concurrent NPU use. Tiny silicon clown-car avoidance clause.
|
||||
with self.lock:
|
||||
vec = self.pipeline.embed_query(text)
|
||||
if purpose == "document":
|
||||
# batch_size=1 means embed_documents must receive exactly one doc.
|
||||
vec = self.pipeline.embed_documents([text])[0]
|
||||
else:
|
||||
vec = self.pipeline.embed_query(text)
|
||||
after = npu_busy_time_us()
|
||||
vector = [float(x) for x in vec]
|
||||
self.embedding_dim = len(vector)
|
||||
return {
|
||||
"embedding": vector,
|
||||
"dim": len(vector),
|
||||
"purpose": purpose,
|
||||
"duration_ms": round((time.perf_counter() - started) * 1000, 3),
|
||||
"npu_busy_delta_us": None if before is None or after is None else after - before,
|
||||
}
|
||||
@@ -136,17 +143,19 @@ class Handler(BaseHTTPRequestHandler):
|
||||
payload = self.read_json()
|
||||
if path == "/api/embed":
|
||||
texts = normalize_input(payload.get("input"))
|
||||
results = [self.svc.embed_one(text) for text in texts]
|
||||
purpose = str(payload.get("purpose") or payload.get("task") or "document")
|
||||
results = [self.svc.embed_one(text, purpose=purpose) for text in texts]
|
||||
self.write_json({
|
||||
"model": payload.get("model") or self.svc.model_name,
|
||||
"embeddings": [item["embedding"] for item in results],
|
||||
"embedding_dim": results[0]["dim"] if results else None,
|
||||
"purpose": purpose,
|
||||
"npu_busy_delta_us": sum((item.get("npu_busy_delta_us") or 0) for item in results),
|
||||
"durations_ms": [item["duration_ms"] for item in results],
|
||||
})
|
||||
elif path == "/api/embeddings":
|
||||
text = payload.get("prompt") or payload.get("input")
|
||||
result = self.svc.embed_one(str(text or ""))
|
||||
result = self.svc.embed_one(str(text or ""), purpose="query")
|
||||
self.write_json({
|
||||
"model": payload.get("model") or self.svc.model_name,
|
||||
"embedding": result["embedding"],
|
||||
@@ -156,7 +165,8 @@ class Handler(BaseHTTPRequestHandler):
|
||||
})
|
||||
elif path == "/v1/embeddings":
|
||||
texts = normalize_input(payload.get("input"))
|
||||
results = [self.svc.embed_one(text) for text in texts]
|
||||
purpose = str(payload.get("purpose") or payload.get("task") or "query")
|
||||
results = [self.svc.embed_one(text, purpose=purpose) for text in texts]
|
||||
self.write_json({
|
||||
"object": "list",
|
||||
"model": payload.get("model") or self.svc.model_name,
|
||||
@@ -166,6 +176,7 @@ class Handler(BaseHTTPRequestHandler):
|
||||
],
|
||||
"usage": {"prompt_tokens": 0, "total_tokens": 0},
|
||||
"embedding_dim": results[0]["dim"] if results else None,
|
||||
"purpose": purpose,
|
||||
"npu_busy_delta_us": sum((item.get("npu_busy_delta_us") or 0) for item in results),
|
||||
"durations_ms": [item["duration_ms"] for item in results],
|
||||
})
|
||||
|
||||
Reference in New Issue
Block a user