ria-toolkit-oss/tests/agent/test_streamer_tx.py
2026-04-16 15:38:35 -04:00

141 lines
4.6 KiB
Python

"""TX streaming happy path + shutdown semantics."""
from __future__ import annotations
import asyncio
import time
import numpy as np
from ria_toolkit_oss.agent.config import AgentConfig
from ria_toolkit_oss.agent.streamer import Streamer
from ria_toolkit_oss.sdr.mock import MockSDR
class RecordingMockSDR(MockSDR):
"""MockSDR that records each TX callback's returned buffer."""
def __init__(self, buffer_size: int):
super().__init__(buffer_size=buffer_size)
self.tx_produced: list[np.ndarray] = []
def _stream_tx(self, callback) -> None:
self._enable_tx = True
self._tx_initialized = True
while self._enable_tx:
result = callback(self.rx_buffer_size)
self.tx_produced.append(np.asarray(result))
time.sleep(0.005)
class FakeWs:
def __init__(self):
self.json_sent: list[dict] = []
self.bytes_sent: list[bytes] = []
async def send_json(self, payload):
self.json_sent.append(payload)
async def send_bytes(self, data):
self.bytes_sent.append(data)
def _iq_frame(samples: np.ndarray) -> bytes:
interleaved = np.empty(samples.size * 2, dtype=np.float32)
interleaved[0::2] = samples.real
interleaved[1::2] = samples.imag
return interleaved.tobytes()
def test_tx_start_streams_binary_to_callback():
BUF = 16
sdr = RecordingMockSDR(buffer_size=BUF)
async def scenario():
ws = FakeWs()
s = Streamer(ws=ws, sdr_factory=lambda d, i: sdr, cfg=AgentConfig(tx_enabled=True))
# Frames of distinct content so we can assert ordering.
frame_a = np.arange(BUF, dtype=np.complex64) * (1 + 0j)
frame_b = (np.arange(BUF, dtype=np.complex64) + BUF) * (1 + 0j)
frame_c = (np.arange(BUF, dtype=np.complex64) + 2 * BUF) * (1 + 0j)
await s.on_message(
{
"type": "tx_start",
"app_id": "app-1",
"radio_config": {
"device": "mock",
"buffer_size": BUF,
"tx_sample_rate": 1_000_000,
"tx_center_frequency": 2.45e9,
"tx_gain": -20,
"underrun_policy": "zero",
},
}
)
# Push three IQ frames.
await s.on_binary(_iq_frame(frame_a))
await s.on_binary(_iq_frame(frame_b))
await s.on_binary(_iq_frame(frame_c))
# Let the executor thread consume them.
for _ in range(100):
# At least the 3 real frames, plus any zero-fill from before they
# arrived. We stop once 3 non-trivial buffers are recorded.
nontrivial = [b for b in sdr.tx_produced if np.any(b != 0)]
if len(nontrivial) >= 3:
break
await asyncio.sleep(0.01)
await s.on_message({"type": "tx_stop", "app_id": "app-1"})
return ws, sdr, s
ws, sdr, streamer = asyncio.run(scenario())
nontrivial = [b for b in sdr.tx_produced if np.any(b != 0)]
assert len(nontrivial) >= 3, "expected ≥3 nontrivial TX buffers"
# First three nontrivial buffers match the order we pushed them.
np.testing.assert_array_equal(nontrivial[0], np.arange(BUF, dtype=np.complex64))
np.testing.assert_array_equal(nontrivial[1], np.arange(BUF, 2 * BUF, dtype=np.complex64))
np.testing.assert_array_equal(nontrivial[2], np.arange(2 * BUF, 3 * BUF, dtype=np.complex64))
# Lifecycle: armed → transmitting → done.
states = [m["state"] for m in ws.json_sent if m.get("type") == "tx_status"]
assert states[0] == "armed"
assert "transmitting" in states
assert states[-1] == "done"
# Session cleared.
assert streamer._tx is None
def test_tx_stop_releases_sdr():
sdr = RecordingMockSDR(buffer_size=8)
async def scenario():
ws = FakeWs()
s = Streamer(ws=ws, sdr_factory=lambda d, i: sdr, cfg=AgentConfig(tx_enabled=True))
await s.on_message(
{
"type": "tx_start",
"app_id": "a",
"radio_config": {
"device": "mock",
"buffer_size": 8,
"tx_sample_rate": 1_000_000,
"tx_center_frequency": 2.45e9,
"tx_gain": -20,
"underrun_policy": "zero",
},
}
)
await asyncio.sleep(0.03)
await s.on_message({"type": "tx_stop", "app_id": "a"})
return s
s = asyncio.run(scenario())
# After stop, the registry has no outstanding references to ("mock", None).
assert s._registry.refcount(("mock", None)) == 0
assert s._tx is None