141 lines
4.6 KiB
Python
141 lines
4.6 KiB
Python
"""TX streaming happy path + shutdown semantics."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import time
|
|
|
|
import numpy as np
|
|
|
|
from ria_toolkit_oss.agent.config import AgentConfig
|
|
from ria_toolkit_oss.agent.streamer import Streamer
|
|
from ria_toolkit_oss.sdr.mock import MockSDR
|
|
|
|
|
|
class RecordingMockSDR(MockSDR):
|
|
"""MockSDR that records each TX callback's returned buffer."""
|
|
|
|
def __init__(self, buffer_size: int):
|
|
super().__init__(buffer_size=buffer_size)
|
|
self.tx_produced: list[np.ndarray] = []
|
|
|
|
def _stream_tx(self, callback) -> None:
|
|
self._enable_tx = True
|
|
self._tx_initialized = True
|
|
while self._enable_tx:
|
|
result = callback(self.rx_buffer_size)
|
|
self.tx_produced.append(np.asarray(result))
|
|
time.sleep(0.005)
|
|
|
|
|
|
class FakeWs:
|
|
def __init__(self):
|
|
self.json_sent: list[dict] = []
|
|
self.bytes_sent: list[bytes] = []
|
|
|
|
async def send_json(self, payload):
|
|
self.json_sent.append(payload)
|
|
|
|
async def send_bytes(self, data):
|
|
self.bytes_sent.append(data)
|
|
|
|
|
|
def _iq_frame(samples: np.ndarray) -> bytes:
|
|
interleaved = np.empty(samples.size * 2, dtype=np.float32)
|
|
interleaved[0::2] = samples.real
|
|
interleaved[1::2] = samples.imag
|
|
return interleaved.tobytes()
|
|
|
|
|
|
def test_tx_start_streams_binary_to_callback():
|
|
BUF = 16
|
|
sdr = RecordingMockSDR(buffer_size=BUF)
|
|
|
|
async def scenario():
|
|
ws = FakeWs()
|
|
s = Streamer(ws=ws, sdr_factory=lambda d, i: sdr, cfg=AgentConfig(tx_enabled=True))
|
|
|
|
# Frames of distinct content so we can assert ordering.
|
|
frame_a = np.arange(BUF, dtype=np.complex64) * (1 + 0j)
|
|
frame_b = (np.arange(BUF, dtype=np.complex64) + BUF) * (1 + 0j)
|
|
frame_c = (np.arange(BUF, dtype=np.complex64) + 2 * BUF) * (1 + 0j)
|
|
|
|
await s.on_message(
|
|
{
|
|
"type": "tx_start",
|
|
"app_id": "app-1",
|
|
"radio_config": {
|
|
"device": "mock",
|
|
"buffer_size": BUF,
|
|
"tx_sample_rate": 1_000_000,
|
|
"tx_center_frequency": 2.45e9,
|
|
"tx_gain": -20,
|
|
"underrun_policy": "zero",
|
|
},
|
|
}
|
|
)
|
|
# Push three IQ frames.
|
|
await s.on_binary(_iq_frame(frame_a))
|
|
await s.on_binary(_iq_frame(frame_b))
|
|
await s.on_binary(_iq_frame(frame_c))
|
|
|
|
# Let the executor thread consume them.
|
|
for _ in range(100):
|
|
# At least the 3 real frames, plus any zero-fill from before they
|
|
# arrived. We stop once 3 non-trivial buffers are recorded.
|
|
nontrivial = [b for b in sdr.tx_produced if np.any(b != 0)]
|
|
if len(nontrivial) >= 3:
|
|
break
|
|
await asyncio.sleep(0.01)
|
|
|
|
await s.on_message({"type": "tx_stop", "app_id": "app-1"})
|
|
return ws, sdr, s
|
|
|
|
ws, sdr, streamer = asyncio.run(scenario())
|
|
|
|
nontrivial = [b for b in sdr.tx_produced if np.any(b != 0)]
|
|
assert len(nontrivial) >= 3, "expected ≥3 nontrivial TX buffers"
|
|
|
|
# First three nontrivial buffers match the order we pushed them.
|
|
np.testing.assert_array_equal(nontrivial[0], np.arange(BUF, dtype=np.complex64))
|
|
np.testing.assert_array_equal(nontrivial[1], np.arange(BUF, 2 * BUF, dtype=np.complex64))
|
|
np.testing.assert_array_equal(nontrivial[2], np.arange(2 * BUF, 3 * BUF, dtype=np.complex64))
|
|
|
|
# Lifecycle: armed → transmitting → done.
|
|
states = [m["state"] for m in ws.json_sent if m.get("type") == "tx_status"]
|
|
assert states[0] == "armed"
|
|
assert "transmitting" in states
|
|
assert states[-1] == "done"
|
|
# Session cleared.
|
|
assert streamer._tx is None
|
|
|
|
|
|
def test_tx_stop_releases_sdr():
|
|
sdr = RecordingMockSDR(buffer_size=8)
|
|
|
|
async def scenario():
|
|
ws = FakeWs()
|
|
s = Streamer(ws=ws, sdr_factory=lambda d, i: sdr, cfg=AgentConfig(tx_enabled=True))
|
|
await s.on_message(
|
|
{
|
|
"type": "tx_start",
|
|
"app_id": "a",
|
|
"radio_config": {
|
|
"device": "mock",
|
|
"buffer_size": 8,
|
|
"tx_sample_rate": 1_000_000,
|
|
"tx_center_frequency": 2.45e9,
|
|
"tx_gain": -20,
|
|
"underrun_policy": "zero",
|
|
},
|
|
}
|
|
)
|
|
await asyncio.sleep(0.03)
|
|
await s.on_message({"type": "tx_stop", "app_id": "a"})
|
|
return s
|
|
|
|
s = asyncio.run(scenario())
|
|
# After stop, the registry has no outstanding references to ("mock", None).
|
|
assert s._registry.refcount(("mock", None)) == 0
|
|
assert s._tx is None
|