feat(rag): add optional NPU reranker fallback
This commit is contained in:
@@ -0,0 +1,138 @@
|
||||
import importlib.util
|
||||
import json
|
||||
import subprocess
|
||||
import sys
|
||||
import types
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
from typing import cast
|
||||
from unittest import mock
|
||||
|
||||
MODULE_PATH = Path(__file__).resolve().parents[1] / "scripts" / "obsidian-reindex-server.py"
|
||||
|
||||
|
||||
def load_module():
|
||||
spec = importlib.util.spec_from_file_location("obsidian_reindex_server", MODULE_PATH)
|
||||
assert spec is not None and spec.loader is not None
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
sys.modules[spec.name] = module
|
||||
spec.loader.exec_module(module)
|
||||
return cast(types.ModuleType, module)
|
||||
|
||||
|
||||
class SemanticSearchRerankTests(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.server = load_module()
|
||||
self.results = [
|
||||
{"id": "a", "text": "alpha doc", "path": "a.md", "score": 0.1},
|
||||
{"id": "b", "text": "beta doc", "path": "b.md", "score": 0.2},
|
||||
{"id": "c", "text": "gamma doc", "path": "c.md", "score": 0.3},
|
||||
]
|
||||
|
||||
def _mock_search_run(self, expected_top_k=None):
|
||||
def fake_run(cmd, capture_output, text, timeout, env):
|
||||
if expected_top_k is not None:
|
||||
self.assertEqual(cmd[cmd.index("--top-k") + 1], str(expected_top_k))
|
||||
return subprocess.CompletedProcess(
|
||||
cmd,
|
||||
0,
|
||||
stdout=json.dumps({"index": "obsidian_bge_npu", "results": self.results}),
|
||||
stderr="",
|
||||
)
|
||||
|
||||
return fake_run
|
||||
|
||||
def test_disabled_rerank_preserves_vector_order(self):
|
||||
setattr(self.server, "RAG_RERANK_ENABLED", False)
|
||||
with mock.patch.object(self.server.subprocess, "run", self._mock_search_run(expected_top_k=2)):
|
||||
payload = self.server.run_semantic_search("npu smoke", top_k=2)
|
||||
self.assertTrue(payload["ok"])
|
||||
self.assertEqual(payload["search_k"], 2)
|
||||
self.assertEqual([item["id"] for item in payload["results"]], ["a", "b"])
|
||||
self.assertEqual(payload["rerank"]["reason"], "disabled")
|
||||
self.assertFalse(payload["rerank"]["attempted"])
|
||||
|
||||
def test_enabled_rerank_reorders_matching_results(self):
|
||||
setattr(self.server, "RAG_RERANK_ENABLED", True)
|
||||
setattr(self.server, "RAG_RERANK_INITIAL_K", 3)
|
||||
setattr(self.server, "RAG_RERANK_TOP_K", 2)
|
||||
|
||||
class FakeResponse:
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, *args):
|
||||
return False
|
||||
|
||||
def read(self):
|
||||
return json.dumps(
|
||||
{
|
||||
"ok": True,
|
||||
"model": "synthetic-reranker",
|
||||
"device": "NPU",
|
||||
"npu_busy_delta_us": 123,
|
||||
"results": [
|
||||
{"id": "c", "score": 9.0},
|
||||
{"id": "a", "score": 7.0},
|
||||
],
|
||||
}
|
||||
).encode()
|
||||
|
||||
with mock.patch.object(self.server.subprocess, "run", self._mock_search_run(expected_top_k=3)), mock.patch.object(
|
||||
self.server.request, "urlopen", return_value=FakeResponse()
|
||||
):
|
||||
payload = self.server.run_semantic_search("npu smoke", top_k=2)
|
||||
self.assertEqual([item["id"] for item in payload["results"]], ["c", "a"])
|
||||
self.assertTrue(payload["rerank"]["attempted"])
|
||||
self.assertTrue(payload["rerank"]["ok"])
|
||||
self.assertEqual(payload["rerank"]["npu_busy_delta_us"], 123)
|
||||
self.assertEqual(payload["results"][0]["rerank_rank"], 1)
|
||||
|
||||
def test_enabled_rerank_error_falls_back_to_vector_order(self):
|
||||
setattr(self.server, "RAG_RERANK_ENABLED", True)
|
||||
setattr(self.server, "RAG_RERANK_INITIAL_K", 3)
|
||||
setattr(self.server, "RAG_RERANK_TOP_K", 2)
|
||||
with mock.patch.object(self.server.subprocess, "run", self._mock_search_run(expected_top_k=3)), mock.patch.object(
|
||||
self.server.request, "urlopen", side_effect=OSError("reranker unavailable")
|
||||
):
|
||||
payload = self.server.run_semantic_search("npu smoke", top_k=2)
|
||||
self.assertEqual([item["id"] for item in payload["results"]], ["a", "b"])
|
||||
self.assertTrue(payload["rerank"]["attempted"])
|
||||
self.assertFalse(payload["rerank"]["ok"])
|
||||
self.assertIn("reranker unavailable", payload["rerank"]["error"])
|
||||
|
||||
def test_enabled_rerank_requires_positive_npu_proof(self):
|
||||
setattr(self.server, "RAG_RERANK_ENABLED", True)
|
||||
setattr(self.server, "RAG_RERANK_INITIAL_K", 3)
|
||||
setattr(self.server, "RAG_RERANK_TOP_K", 2)
|
||||
setattr(self.server, "RAG_RERANK_REQUIRE_NPU_PROOF", True)
|
||||
|
||||
class FakeResponse:
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, *args):
|
||||
return False
|
||||
|
||||
def read(self):
|
||||
return json.dumps(
|
||||
{
|
||||
"ok": True,
|
||||
"device": "NPU",
|
||||
"npu_busy_delta_us": 0,
|
||||
"results": [{"id": "c", "score": 9.0}],
|
||||
}
|
||||
).encode()
|
||||
|
||||
with mock.patch.object(self.server.subprocess, "run", self._mock_search_run(expected_top_k=3)), mock.patch.object(
|
||||
self.server.request, "urlopen", return_value=FakeResponse()
|
||||
):
|
||||
payload = self.server.run_semantic_search("npu smoke", top_k=2)
|
||||
self.assertEqual([item["id"] for item in payload["results"]], ["a", "b"])
|
||||
self.assertTrue(payload["rerank"]["attempted"])
|
||||
self.assertFalse(payload["rerank"]["ok"])
|
||||
self.assertIn("positive npu_busy_delta_us", payload["rerank"]["error"])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user