Files
swarm-master/tests/test_npu_voice_audio_pipeline.py
T
2026-06-05 15:52:43 -07:00

171 lines
7.1 KiB
Python

import importlib.util
import json
import sys
import types
import unittest
from argparse import Namespace
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import cast
from unittest import mock
MODULE_PATH = Path(__file__).resolve().parents[1] / "scripts" / "npu_voice_audio_pipeline.py"
def load_module():
spec = importlib.util.spec_from_file_location("npu_voice_audio_pipeline", MODULE_PATH)
assert spec is not None and spec.loader is not None
module = importlib.util.module_from_spec(spec)
sys.modules[spec.name] = module
spec.loader.exec_module(module)
return cast(types.ModuleType, module)
class NpuVoiceAudioPipelineTests(unittest.TestCase):
def setUp(self):
self.pipeline = load_module()
def test_rejects_relative_audio_path(self):
with self.assertRaisesRegex(self.pipeline.PipelineError, "audio_path_must_be_absolute"):
self.pipeline.validate_audio_path("memo.wav", max_bytes=1024, max_audio_seconds=300)
def test_rejects_symlink_audio_path(self):
with TemporaryDirectory() as tmp:
root = Path(tmp)
target = root / "memo.wav"
target.write_bytes(b"RIFFfake")
link = root / "link.wav"
link.symlink_to(target)
with self.assertRaisesRegex(self.pipeline.PipelineError, "audio_path_must_not_be_symlink"):
self.pipeline.validate_audio_path(str(link), max_bytes=1024, max_audio_seconds=None)
def test_compact_labels_unwraps_classifier_label_values(self):
labels = self.pipeline.compact_labels(
{
"labels": {
"workflow_category": {"value": "media"},
"tool_needed": {"value": True},
"urgency": {"value": "high"},
"safety_confirmation_required": {"value": False},
}
}
)
self.assertEqual(labels["workflow_category"], "media")
self.assertTrue(labels["tool_needed"])
self.assertEqual(labels["urgency"], "high")
self.assertFalse(labels["safety_confirmation_required"])
def test_gate_blocks_missing_npu_proof(self):
action_worthy, atlas_gate, next_gate = self.pipeline.decide_gate(
"remind me to review logs",
{"tool_needed": True, "urgency": "normal", "safety_confirmation_required": False},
whisper_proven=False,
classifier_proven=True,
)
self.assertTrue(action_worthy)
self.assertEqual(atlas_gate, "blocked_missing_npu_proof")
self.assertEqual(next_gate, "npu_proof_required")
def test_loopback_endpoint_policy_accepts_local_urls(self):
allowed = [
"http://localhost:18816/v1/audio/transcriptions",
"https://localhost:18816/v1/audio/transcriptions",
"http://127.0.0.1:18816/v1/audio/transcriptions",
"http://127.42.0.9:18816/v1/audio/transcriptions",
"http://[::1]:18816/v1/audio/transcriptions",
]
for url in allowed:
with self.subTest(url=url):
self.assertEqual(self.pipeline.validate_loopback_endpoint(url, label="whisper"), url)
def test_loopback_endpoint_policy_rejects_remote_urls(self):
rejected = [
"http://example.com:18816/v1/audio/transcriptions",
"https://10.0.0.5:18816/v1/audio/transcriptions",
"http://192.168.1.10:18816/v1/audio/transcriptions",
"http://[2001:db8::1]:18816/v1/audio/transcriptions",
"file:///tmp/audio.wav",
]
for url in rejected:
with self.subTest(url=url):
with self.assertRaisesRegex(self.pipeline.PipelineError, "whisper_url_.*not_.*|whisper_url_scheme_not_allowed"):
self.pipeline.validate_loopback_endpoint(url, label="whisper")
def test_run_pipeline_rejects_remote_url_before_audio_read(self):
args = Namespace(
audio="/tmp/does-not-exist-remote-rejection-smoke.ogg",
id="voice-smoke",
source="local_file",
title="synthetic smoke",
language="en",
whisper_url="http://example.com:18816/v1/audio/transcriptions",
classifier_url="http://127.0.0.1:18819/v1/classify",
dry_run=True,
include_transcript=False,
include_transcript_preview_chars=0,
include_raw=False,
max_bytes=1024 * 1024,
max_audio_seconds=300,
max_transcript_chars=6000,
timeout=1,
)
with self.assertRaisesRegex(self.pipeline.PipelineError, "whisper_url_host_not_loopback"):
self.pipeline.run_pipeline(args)
def test_run_pipeline_compact_success_with_mocked_services(self):
with TemporaryDirectory() as tmp:
audio = Path(tmp) / "memo.ogg"
audio.write_bytes(b"not-real-audio-but-services-are-mocked")
args = Namespace(
audio=str(audio),
id="voice-smoke",
source="local_file",
title="synthetic smoke",
language="en",
whisper_url="http://127.0.0.1:18816/v1/audio/transcriptions",
classifier_url="http://127.0.0.1:18819/v1/classify",
dry_run=True,
include_transcript=False,
include_transcript_preview_chars=0,
include_raw=False,
max_bytes=1024 * 1024,
max_audio_seconds=300,
max_transcript_chars=6000,
timeout=1,
)
busy_values = iter([100, 150, 150, 225])
with mock.patch.object(self.pipeline, "read_npu_busy_us", side_effect=lambda: next(busy_values)):
with mock.patch.object(
self.pipeline,
"post_whisper",
return_value={"text": "remind me to check npu logs", "npu_busy_delta_us": 50},
):
with mock.patch.object(
self.pipeline,
"post_json",
return_value={
"dry_run": True,
"labels": {
"workflow_category": {"value": "media"},
"tool_needed": {"value": True},
"urgency": {"value": "normal"},
"safety_confirmation_required": {"value": False},
},
"npu_busy_delta_us": 75,
"sysfs_npu_busy_delta_us": 75,
},
):
result = self.pipeline.run_pipeline(args)
self.assertTrue(result["ok"])
self.assertEqual(result["external_sends"], 0)
self.assertEqual(result["writes"], 0)
self.assertEqual(result["whisper_sysfs_delta_us"], 50)
self.assertEqual(result["classifier_observed_sysfs_delta_us"], 75)
self.assertEqual(result["atlas_gate"], "advisory_only_not_sent")
self.assertNotIn("transcript", result)
json.dumps(result)
if __name__ == "__main__":
unittest.main()