ria-toolkit-oss/tests/agent/test_streamer.py
2026-04-13 11:48:15 -04:00

125 lines
3.7 KiB
Python
Raw RIA Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""Unit tests for the streamer: drive it with a fake WsClient + MockSDR."""
from __future__ import annotations
import asyncio
import numpy as np
from ria_toolkit_oss.agent.streamer import (
Streamer,
_apply_sdr_config,
_samples_to_interleaved_float32,
)
from ria_toolkit_oss.sdr.mock import MockSDR
class FakeWs:
def __init__(self):
self.json_sent: list[dict] = []
self.bytes_sent: list[bytes] = []
async def send_json(self, payload: dict) -> None:
self.json_sent.append(payload)
async def send_bytes(self, data: bytes) -> None:
self.bytes_sent.append(data)
def _factory(device: str, identifier):
return MockSDR(buffer_size=32, seed=0)
def test_samples_to_interleaved_float32_roundtrip():
c = np.array([1 + 2j, 3 + 4j], dtype=np.complex64)
raw = _samples_to_interleaved_float32(c)
arr = np.frombuffer(raw, dtype=np.float32)
assert arr.tolist() == [1.0, 2.0, 3.0, 4.0]
def test_apply_sdr_config_sets_attributes():
sdr = MockSDR(buffer_size=16)
_apply_sdr_config(sdr, {"sample_rate": 2e6, "center_frequency": 9.15e8, "gain": 30})
assert sdr.sample_rate == 2e6
assert sdr.center_freq == 9.15e8
assert sdr.gain == 30
def test_heartbeat_reflects_status_and_app():
s = Streamer(ws=FakeWs(), sdr_factory=_factory)
hb = s.build_heartbeat()
assert hb["type"] == "heartbeat"
assert hb["status"] == "idle"
s._status = "streaming"
s._app_id = "app-42"
hb2 = s.build_heartbeat()
assert hb2["status"] == "streaming"
assert hb2["app_id"] == "app-42"
def test_full_start_stream_stop_cycle():
async def scenario():
ws = FakeWs()
streamer = Streamer(ws=ws, sdr_factory=_factory)
await streamer.on_message(
{
"type": "start",
"app_id": "abc",
"radio_config": {
"device": "mock",
"sample_rate": 1_000_000,
"center_frequency": 2.45e9,
"gain": 40,
"buffer_size": 64,
},
}
)
for _ in range(30):
if len(ws.bytes_sent) >= 2:
break
await asyncio.sleep(0.02)
await streamer.on_message({"type": "stop", "app_id": "abc"})
return ws, streamer
ws, streamer = asyncio.run(scenario())
assert len(ws.bytes_sent) >= 1
for frame in ws.bytes_sent:
assert len(frame) == 64 * 2 * 4 # 64 samples × (I,Q) × float32
statuses = [m for m in ws.json_sent if m.get("type") == "status"]
assert statuses[0]["status"] == "streaming"
assert statuses[-1]["status"] == "idle"
assert streamer._sdr is None
def test_start_without_device_emits_error():
async def scenario():
ws = FakeWs()
streamer = Streamer(ws=ws, sdr_factory=_factory)
await streamer.on_message({"type": "start", "app_id": "x", "radio_config": {}})
return ws
ws = asyncio.run(scenario())
errors = [m for m in ws.json_sent if m.get("type") == "error"]
assert errors and "device" in errors[0]["message"]
def test_configure_queues_update():
async def scenario():
streamer = Streamer(ws=FakeWs(), sdr_factory=_factory)
await streamer.on_message(
{"type": "configure", "app_id": "x", "radio_config": {"center_frequency": 915e6}}
)
return streamer._pending_config
pending = asyncio.run(scenario())
assert pending == {"center_frequency": 915e6}
def test_unknown_message_type_is_ignored():
async def scenario():
s = Streamer(ws=FakeWs(), sdr_factory=_factory)
await s.on_message({"type": "nope"})
asyncio.run(scenario())