fix(rag): distinguish query and document embeddings
This commit is contained in:
@@ -67,15 +67,21 @@ class EmbeddingService:
|
|||||||
cfg.batch_size = 1
|
cfg.batch_size = 1
|
||||||
self.pipeline = ovg.TextEmbeddingPipeline(self.model_dir, self.device, cfg)
|
self.pipeline = ovg.TextEmbeddingPipeline(self.model_dir, self.device, cfg)
|
||||||
|
|
||||||
def embed_one(self, text: str) -> dict[str, Any]:
|
def embed_one(self, text: str, *, purpose: str = "query") -> dict[str, Any]:
|
||||||
text = str(text or "")
|
text = str(text or "")
|
||||||
if not text.strip():
|
if not text.strip():
|
||||||
raise ValueError("embedding input text is empty")
|
raise ValueError("embedding input text is empty")
|
||||||
|
if purpose not in {"query", "document"}:
|
||||||
|
raise ValueError("embedding purpose must be 'query' or 'document'")
|
||||||
before = npu_busy_time_us()
|
before = npu_busy_time_us()
|
||||||
started = time.perf_counter()
|
started = time.perf_counter()
|
||||||
# TextEmbeddingPipeline is a native object; serialize calls until proven
|
# TextEmbeddingPipeline is a native object; serialize calls until proven
|
||||||
# safe under concurrent NPU use. Tiny silicon clown-car avoidance clause.
|
# safe under concurrent NPU use. Tiny silicon clown-car avoidance clause.
|
||||||
with self.lock:
|
with self.lock:
|
||||||
|
if purpose == "document":
|
||||||
|
# batch_size=1 means embed_documents must receive exactly one doc.
|
||||||
|
vec = self.pipeline.embed_documents([text])[0]
|
||||||
|
else:
|
||||||
vec = self.pipeline.embed_query(text)
|
vec = self.pipeline.embed_query(text)
|
||||||
after = npu_busy_time_us()
|
after = npu_busy_time_us()
|
||||||
vector = [float(x) for x in vec]
|
vector = [float(x) for x in vec]
|
||||||
@@ -83,6 +89,7 @@ class EmbeddingService:
|
|||||||
return {
|
return {
|
||||||
"embedding": vector,
|
"embedding": vector,
|
||||||
"dim": len(vector),
|
"dim": len(vector),
|
||||||
|
"purpose": purpose,
|
||||||
"duration_ms": round((time.perf_counter() - started) * 1000, 3),
|
"duration_ms": round((time.perf_counter() - started) * 1000, 3),
|
||||||
"npu_busy_delta_us": None if before is None or after is None else after - before,
|
"npu_busy_delta_us": None if before is None or after is None else after - before,
|
||||||
}
|
}
|
||||||
@@ -136,17 +143,19 @@ class Handler(BaseHTTPRequestHandler):
|
|||||||
payload = self.read_json()
|
payload = self.read_json()
|
||||||
if path == "/api/embed":
|
if path == "/api/embed":
|
||||||
texts = normalize_input(payload.get("input"))
|
texts = normalize_input(payload.get("input"))
|
||||||
results = [self.svc.embed_one(text) for text in texts]
|
purpose = str(payload.get("purpose") or payload.get("task") or "document")
|
||||||
|
results = [self.svc.embed_one(text, purpose=purpose) for text in texts]
|
||||||
self.write_json({
|
self.write_json({
|
||||||
"model": payload.get("model") or self.svc.model_name,
|
"model": payload.get("model") or self.svc.model_name,
|
||||||
"embeddings": [item["embedding"] for item in results],
|
"embeddings": [item["embedding"] for item in results],
|
||||||
"embedding_dim": results[0]["dim"] if results else None,
|
"embedding_dim": results[0]["dim"] if results else None,
|
||||||
|
"purpose": purpose,
|
||||||
"npu_busy_delta_us": sum((item.get("npu_busy_delta_us") or 0) for item in results),
|
"npu_busy_delta_us": sum((item.get("npu_busy_delta_us") or 0) for item in results),
|
||||||
"durations_ms": [item["duration_ms"] for item in results],
|
"durations_ms": [item["duration_ms"] for item in results],
|
||||||
})
|
})
|
||||||
elif path == "/api/embeddings":
|
elif path == "/api/embeddings":
|
||||||
text = payload.get("prompt") or payload.get("input")
|
text = payload.get("prompt") or payload.get("input")
|
||||||
result = self.svc.embed_one(str(text or ""))
|
result = self.svc.embed_one(str(text or ""), purpose="query")
|
||||||
self.write_json({
|
self.write_json({
|
||||||
"model": payload.get("model") or self.svc.model_name,
|
"model": payload.get("model") or self.svc.model_name,
|
||||||
"embedding": result["embedding"],
|
"embedding": result["embedding"],
|
||||||
@@ -156,7 +165,8 @@ class Handler(BaseHTTPRequestHandler):
|
|||||||
})
|
})
|
||||||
elif path == "/v1/embeddings":
|
elif path == "/v1/embeddings":
|
||||||
texts = normalize_input(payload.get("input"))
|
texts = normalize_input(payload.get("input"))
|
||||||
results = [self.svc.embed_one(text) for text in texts]
|
purpose = str(payload.get("purpose") or payload.get("task") or "query")
|
||||||
|
results = [self.svc.embed_one(text, purpose=purpose) for text in texts]
|
||||||
self.write_json({
|
self.write_json({
|
||||||
"object": "list",
|
"object": "list",
|
||||||
"model": payload.get("model") or self.svc.model_name,
|
"model": payload.get("model") or self.svc.model_name,
|
||||||
@@ -166,6 +176,7 @@ class Handler(BaseHTTPRequestHandler):
|
|||||||
],
|
],
|
||||||
"usage": {"prompt_tokens": 0, "total_tokens": 0},
|
"usage": {"prompt_tokens": 0, "total_tokens": 0},
|
||||||
"embedding_dim": results[0]["dim"] if results else None,
|
"embedding_dim": results[0]["dim"] if results else None,
|
||||||
|
"purpose": purpose,
|
||||||
"npu_busy_delta_us": sum((item.get("npu_busy_delta_us") or 0) for item in results),
|
"npu_busy_delta_us": sum((item.get("npu_busy_delta_us") or 0) for item in results),
|
||||||
"durations_ms": [item["duration_ms"] for item in results],
|
"durations_ms": [item["duration_ms"] for item in results],
|
||||||
})
|
})
|
||||||
|
|||||||
Reference in New Issue
Block a user