feat(rag): add optional NPU reranker fallback
This commit is contained in:
@@ -21,14 +21,32 @@ import os
|
||||
import subprocess
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
from pathlib import Path
|
||||
from urllib.parse import parse_qs, urlparse
|
||||
from urllib import request, error
|
||||
|
||||
PORT = int(os.environ.get("PORT", 18810))
|
||||
REINDEX_TIMEOUT = int(os.environ.get("REINDEX_TIMEOUT", "1800"))
|
||||
RAG_COLLECTION = os.environ.get("RAG_COLLECTION", "obsidian").strip() or "obsidian"
|
||||
RAG_EMBED_MODEL = os.environ.get("RAG_EMBED_MODEL", "nomic-embed-text").strip() or "nomic-embed-text"
|
||||
OLLAMA_BASE_URL = (os.environ.get("OLLAMA_BASE_URL") or "http://127.0.0.1:18807").rstrip("/")
|
||||
RAG_RERANK_ENABLED = (os.environ.get("RAG_RERANK_ENABLED") or "false").strip().lower() in {
|
||||
"1",
|
||||
"true",
|
||||
"yes",
|
||||
"on",
|
||||
}
|
||||
RAG_RERANK_URL = (os.environ.get("RAG_RERANK_URL") or "http://127.0.0.1:18818/rerank").strip()
|
||||
RAG_RERANK_INITIAL_K = max(1, int(os.environ.get("RAG_RERANK_INITIAL_K") or "20"))
|
||||
RAG_RERANK_TOP_K = max(1, int(os.environ.get("RAG_RERANK_TOP_K") or "5"))
|
||||
RAG_RERANK_TIMEOUT_MS = max(1, int(os.environ.get("RAG_RERANK_TIMEOUT_MS") or "3000"))
|
||||
RAG_RERANK_REQUIRE_NPU_PROOF = (os.environ.get("RAG_RERANK_REQUIRE_NPU_PROOF") or "true").strip().lower() in {
|
||||
"1",
|
||||
"true",
|
||||
"yes",
|
||||
"on",
|
||||
}
|
||||
|
||||
REINDEX_SCRIPT = str(
|
||||
Path.home()
|
||||
@@ -102,12 +120,125 @@ def get_status() -> dict:
|
||||
return {"error": str(e)}
|
||||
|
||||
|
||||
def _result_text(result: dict) -> str:
|
||||
"""Return the text field sent to the reranker without changing response shape."""
|
||||
return str(result.get("text") or result.get("content") or "")
|
||||
|
||||
|
||||
def _apply_rerank(query: str, results: list[dict], final_k: int) -> tuple[list[dict], dict]:
|
||||
"""Optionally rerank semantic results, falling back to vector order on any error."""
|
||||
metadata = {
|
||||
"enabled": RAG_RERANK_ENABLED,
|
||||
"attempted": False,
|
||||
"ok": False,
|
||||
"url": RAG_RERANK_URL,
|
||||
"initial_k": len(results),
|
||||
"top_k": final_k,
|
||||
}
|
||||
if not RAG_RERANK_ENABLED:
|
||||
metadata["ok"] = True
|
||||
metadata["reason"] = "disabled"
|
||||
return results[:final_k], metadata
|
||||
if not results:
|
||||
metadata["ok"] = True
|
||||
metadata["reason"] = "no_results"
|
||||
return [], metadata
|
||||
|
||||
metadata["attempted"] = True
|
||||
documents = []
|
||||
for idx, item in enumerate(results):
|
||||
text = _result_text(item)
|
||||
if not text:
|
||||
continue
|
||||
documents.append(
|
||||
{
|
||||
"id": str(item.get("id") or idx),
|
||||
"text": text,
|
||||
"metadata": {
|
||||
"index": idx,
|
||||
"path": item.get("path"),
|
||||
"source": item.get("source"),
|
||||
"chunk": item.get("chunk"),
|
||||
},
|
||||
}
|
||||
)
|
||||
if not documents:
|
||||
metadata["ok"] = True
|
||||
metadata["reason"] = "no_text_documents"
|
||||
return results[:final_k], metadata
|
||||
|
||||
started = time.monotonic()
|
||||
try:
|
||||
body = json.dumps(
|
||||
{
|
||||
"query": query,
|
||||
"documents": documents,
|
||||
"top_k": final_k,
|
||||
"return_documents": False,
|
||||
}
|
||||
).encode("utf-8")
|
||||
req = request.Request(
|
||||
RAG_RERANK_URL,
|
||||
data=body,
|
||||
headers={"Content-Type": "application/json"},
|
||||
method="POST",
|
||||
)
|
||||
with request.urlopen(req, timeout=RAG_RERANK_TIMEOUT_MS / 1000.0) as resp:
|
||||
payload = json.loads(resp.read().decode("utf-8"))
|
||||
except (OSError, TimeoutError, json.JSONDecodeError, error.URLError, error.HTTPError) as exc:
|
||||
metadata["duration_ms"] = round((time.monotonic() - started) * 1000, 2)
|
||||
metadata["error"] = f"{type(exc).__name__}: {exc}"
|
||||
return results[:final_k], metadata
|
||||
|
||||
metadata["duration_ms"] = round((time.monotonic() - started) * 1000, 2)
|
||||
metadata["ok"] = bool(payload.get("ok", True))
|
||||
metadata["model"] = payload.get("model")
|
||||
metadata["device"] = payload.get("device")
|
||||
metadata["npu_busy_delta_us"] = payload.get("npu_busy_delta_us")
|
||||
metadata["require_npu_proof"] = RAG_RERANK_REQUIRE_NPU_PROOF
|
||||
metadata["input_count"] = payload.get("input_count")
|
||||
ranked = payload.get("results") or []
|
||||
if RAG_RERANK_REQUIRE_NPU_PROOF and int(payload.get("npu_busy_delta_us") or 0) <= 0:
|
||||
metadata["ok"] = False
|
||||
metadata["error"] = "reranker response lacked positive npu_busy_delta_us"
|
||||
return results[:final_k], metadata
|
||||
if not metadata["ok"] or not ranked:
|
||||
metadata["error"] = payload.get("error") or "reranker returned no ranked results"
|
||||
return results[:final_k], metadata
|
||||
|
||||
by_id = {str(item.get("id") or idx): item for idx, item in enumerate(results)}
|
||||
reranked = []
|
||||
for rank, ranked_item in enumerate(ranked):
|
||||
source_item = None
|
||||
if "id" in ranked_item:
|
||||
source_item = by_id.get(str(ranked_item.get("id")))
|
||||
if source_item is None and isinstance(ranked_item.get("index"), int):
|
||||
idx = ranked_item["index"]
|
||||
if 0 <= idx < len(results):
|
||||
source_item = results[idx]
|
||||
if source_item is None:
|
||||
continue
|
||||
merged = dict(source_item)
|
||||
merged["rerank_score"] = ranked_item.get("score")
|
||||
merged["rerank_rank"] = rank + 1
|
||||
reranked.append(merged)
|
||||
if len(reranked) >= final_k:
|
||||
break
|
||||
if not reranked:
|
||||
metadata["ok"] = False
|
||||
metadata["error"] = "reranker result IDs did not match search results"
|
||||
return results[:final_k], metadata
|
||||
return reranked, metadata
|
||||
|
||||
|
||||
def run_semantic_search(query: str, top_k: int = 5) -> dict:
|
||||
"""Query the local Obsidian Chroma index via the rag-search script."""
|
||||
query = (query or "").strip()
|
||||
if not query:
|
||||
return {"ok": False, "error": "query is required", "results": []}
|
||||
top_k = max(1, min(int(top_k or 5), 20))
|
||||
search_k = max(top_k, min(RAG_RERANK_INITIAL_K, 100)) if RAG_RERANK_ENABLED else top_k
|
||||
final_k = min(top_k, RAG_RERANK_TOP_K) if RAG_RERANK_ENABLED else top_k
|
||||
env = os.environ.copy()
|
||||
env.setdefault("RAG_COLLECTION", RAG_COLLECTION)
|
||||
env.setdefault("RAG_EMBED_MODEL", RAG_EMBED_MODEL)
|
||||
@@ -119,7 +250,7 @@ def run_semantic_search(query: str, top_k: int = 5) -> dict:
|
||||
"--index",
|
||||
RAG_COLLECTION,
|
||||
"--top-k",
|
||||
str(top_k),
|
||||
str(search_k),
|
||||
"--raw",
|
||||
query,
|
||||
],
|
||||
@@ -133,17 +264,27 @@ def run_semantic_search(query: str, top_k: int = 5) -> dict:
|
||||
"ok": False,
|
||||
"query": query,
|
||||
"top_k": top_k,
|
||||
"search_k": search_k,
|
||||
"error": result.stderr.strip()[-2000:] or result.stdout.strip()[-2000:],
|
||||
"results": [],
|
||||
"rerank": {
|
||||
"enabled": RAG_RERANK_ENABLED,
|
||||
"attempted": False,
|
||||
"ok": False,
|
||||
"error": "vector search failed before rerank",
|
||||
},
|
||||
}
|
||||
payload = json.loads(result.stdout)
|
||||
results = payload.get("results") or []
|
||||
results, rerank_meta = _apply_rerank(query, results, final_k)
|
||||
return {
|
||||
"ok": True,
|
||||
"query": query,
|
||||
"index": payload.get("index", RAG_COLLECTION),
|
||||
"top_k": top_k,
|
||||
"search_k": search_k,
|
||||
"result_count": len(results),
|
||||
"rerank": rerank_meta,
|
||||
"results": results,
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user