ria-toolkit-oss/tests/agent/test_streamer.py
ben 22b035dbee
Some checks failed
Build Sphinx Docs Set / Build Docs (pull_request) Has been cancelled
Test with tox / Test with tox (3.10) (pull_request) Has been cancelled
Test with tox / Test with tox (3.11) (pull_request) Has been cancelled
Test with tox / Test with tox (3.12) (pull_request) Has been cancelled
Build Project / Build Project (3.12) (pull_request) Has been cancelled
Build Project / Build Project (3.11) (pull_request) Has been cancelled
Build Project / Build Project (3.10) (pull_request) Has been cancelled
format fixes
2026-04-20 13:51:15 -04:00

206 lines
6.6 KiB
Python
Raw Permalink RIA Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""Unit tests for the streamer: drive it with a fake WsClient + MockSDR."""
from __future__ import annotations
import asyncio
import numpy as np
from ria_toolkit_oss.agent.streamer import (
Streamer,
_apply_sdr_config,
_samples_to_interleaved_float32,
)
from ria_toolkit_oss.sdr.mock import MockSDR
class FakeWs:
def __init__(self):
self.json_sent: list[dict] = []
self.bytes_sent: list[bytes] = []
async def send_json(self, payload: dict) -> None:
self.json_sent.append(payload)
async def send_bytes(self, data: bytes) -> None:
self.bytes_sent.append(data)
def _factory(device: str, identifier):
return MockSDR(buffer_size=32, seed=0)
def test_samples_to_interleaved_float32_roundtrip():
c = np.array([1 + 2j, 3 + 4j], dtype=np.complex64)
raw = _samples_to_interleaved_float32(c)
arr = np.frombuffer(raw, dtype=np.float32)
assert arr.tolist() == [1.0, 2.0, 3.0, 4.0]
def test_apply_sdr_config_sets_attributes():
sdr = MockSDR(buffer_size=16)
_apply_sdr_config(sdr, {"sample_rate": 2e6, "center_frequency": 9.15e8, "gain": 30})
assert sdr.sample_rate == 2e6
assert sdr.center_freq == 9.15e8
assert sdr.gain == 30
def test_heartbeat_reflects_status_and_app():
async def scenario():
s = Streamer(ws=FakeWs(), sdr_factory=_factory)
hb = s.build_heartbeat()
assert hb["type"] == "heartbeat"
assert hb["status"] == "idle"
# capabilities default to rx-only
assert hb["capabilities"] == ["rx"]
assert hb["tx_enabled"] is False
await s.on_message(
{
"type": "start",
"app_id": "app-42",
"radio_config": {"device": "mock", "buffer_size": 32},
}
)
hb2 = s.build_heartbeat()
assert hb2["status"] == "streaming"
assert hb2["app_id"] == "app-42"
assert hb2["sessions"]["rx"]["app_id"] == "app-42"
await s.on_message({"type": "stop", "app_id": "app-42"})
asyncio.run(scenario())
def test_full_start_stream_stop_cycle():
async def scenario():
ws = FakeWs()
streamer = Streamer(ws=ws, sdr_factory=_factory)
await streamer.on_message(
{
"type": "start",
"app_id": "abc",
"radio_config": {
"device": "mock",
"sample_rate": 1_000_000,
"center_frequency": 2.45e9,
"gain": 40,
"buffer_size": 64,
},
}
)
for _ in range(30):
if len(ws.bytes_sent) >= 2:
break
await asyncio.sleep(0.02)
await streamer.on_message({"type": "stop", "app_id": "abc"})
return ws, streamer
ws, streamer = asyncio.run(scenario())
assert len(ws.bytes_sent) >= 1
for frame in ws.bytes_sent:
assert len(frame) == 64 * 2 * 4 # 64 samples × (I,Q) × float32
statuses = [m for m in ws.json_sent if m.get("type") == "status"]
assert statuses[0]["status"] == "streaming"
assert statuses[-1]["status"] == "idle"
assert streamer._rx is None
def test_start_without_device_emits_error():
async def scenario():
ws = FakeWs()
streamer = Streamer(ws=ws, sdr_factory=_factory)
await streamer.on_message({"type": "start", "app_id": "x", "radio_config": {}})
return ws
ws = asyncio.run(scenario())
errors = [m for m in ws.json_sent if m.get("type") == "error"]
assert errors and "device" in errors[0]["message"]
def test_configure_queues_update():
async def scenario():
streamer = Streamer(ws=FakeWs(), sdr_factory=_factory)
await streamer.on_message({"type": "configure", "app_id": "x", "radio_config": {"center_frequency": 915e6}})
# Before start(), pending config lives on the standalone dict exposed via the _pending_config shim.
return streamer._pending_config
pending = asyncio.run(scenario())
assert pending == {"center_frequency": 915e6}
def test_unknown_message_type_is_ignored():
async def scenario():
s = Streamer(ws=FakeWs(), sdr_factory=_factory)
await s.on_message({"type": "nope"})
asyncio.run(scenario())
def test_tx_data_available_is_a_silent_noop():
# Hub sends this as a keepalive; we should accept and ignore without
# emitting a WARNING or treating it as an error.
async def scenario():
ws = FakeWs()
s = Streamer(ws=ws, sdr_factory=_factory)
await s.on_message({"type": "tx_data_available", "app_id": "x"})
return ws
ws = asyncio.run(scenario())
# No outbound frames emitted.
assert ws.json_sent == []
assert ws.bytes_sent == []
def test_registry_shares_sdr_across_start_stop_cycles():
# Two sequential start/stop cycles with the same (device, identifier)
# should hit the registry's cache path rather than constructing a new SDR.
built: list[MockSDR] = []
def counting_factory(device: str, identifier):
sdr = MockSDR(buffer_size=16, seed=0)
built.append(sdr)
return sdr
async def scenario():
s = Streamer(ws=FakeWs(), sdr_factory=counting_factory)
for _ in range(2):
await s.on_message(
{
"type": "start",
"app_id": "a",
"radio_config": {"device": "mock", "buffer_size": 16},
}
)
# Let one capture buffer flow before stopping so the loop is engaged.
await asyncio.sleep(0.02)
await s.on_message({"type": "stop", "app_id": "a"})
asyncio.run(scenario())
# A new SDR per cycle (we fully close between starts) — registry refcount
# drops to zero on each stop. This test confirms close-and-rebuild works;
# the ref-counting share-while-open case is covered in the full-duplex tests.
assert len(built) == 2
def test_tx_start_rejected_when_tx_disabled():
from ria_toolkit_oss.agent.config import AgentConfig
async def scenario():
ws = FakeWs()
s = Streamer(ws=ws, sdr_factory=_factory, cfg=AgentConfig(tx_enabled=False))
await s.on_message(
{
"type": "tx_start",
"app_id": "a",
"radio_config": {"device": "mock", "tx_center_frequency": 2.45e9, "tx_gain": -20},
}
)
return ws
ws = asyncio.run(scenario())
tx_statuses = [m for m in ws.json_sent if m.get("type") == "tx_status"]
assert tx_statuses, "expected a tx_status frame"
assert tx_statuses[-1]["state"] == "error"
assert "disabled" in tx_statuses[-1]["message"].lower()