feat(rag): add optional NPU reranker fallback

This commit is contained in:
William Valentin
2026-06-04 14:50:41 -07:00
parent 06f235d26b
commit 71f3c05587
5 changed files with 303 additions and 9 deletions
+138
View File
@@ -0,0 +1,138 @@
import importlib.util
import json
import subprocess
import sys
import types
import unittest
from pathlib import Path
from typing import cast
from unittest import mock
MODULE_PATH = Path(__file__).resolve().parents[1] / "scripts" / "obsidian-reindex-server.py"
def load_module():
spec = importlib.util.spec_from_file_location("obsidian_reindex_server", MODULE_PATH)
assert spec is not None and spec.loader is not None
module = importlib.util.module_from_spec(spec)
sys.modules[spec.name] = module
spec.loader.exec_module(module)
return cast(types.ModuleType, module)
class SemanticSearchRerankTests(unittest.TestCase):
def setUp(self):
self.server = load_module()
self.results = [
{"id": "a", "text": "alpha doc", "path": "a.md", "score": 0.1},
{"id": "b", "text": "beta doc", "path": "b.md", "score": 0.2},
{"id": "c", "text": "gamma doc", "path": "c.md", "score": 0.3},
]
def _mock_search_run(self, expected_top_k=None):
def fake_run(cmd, capture_output, text, timeout, env):
if expected_top_k is not None:
self.assertEqual(cmd[cmd.index("--top-k") + 1], str(expected_top_k))
return subprocess.CompletedProcess(
cmd,
0,
stdout=json.dumps({"index": "obsidian_bge_npu", "results": self.results}),
stderr="",
)
return fake_run
def test_disabled_rerank_preserves_vector_order(self):
setattr(self.server, "RAG_RERANK_ENABLED", False)
with mock.patch.object(self.server.subprocess, "run", self._mock_search_run(expected_top_k=2)):
payload = self.server.run_semantic_search("npu smoke", top_k=2)
self.assertTrue(payload["ok"])
self.assertEqual(payload["search_k"], 2)
self.assertEqual([item["id"] for item in payload["results"]], ["a", "b"])
self.assertEqual(payload["rerank"]["reason"], "disabled")
self.assertFalse(payload["rerank"]["attempted"])
def test_enabled_rerank_reorders_matching_results(self):
setattr(self.server, "RAG_RERANK_ENABLED", True)
setattr(self.server, "RAG_RERANK_INITIAL_K", 3)
setattr(self.server, "RAG_RERANK_TOP_K", 2)
class FakeResponse:
def __enter__(self):
return self
def __exit__(self, *args):
return False
def read(self):
return json.dumps(
{
"ok": True,
"model": "synthetic-reranker",
"device": "NPU",
"npu_busy_delta_us": 123,
"results": [
{"id": "c", "score": 9.0},
{"id": "a", "score": 7.0},
],
}
).encode()
with mock.patch.object(self.server.subprocess, "run", self._mock_search_run(expected_top_k=3)), mock.patch.object(
self.server.request, "urlopen", return_value=FakeResponse()
):
payload = self.server.run_semantic_search("npu smoke", top_k=2)
self.assertEqual([item["id"] for item in payload["results"]], ["c", "a"])
self.assertTrue(payload["rerank"]["attempted"])
self.assertTrue(payload["rerank"]["ok"])
self.assertEqual(payload["rerank"]["npu_busy_delta_us"], 123)
self.assertEqual(payload["results"][0]["rerank_rank"], 1)
def test_enabled_rerank_error_falls_back_to_vector_order(self):
setattr(self.server, "RAG_RERANK_ENABLED", True)
setattr(self.server, "RAG_RERANK_INITIAL_K", 3)
setattr(self.server, "RAG_RERANK_TOP_K", 2)
with mock.patch.object(self.server.subprocess, "run", self._mock_search_run(expected_top_k=3)), mock.patch.object(
self.server.request, "urlopen", side_effect=OSError("reranker unavailable")
):
payload = self.server.run_semantic_search("npu smoke", top_k=2)
self.assertEqual([item["id"] for item in payload["results"]], ["a", "b"])
self.assertTrue(payload["rerank"]["attempted"])
self.assertFalse(payload["rerank"]["ok"])
self.assertIn("reranker unavailable", payload["rerank"]["error"])
def test_enabled_rerank_requires_positive_npu_proof(self):
setattr(self.server, "RAG_RERANK_ENABLED", True)
setattr(self.server, "RAG_RERANK_INITIAL_K", 3)
setattr(self.server, "RAG_RERANK_TOP_K", 2)
setattr(self.server, "RAG_RERANK_REQUIRE_NPU_PROOF", True)
class FakeResponse:
def __enter__(self):
return self
def __exit__(self, *args):
return False
def read(self):
return json.dumps(
{
"ok": True,
"device": "NPU",
"npu_busy_delta_us": 0,
"results": [{"id": "c", "score": 9.0}],
}
).encode()
with mock.patch.object(self.server.subprocess, "run", self._mock_search_run(expected_top_k=3)), mock.patch.object(
self.server.request, "urlopen", return_value=FakeResponse()
):
payload = self.server.run_semantic_search("npu smoke", top_k=2)
self.assertEqual([item["id"] for item in payload["results"]], ["a", "b"])
self.assertTrue(payload["rerank"]["attempted"])
self.assertFalse(payload["rerank"]["ok"])
self.assertIn("positive npu_busy_delta_us", payload["rerank"]["error"])
if __name__ == "__main__":
unittest.main()