125 lines
3.7 KiB
Python
125 lines
3.7 KiB
Python
"""Unit tests for the streamer: drive it with a fake WsClient + MockSDR."""
|
||
|
||
from __future__ import annotations
|
||
|
||
import asyncio
|
||
|
||
import numpy as np
|
||
|
||
from ria_toolkit_oss.agent.streamer import (
|
||
Streamer,
|
||
_apply_sdr_config,
|
||
_samples_to_interleaved_float32,
|
||
)
|
||
from ria_toolkit_oss.sdr.mock import MockSDR
|
||
|
||
|
||
class FakeWs:
|
||
def __init__(self):
|
||
self.json_sent: list[dict] = []
|
||
self.bytes_sent: list[bytes] = []
|
||
|
||
async def send_json(self, payload: dict) -> None:
|
||
self.json_sent.append(payload)
|
||
|
||
async def send_bytes(self, data: bytes) -> None:
|
||
self.bytes_sent.append(data)
|
||
|
||
|
||
def _factory(device: str, identifier):
|
||
return MockSDR(buffer_size=32, seed=0)
|
||
|
||
|
||
def test_samples_to_interleaved_float32_roundtrip():
|
||
c = np.array([1 + 2j, 3 + 4j], dtype=np.complex64)
|
||
raw = _samples_to_interleaved_float32(c)
|
||
arr = np.frombuffer(raw, dtype=np.float32)
|
||
assert arr.tolist() == [1.0, 2.0, 3.0, 4.0]
|
||
|
||
|
||
def test_apply_sdr_config_sets_attributes():
|
||
sdr = MockSDR(buffer_size=16)
|
||
_apply_sdr_config(sdr, {"sample_rate": 2e6, "center_frequency": 9.15e8, "gain": 30})
|
||
assert sdr.sample_rate == 2e6
|
||
assert sdr.center_freq == 9.15e8
|
||
assert sdr.gain == 30
|
||
|
||
|
||
def test_heartbeat_reflects_status_and_app():
|
||
s = Streamer(ws=FakeWs(), sdr_factory=_factory)
|
||
hb = s.build_heartbeat()
|
||
assert hb["type"] == "heartbeat"
|
||
assert hb["status"] == "idle"
|
||
s._status = "streaming"
|
||
s._app_id = "app-42"
|
||
hb2 = s.build_heartbeat()
|
||
assert hb2["status"] == "streaming"
|
||
assert hb2["app_id"] == "app-42"
|
||
|
||
|
||
def test_full_start_stream_stop_cycle():
|
||
async def scenario():
|
||
ws = FakeWs()
|
||
streamer = Streamer(ws=ws, sdr_factory=_factory)
|
||
|
||
await streamer.on_message(
|
||
{
|
||
"type": "start",
|
||
"app_id": "abc",
|
||
"radio_config": {
|
||
"device": "mock",
|
||
"sample_rate": 1_000_000,
|
||
"center_frequency": 2.45e9,
|
||
"gain": 40,
|
||
"buffer_size": 64,
|
||
},
|
||
}
|
||
)
|
||
for _ in range(30):
|
||
if len(ws.bytes_sent) >= 2:
|
||
break
|
||
await asyncio.sleep(0.02)
|
||
await streamer.on_message({"type": "stop", "app_id": "abc"})
|
||
return ws, streamer
|
||
|
||
ws, streamer = asyncio.run(scenario())
|
||
assert len(ws.bytes_sent) >= 1
|
||
for frame in ws.bytes_sent:
|
||
assert len(frame) == 64 * 2 * 4 # 64 samples × (I,Q) × float32
|
||
statuses = [m for m in ws.json_sent if m.get("type") == "status"]
|
||
assert statuses[0]["status"] == "streaming"
|
||
assert statuses[-1]["status"] == "idle"
|
||
assert streamer._sdr is None
|
||
|
||
|
||
def test_start_without_device_emits_error():
|
||
async def scenario():
|
||
ws = FakeWs()
|
||
streamer = Streamer(ws=ws, sdr_factory=_factory)
|
||
await streamer.on_message({"type": "start", "app_id": "x", "radio_config": {}})
|
||
return ws
|
||
|
||
ws = asyncio.run(scenario())
|
||
errors = [m for m in ws.json_sent if m.get("type") == "error"]
|
||
assert errors and "device" in errors[0]["message"]
|
||
|
||
|
||
def test_configure_queues_update():
|
||
async def scenario():
|
||
streamer = Streamer(ws=FakeWs(), sdr_factory=_factory)
|
||
await streamer.on_message(
|
||
{"type": "configure", "app_id": "x", "radio_config": {"center_frequency": 915e6}}
|
||
)
|
||
return streamer._pending_config
|
||
|
||
pending = asyncio.run(scenario())
|
||
assert pending == {"center_frequency": 915e6}
|
||
|
||
|
||
def test_unknown_message_type_is_ignored():
|
||
async def scenario():
|
||
s = Streamer(ws=FakeWs(), sdr_factory=_factory)
|
||
await s.on_message({"type": "nope"})
|
||
|
||
asyncio.run(scenario())
|