171 lines
7.1 KiB
Python
171 lines
7.1 KiB
Python
import importlib.util
|
|
import json
|
|
import sys
|
|
import types
|
|
import unittest
|
|
from argparse import Namespace
|
|
from pathlib import Path
|
|
from tempfile import TemporaryDirectory
|
|
from typing import cast
|
|
from unittest import mock
|
|
|
|
MODULE_PATH = Path(__file__).resolve().parents[1] / "scripts" / "npu_voice_audio_pipeline.py"
|
|
|
|
|
|
def load_module():
|
|
spec = importlib.util.spec_from_file_location("npu_voice_audio_pipeline", MODULE_PATH)
|
|
assert spec is not None and spec.loader is not None
|
|
module = importlib.util.module_from_spec(spec)
|
|
sys.modules[spec.name] = module
|
|
spec.loader.exec_module(module)
|
|
return cast(types.ModuleType, module)
|
|
|
|
|
|
class NpuVoiceAudioPipelineTests(unittest.TestCase):
|
|
def setUp(self):
|
|
self.pipeline = load_module()
|
|
|
|
def test_rejects_relative_audio_path(self):
|
|
with self.assertRaisesRegex(self.pipeline.PipelineError, "audio_path_must_be_absolute"):
|
|
self.pipeline.validate_audio_path("memo.wav", max_bytes=1024, max_audio_seconds=300)
|
|
|
|
def test_rejects_symlink_audio_path(self):
|
|
with TemporaryDirectory() as tmp:
|
|
root = Path(tmp)
|
|
target = root / "memo.wav"
|
|
target.write_bytes(b"RIFFfake")
|
|
link = root / "link.wav"
|
|
link.symlink_to(target)
|
|
with self.assertRaisesRegex(self.pipeline.PipelineError, "audio_path_must_not_be_symlink"):
|
|
self.pipeline.validate_audio_path(str(link), max_bytes=1024, max_audio_seconds=None)
|
|
|
|
def test_compact_labels_unwraps_classifier_label_values(self):
|
|
labels = self.pipeline.compact_labels(
|
|
{
|
|
"labels": {
|
|
"workflow_category": {"value": "media"},
|
|
"tool_needed": {"value": True},
|
|
"urgency": {"value": "high"},
|
|
"safety_confirmation_required": {"value": False},
|
|
}
|
|
}
|
|
)
|
|
self.assertEqual(labels["workflow_category"], "media")
|
|
self.assertTrue(labels["tool_needed"])
|
|
self.assertEqual(labels["urgency"], "high")
|
|
self.assertFalse(labels["safety_confirmation_required"])
|
|
|
|
def test_gate_blocks_missing_npu_proof(self):
|
|
action_worthy, atlas_gate, next_gate = self.pipeline.decide_gate(
|
|
"remind me to review logs",
|
|
{"tool_needed": True, "urgency": "normal", "safety_confirmation_required": False},
|
|
whisper_proven=False,
|
|
classifier_proven=True,
|
|
)
|
|
self.assertTrue(action_worthy)
|
|
self.assertEqual(atlas_gate, "blocked_missing_npu_proof")
|
|
self.assertEqual(next_gate, "npu_proof_required")
|
|
|
|
def test_loopback_endpoint_policy_accepts_local_urls(self):
|
|
allowed = [
|
|
"http://localhost:18816/v1/audio/transcriptions",
|
|
"https://localhost:18816/v1/audio/transcriptions",
|
|
"http://127.0.0.1:18816/v1/audio/transcriptions",
|
|
"http://127.42.0.9:18816/v1/audio/transcriptions",
|
|
"http://[::1]:18816/v1/audio/transcriptions",
|
|
]
|
|
for url in allowed:
|
|
with self.subTest(url=url):
|
|
self.assertEqual(self.pipeline.validate_loopback_endpoint(url, label="whisper"), url)
|
|
|
|
def test_loopback_endpoint_policy_rejects_remote_urls(self):
|
|
rejected = [
|
|
"http://example.com:18816/v1/audio/transcriptions",
|
|
"https://10.0.0.5:18816/v1/audio/transcriptions",
|
|
"http://192.168.1.10:18816/v1/audio/transcriptions",
|
|
"http://[2001:db8::1]:18816/v1/audio/transcriptions",
|
|
"file:///tmp/audio.wav",
|
|
]
|
|
for url in rejected:
|
|
with self.subTest(url=url):
|
|
with self.assertRaisesRegex(self.pipeline.PipelineError, "whisper_url_.*not_.*|whisper_url_scheme_not_allowed"):
|
|
self.pipeline.validate_loopback_endpoint(url, label="whisper")
|
|
|
|
def test_run_pipeline_rejects_remote_url_before_audio_read(self):
|
|
args = Namespace(
|
|
audio="/tmp/does-not-exist-remote-rejection-smoke.ogg",
|
|
id="voice-smoke",
|
|
source="local_file",
|
|
title="synthetic smoke",
|
|
language="en",
|
|
whisper_url="http://example.com:18816/v1/audio/transcriptions",
|
|
classifier_url="http://127.0.0.1:18819/v1/classify",
|
|
dry_run=True,
|
|
include_transcript=False,
|
|
include_transcript_preview_chars=0,
|
|
include_raw=False,
|
|
max_bytes=1024 * 1024,
|
|
max_audio_seconds=300,
|
|
max_transcript_chars=6000,
|
|
timeout=1,
|
|
)
|
|
with self.assertRaisesRegex(self.pipeline.PipelineError, "whisper_url_host_not_loopback"):
|
|
self.pipeline.run_pipeline(args)
|
|
|
|
def test_run_pipeline_compact_success_with_mocked_services(self):
|
|
with TemporaryDirectory() as tmp:
|
|
audio = Path(tmp) / "memo.ogg"
|
|
audio.write_bytes(b"not-real-audio-but-services-are-mocked")
|
|
args = Namespace(
|
|
audio=str(audio),
|
|
id="voice-smoke",
|
|
source="local_file",
|
|
title="synthetic smoke",
|
|
language="en",
|
|
whisper_url="http://127.0.0.1:18816/v1/audio/transcriptions",
|
|
classifier_url="http://127.0.0.1:18819/v1/classify",
|
|
dry_run=True,
|
|
include_transcript=False,
|
|
include_transcript_preview_chars=0,
|
|
include_raw=False,
|
|
max_bytes=1024 * 1024,
|
|
max_audio_seconds=300,
|
|
max_transcript_chars=6000,
|
|
timeout=1,
|
|
)
|
|
busy_values = iter([100, 150, 150, 225])
|
|
with mock.patch.object(self.pipeline, "read_npu_busy_us", side_effect=lambda: next(busy_values)):
|
|
with mock.patch.object(
|
|
self.pipeline,
|
|
"post_whisper",
|
|
return_value={"text": "remind me to check npu logs", "npu_busy_delta_us": 50},
|
|
):
|
|
with mock.patch.object(
|
|
self.pipeline,
|
|
"post_json",
|
|
return_value={
|
|
"dry_run": True,
|
|
"labels": {
|
|
"workflow_category": {"value": "media"},
|
|
"tool_needed": {"value": True},
|
|
"urgency": {"value": "normal"},
|
|
"safety_confirmation_required": {"value": False},
|
|
},
|
|
"npu_busy_delta_us": 75,
|
|
"sysfs_npu_busy_delta_us": 75,
|
|
},
|
|
):
|
|
result = self.pipeline.run_pipeline(args)
|
|
self.assertTrue(result["ok"])
|
|
self.assertEqual(result["external_sends"], 0)
|
|
self.assertEqual(result["writes"], 0)
|
|
self.assertEqual(result["whisper_sysfs_delta_us"], 50)
|
|
self.assertEqual(result["classifier_observed_sysfs_delta_us"], 75)
|
|
self.assertEqual(result["atlas_gate"], "advisory_only_not_sent")
|
|
self.assertNotIn("transcript", result)
|
|
json.dumps(result)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|