56 lines
2.3 KiB
Python
56 lines
2.3 KiB
Python
#!/usr/bin/env python3
|
|
"""Unit checks for reranker request validation helpers.
|
|
|
|
These tests intentionally avoid loading an OpenVINO model; they only cover the
|
|
stdlib validation helpers used before inference.
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
import socket
|
|
import unittest
|
|
|
|
from server import assert_port_available, normalize_documents, parse_top_k
|
|
|
|
|
|
class ValidationTests(unittest.TestCase):
|
|
def test_normalize_accepts_strings_and_objects(self) -> None:
|
|
docs = normalize_documents(
|
|
[
|
|
"plain text document",
|
|
{"id": "obj", "text": "object document", "metadata": {"source": "synthetic"}},
|
|
],
|
|
max_documents=2,
|
|
)
|
|
self.assertEqual(docs[0], {"text": "plain text document"})
|
|
self.assertEqual(docs[1]["id"], "obj")
|
|
self.assertEqual(docs[1]["metadata"], {"source": "synthetic"})
|
|
|
|
def test_normalize_rejects_empty_or_too_many_documents(self) -> None:
|
|
with self.assertRaisesRegex(ValueError, "non-empty"):
|
|
normalize_documents([], max_documents=2)
|
|
with self.assertRaisesRegex(ValueError, "max_documents"):
|
|
normalize_documents(["a", "b", "c"], max_documents=2)
|
|
with self.assertRaisesRegex(ValueError, "non-empty string"):
|
|
normalize_documents([{"id": "empty", "text": ""}], max_documents=2)
|
|
|
|
def test_parse_top_k_defaults_clamps_and_rejects_invalid_values(self) -> None:
|
|
self.assertEqual(parse_top_k(None, document_count=3), 3)
|
|
self.assertEqual(parse_top_k(2, document_count=3), 2)
|
|
self.assertEqual(parse_top_k(99, document_count=3), 3)
|
|
for value in (0, -1, True, False, 1.5, "2", "nope"):
|
|
with self.subTest(value=value):
|
|
with self.assertRaisesRegex(ValueError, "positive integer"):
|
|
parse_top_k(value, document_count=3)
|
|
|
|
def test_assert_port_available_detects_listener_conflict(self) -> None:
|
|
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as listener:
|
|
listener.bind(("127.0.0.1", 0))
|
|
listener.listen(1)
|
|
port = listener.getsockname()[1]
|
|
with self.assertRaisesRegex(RuntimeError, "cannot bind"):
|
|
assert_port_available("127.0.0.1", port)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|