qac-cli-commands #26

Merged
madrigal merged 15 commits from qac-cli-commands into main 2026-04-21 09:03:29 -04:00
17 changed files with 1752 additions and 100 deletions
Showing only changes of commit b955256479 - Show all commits

View File

@ -5,8 +5,8 @@ Subcommands:
- ``ria-agent run [legacy args]`` legacy long-poll NodeAgent (unchanged). - ``ria-agent run [legacy args]`` legacy long-poll NodeAgent (unchanged).
- ``ria-agent stream`` new WebSocket-based IQ streamer. - ``ria-agent stream`` new WebSocket-based IQ streamer.
- ``ria-agent detect`` print SDR drivers whose modules import cleanly. - ``ria-agent detect`` print SDR drivers whose modules import cleanly.
- ``ria-agent register --url URL --token TOKEN`` save credentials to - ``ria-agent register --hub URL --api-key KEY`` register with the hub and
``~/.ria/agent.json``. save credentials (and optional TX interlocks) to ``~/.ria/agent.json``.
Invoking ``ria-agent`` with no subcommand falls through to the legacy Invoking ``ria-agent`` with no subcommand falls through to the legacy
long-poll behavior for back-compatibility with existing deployments. long-poll behavior for back-compatibility with existing deployments.
@ -69,9 +69,27 @@ def _cmd_register(args: argparse.Namespace) -> int:
if args.name: if args.name:
cfg.name = args.name cfg.name = args.name
cfg.insecure = bool(args.insecure) cfg.insecure = bool(args.insecure)
cfg.tx_enabled = bool(getattr(args, "allow_tx", False))
if (v := getattr(args, "tx_max_gain_db", None)) is not None:
cfg.tx_max_gain_db = float(v)
if (v := getattr(args, "tx_max_duration_s", None)) is not None:
cfg.tx_max_duration_s = float(v)
freq_ranges = getattr(args, "tx_freq_range", None) or []
if freq_ranges:
cfg.tx_allowed_freq_ranges = [[float(lo), float(hi)] for lo, hi in freq_ranges]
path = _config.save(cfg) path = _config.save(cfg)
print(f"Registered agent: {agent_id}") print(f"Registered agent: {agent_id}")
if cfg.tx_enabled:
caps: list[str] = []
if cfg.tx_max_gain_db is not None:
caps.append(f"gain<={cfg.tx_max_gain_db} dB")
if cfg.tx_max_duration_s is not None:
caps.append(f"duration<={cfg.tx_max_duration_s} s")
if cfg.tx_allowed_freq_ranges:
caps.append(f"freq in {cfg.tx_allowed_freq_ranges}")
tail = f" ({', '.join(caps)})" if caps else ""
print(f"TX enabled{tail}")
print(f"Credentials saved to {path}") print(f"Credentials saved to {path}")
return 0 return 0
@ -85,8 +103,10 @@ def _cmd_stream(args: argparse.Namespace) -> int:
if not url: if not url:
print("error: --url is required (or run `ria-agent register` first)", file=sys.stderr) print("error: --url is required (or run `ria-agent register` first)", file=sys.stderr)
return 2 return 2
if getattr(args, "allow_tx", False):
cfg.tx_enabled = True
try: try:
asyncio.run(run_streamer(url, token)) asyncio.run(run_streamer(url, token, cfg=cfg))
except KeyboardInterrupt: except KeyboardInterrupt:
pass pass
return 0 return 0
@ -123,11 +143,47 @@ def main() -> None:
p_reg.add_argument("--api-key", dest="api_key", required=True, help="Hub API key") p_reg.add_argument("--api-key", dest="api_key", required=True, help="Hub API key")
p_reg.add_argument("--name", default=None, help="Human-friendly agent name") p_reg.add_argument("--name", default=None, help="Human-friendly agent name")
p_reg.add_argument("--insecure", action="store_true", help="Skip TLS verification") p_reg.add_argument("--insecure", action="store_true", help="Skip TLS verification")
p_reg.add_argument(
"--allow-tx",
dest="allow_tx",
action="store_true",
help="Opt this agent in to TX (required for any transmission from the hub)",
)
p_reg.add_argument(
"--tx-max-gain-db",
dest="tx_max_gain_db",
type=float,
default=None,
help="Reject tx_start frames whose tx_gain exceeds this cap (dB)",
)
p_reg.add_argument(
"--tx-max-duration-s",
dest="tx_max_duration_s",
type=float,
default=None,
help="Auto-stop any TX session after this many seconds",
)
p_reg.add_argument(
"--tx-freq-range",
dest="tx_freq_range",
type=float,
nargs=2,
action="append",
metavar=("LO", "HI"),
default=None,
help="Allowed TX center-frequency range in Hz (repeat for multiple bands)",
)
p_stream = sub.add_parser("stream", help="Run the WebSocket IQ streamer") p_stream = sub.add_parser("stream", help="Run the WebSocket IQ streamer")
p_stream.add_argument("--url", default=None, help="Override WebSocket URL") p_stream.add_argument("--url", default=None, help="Override WebSocket URL")
p_stream.add_argument("--token", default=None, help="Override bearer token") p_stream.add_argument("--token", default=None, help="Override bearer token")
p_stream.add_argument("--log-level", default="INFO") p_stream.add_argument("--log-level", default="INFO")
p_stream.add_argument(
"--allow-tx",
dest="allow_tx",
action="store_true",
help="Runtime override: enable TX for this process without writing config",
)
# Unknown extras are forwarded to the legacy CLI when command == "run". # Unknown extras are forwarded to the legacy CLI when command == "run".
args, extras = parser.parse_known_args(argv) args, extras = parser.parse_known_args(argv)

View File

@ -7,7 +7,11 @@ Schema::
"agent_id": "agent-abc123", "agent_id": "agent-abc123",
"token": "rha_xxxx", "token": "rha_xxxx",
"name": "lab-bench-1", "name": "lab-bench-1",
"insecure": false "insecure": false,
"tx_enabled": false,
"tx_max_gain_db": null,
"tx_max_duration_s": null,
"tx_allowed_freq_ranges": null
} }
""" """
@ -18,7 +22,8 @@ import os
from dataclasses import asdict, dataclass, field from dataclasses import asdict, dataclass, field
from pathlib import Path from pathlib import Path
_DEFAULT_PATH = Path(os.environ.get("RIA_AGENT_CONFIG", str(Path.home() / ".ria" / "agent.json"))) def _resolve_default_path() -> Path:
return Path(os.environ.get("RIA_AGENT_CONFIG", str(Path.home() / ".ria" / "agent.json")))
@dataclass @dataclass
@ -29,15 +34,29 @@ class AgentConfig:
name: str = "" name: str = ""
insecure: bool = False insecure: bool = False
api_key: str = "" api_key: str = ""
tx_enabled: bool = False
tx_max_gain_db: float | None = None
tx_max_duration_s: float | None = None
tx_allowed_freq_ranges: list[list[float]] | None = None
extra: dict = field(default_factory=dict) extra: dict = field(default_factory=dict)
def default_path() -> Path: def default_path() -> Path:
return _DEFAULT_PATH return _resolve_default_path()
def _coerce_ranges(raw) -> list[list[float]] | None:
if raw is None:
return None
out: list[list[float]] = []
for pair in raw:
lo, hi = pair
out.append([float(lo), float(hi)])
return out
def load(path: Path | None = None) -> AgentConfig: def load(path: Path | None = None) -> AgentConfig:
p = path or _DEFAULT_PATH p = path or _resolve_default_path()
if not p.exists(): if not p.exists():
return AgentConfig() return AgentConfig()
data = json.loads(p.read_text()) data = json.loads(p.read_text())
@ -50,12 +69,16 @@ def load(path: Path | None = None) -> AgentConfig:
name=data.get("name", ""), name=data.get("name", ""),
insecure=bool(data.get("insecure", False)), insecure=bool(data.get("insecure", False)),
api_key=data.get("api_key", ""), api_key=data.get("api_key", ""),
tx_enabled=bool(data.get("tx_enabled", False)),
tx_max_gain_db=(float(v) if (v := data.get("tx_max_gain_db")) is not None else None),
tx_max_duration_s=(float(v) if (v := data.get("tx_max_duration_s")) is not None else None),
tx_allowed_freq_ranges=_coerce_ranges(data.get("tx_allowed_freq_ranges")),
extra=extra, extra=extra,
) )
def save(cfg: AgentConfig, path: Path | None = None) -> Path: def save(cfg: AgentConfig, path: Path | None = None) -> Path:
p = path or _DEFAULT_PATH p = path or _resolve_default_path()
p.parent.mkdir(parents=True, exist_ok=True) p.parent.mkdir(parents=True, exist_ok=True)
data = asdict(cfg) data = asdict(cfg)
extra = data.pop("extra", {}) or {} extra = data.pop("extra", {}) or {}

View File

@ -4,19 +4,41 @@ from __future__ import annotations
from ria_toolkit_oss.sdr import detect_available from ria_toolkit_oss.sdr import detect_available
from .config import AgentConfig
def available_devices() -> list[str]: def available_devices() -> list[str]:
"""Return a sorted list of device names whose driver modules import cleanly.""" """Return a sorted list of device names whose driver modules import cleanly."""
return sorted(detect_available().keys()) return sorted(detect_available().keys())
def heartbeat_payload(status: str = "idle", app_id: str | None = None) -> dict: def heartbeat_payload(
"""Build the JSON body of a periodic heartbeat frame.""" status: str = "idle",
app_id: str | None = None,
*,
cfg: AgentConfig | None = None,
sessions: dict | None = None,
) -> dict:
"""Build the JSON body of a periodic heartbeat frame.
*cfg* drives the ``capabilities`` list and the ``tx_enabled`` flag. If not
supplied, the heartbeat advertises RX-only with ``tx_enabled=False``
matching the pre-TX shape.
"""
c = cfg or AgentConfig()
capabilities = ["rx"]
if c.tx_enabled:
capabilities.append("tx")
payload: dict = { payload: dict = {
"type": "heartbeat", "type": "heartbeat",
"hardware": available_devices(), "hardware": available_devices(),
"status": status, "status": status,
"capabilities": capabilities,
"tx_enabled": bool(c.tx_enabled),
} }
if app_id: if app_id:
payload["app_id"] = app_id payload["app_id"] = app_id
if sessions:
payload["sessions"] = sessions
return payload return payload

View File

@ -1,20 +1,33 @@
"""Thin IQ-streaming agent. """IQ-streaming agent.
Listens for control messages from the RIA Hub over a persistent WebSocket. Listens for control messages from the RIA Hub over a persistent WebSocket.
When the server sends ``start``, opens the SDR described in ``radio_config``, Supports:
loops over ``sdr.rx(buffer_size)``, and sends each buffer as raw
interleaved float32 bytes. ``stop`` closes the SDR; ``configure`` applies - An **RX session** (hub sends ``start``/``stop``/``configure``; agent opens
parameter updates at the next capture boundary. the SDR, loops ``sdr.rx()`` and ships raw interleaved float32 IQ).
- A **TX session** (hub sends ``tx_start``/``tx_stop``/``tx_configure`` plus
binary IQ frames; agent feeds them into ``sdr._stream_tx``). Phase 3 wires
up the session plumbing and rejects TX when ``cfg.tx_enabled`` is False;
Phase 4 implements the full TX loop.
Both sessions can run concurrently on the same physical SDR (FDD) a
ref-counted SDR registry shares one driver instance when RX and TX name the
same ``(device, identifier)``.
""" """
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
import logging import logging
import queue
import threading
import time
from dataclasses import dataclass, field
from typing import Any from typing import Any
import numpy as np import numpy as np
from .config import AgentConfig
from .hardware import heartbeat_payload from .hardware import heartbeat_payload
from .ws_client import WsClient from .ws_client import WsClient
@ -23,6 +36,98 @@ logger = logging.getLogger("ria_agent.streamer")
_DEFAULT_BUFFER_SIZE = 1024 _DEFAULT_BUFFER_SIZE = 1024
# ---------------------------------------------------------------------------
# Session dataclasses
@dataclass
class RxSession:
app_id: str
sdr: Any
device_key: tuple[str, str | None]
buffer_size: int
task: asyncio.Task | None = None
pending_config: dict = field(default_factory=dict)
@dataclass
class TxSession:
app_id: str
sdr: Any
device_key: tuple[str, str | None]
buffer_size: int
task: Any = None # concurrent.futures.Future from run_in_executor
pending_config: dict = field(default_factory=dict)
underrun_policy: str = "pause"
last_buffer: np.ndarray | None = None
stop_event: threading.Event = field(default_factory=threading.Event)
started_at: float = 0.0
max_duration_s: float | None = None
state: str = "armed"
# Thread-safe queue of inbound interleaved-float32 IQ frames. Bounded so
# hub-side over-production triggers WS backpressure rather than memory
# growth in the agent.
in_queue: "queue.Queue[bytes]" = field(default_factory=lambda: queue.Queue(maxsize=8))
# Set by the TX callback when it hits an underrun while policy=="pause";
# asyncio side flips the session state and emits tx_status.
underrun_flag: threading.Event = field(default_factory=threading.Event)
# ---------------------------------------------------------------------------
# SDR registry (ref-counted so one Pluto handle serves RX + TX simultaneously)
class _SdrRegistry:
def __init__(self, factory):
self._factory = factory
self._instances: dict[tuple[str, str | None], tuple[Any, int]] = {}
self._lock = threading.Lock()
def acquire(self, device: str, identifier: str | None) -> tuple[Any, tuple[str, str | None]]:
key = (device, identifier)
with self._lock:
if key in self._instances:
sdr, rc = self._instances[key]
self._instances[key] = (sdr, rc + 1)
return sdr, key
# Build outside the lock: driver init can be slow and we don't want to
# block concurrent releases on unrelated devices.
sdr = self._factory(device, identifier)
with self._lock:
if key in self._instances:
# Raced another acquirer; discard our duplicate and share theirs.
other_sdr, rc = self._instances[key]
try:
sdr.close()
except Exception:
pass
self._instances[key] = (other_sdr, rc + 1)
return other_sdr, key
self._instances[key] = (sdr, 1)
return sdr, key
def release(self, key: tuple[str, str | None]) -> bool:
"""Decrement refcount. Returns True if the caller owns the last reference
and should close the SDR."""
with self._lock:
sdr, rc = self._instances.get(key, (None, 0))
if sdr is None:
return False
if rc <= 1:
del self._instances[key]
return True
self._instances[key] = (sdr, rc - 1)
return False
def refcount(self, key: tuple[str, str | None]) -> int:
with self._lock:
return self._instances.get(key, (None, 0))[1]
# ---------------------------------------------------------------------------
# Streamer
class Streamer: class Streamer:
"""Main streamer loop. """Main streamer loop.
@ -31,103 +136,186 @@ class Streamer:
ws: ws:
Connected :class:`WsClient`. Connected :class:`WsClient`.
sdr_factory: sdr_factory:
Callable ``(device, identifier) -> SDR``. Defaults to Callable ``(device, identifier) -> SDR``. Defaults to the helper in
:func:`ria_toolkit_oss.sdr.get_sdr_device`. Injectable for tests. :mod:`ria_toolkit_oss.sdr`. Injectable for tests.
cfg:
:class:`AgentConfig` for interlocks (``tx_enabled`` and caps) and
heartbeat capabilities. Defaults to an empty ``AgentConfig()`` which
leaves TX disabled.
""" """
def __init__(self, ws: WsClient, sdr_factory=None) -> None: def __init__(
self,
ws,
sdr_factory=None,
cfg: AgentConfig | None = None,
) -> None:
self.ws = ws self.ws = ws
self._sdr_factory = sdr_factory self._cfg = cfg or AgentConfig()
self._app_id: str | None = None self._registry = _SdrRegistry(sdr_factory or _default_sdr_factory)
self._sdr: Any = None self._rx: RxSession | None = None
self._pending_config: dict = {} self._tx: TxSession | None = None
self._capture_task: asyncio.Task | None = None # Pending radio_config accepted via ``configure`` before ``start``.
self._status = "idle" self._standalone_pending_config: dict = {}
# Cached asyncio event loop, set the first time a handler runs. Used
# to schedule async callbacks from the TX executor thread.
self._loop: asyncio.AbstractEventLoop | None = None
# ------------------------------------------------------------------
# Back-compat read-only shims for callers that check ``._sdr`` etc.
# Writes to these attributes are not supported — use the session objects.
@property
def _sdr(self):
return self._rx.sdr if self._rx is not None else None
@property
def _pending_config(self) -> dict:
return self._rx.pending_config if self._rx is not None else self._standalone_pending_config
# ------------------------------------------------------------------ # ------------------------------------------------------------------
# WsClient wiring # WsClient wiring
def build_heartbeat(self) -> dict: def build_heartbeat(self) -> dict:
return heartbeat_payload(status=self._status, app_id=self._app_id) status = "streaming" if (self._rx is not None or self._tx is not None) else "idle"
app_id: str | None = None
if self._rx is not None:
app_id = self._rx.app_id
elif self._tx is not None:
app_id = self._tx.app_id
sessions: dict[str, dict] = {}
if self._rx is not None:
sessions["rx"] = {"app_id": self._rx.app_id, "state": "streaming"}
if self._tx is not None:
sessions["tx"] = {"app_id": self._tx.app_id, "state": self._tx.state}
return heartbeat_payload(
status=status,
app_id=app_id,
cfg=self._cfg,
sessions=sessions or None,
)
async def on_message(self, msg: dict) -> None: async def on_message(self, msg: dict) -> None:
t = msg.get("type") t = msg.get("type")
if t == "start": handler = {
await self._handle_start(msg) "start": self._handle_rx_start,
elif t == "stop": "stop": self._handle_rx_stop,
await self._handle_stop(msg) "configure": self._handle_rx_configure,
elif t == "configure": "tx_start": self._handle_tx_start,
self._pending_config.update(msg.get("radio_config") or {}) "tx_stop": self._handle_tx_stop,
logger.debug("Queued configure: %s", self._pending_config) "tx_configure": self._handle_tx_configure,
else: }.get(t)
if handler is None:
logger.warning("Unknown server message type: %r", t) logger.warning("Unknown server message type: %r", t)
return
await handler(msg)
# ------------------------------------------------------------------ async def on_binary(self, data: bytes) -> None:
async def _handle_start(self, msg: dict) -> None: tx = self._tx
if self._capture_task is not None and not self._capture_task.done(): if tx is None:
logger.debug("Dropping %d-byte binary frame: no TX session", len(data))
return
# Backpressure: if the TX queue is full, await briefly so the hub's
# ``await ws.send`` throttles naturally via TCP. We don't block
# indefinitely — a 2s stall means something else is wrong.
loop = asyncio.get_running_loop()
try:
await loop.run_in_executor(None, lambda: tx.in_queue.put(data, timeout=2.0))
except queue.Full:
logger.warning("TX queue stalled; dropping frame")
# ==================================================================
# RX
async def _handle_rx_start(self, msg: dict) -> None:
if self._rx is not None:
logger.warning("start received while already streaming — ignoring") logger.warning("start received while already streaming — ignoring")
return return
self._app_id = msg.get("app_id") app_id = msg.get("app_id") or ""
radio_config = dict(msg.get("radio_config") or {}) radio_config = dict(msg.get("radio_config") or {})
device = radio_config.pop("device", None) device = radio_config.pop("device", None)
identifier = radio_config.pop("identifier", None) identifier = radio_config.pop("identifier", None)
buffer_size = int(radio_config.pop("buffer_size", _DEFAULT_BUFFER_SIZE)) buffer_size = int(radio_config.pop("buffer_size", _DEFAULT_BUFFER_SIZE))
if not device: if not device:
await self._send_error("start missing radio_config.device") await self._send_error(app_id, "start missing radio_config.device")
return return
try: try:
factory = self._sdr_factory or _default_sdr_factory sdr, device_key = self._registry.acquire(device, identifier)
self._sdr = factory(device, identifier) _apply_sdr_config(sdr, radio_config)
_apply_sdr_config(self._sdr, radio_config)
except Exception as exc: except Exception as exc:
logger.exception("Failed to open SDR %r", device) logger.exception("Failed to open SDR %r", device)
await self._send_error(f"SDR init failed: {exc}") await self._send_error(app_id, f"SDR init failed: {exc}")
return return
self._status = "streaming" # Inherit any pending config that was queued before start.
await self._send_status("streaming") pending = dict(self._standalone_pending_config)
self._capture_task = asyncio.create_task( self._standalone_pending_config = {}
self._capture_loop(buffer_size), name="ria-streamer-capture"
session = RxSession(
app_id=app_id,
sdr=sdr,
device_key=device_key,
buffer_size=buffer_size,
pending_config=pending,
)
self._rx = session
await self._send_status("streaming", app_id)
session.task = asyncio.create_task(
self._capture_loop(session), name="ria-streamer-capture"
) )
async def _handle_stop(self, msg: dict) -> None: async def _handle_rx_stop(self, msg: dict) -> None:
if self._capture_task is not None: session = self._rx
self._capture_task.cancel() if session is None:
return
if session.task is not None:
session.task.cancel()
try: try:
await self._capture_task await session.task
except (asyncio.CancelledError, Exception): except (asyncio.CancelledError, Exception):
pass pass
self._capture_task = None self._close_session_sdr(session)
self._close_sdr() app_id = session.app_id
self._app_id = None self._rx = None
self._status = "idle" await self._send_status("idle", app_id)
await self._send_status("idle")
async def _capture_loop(self, buffer_size: int) -> None: async def _handle_rx_configure(self, msg: dict) -> None:
cfg = dict(msg.get("radio_config") or {})
if self._rx is not None:
self._rx.pending_config.update(cfg)
else:
self._standalone_pending_config.update(cfg)
logger.debug("Queued configure: %s", cfg)
async def _capture_loop(self, session: RxSession) -> None:
loop = asyncio.get_running_loop() loop = asyncio.get_running_loop()
try: try:
while True: while True:
if self._pending_config: if session.pending_config:
cfg = self._pending_config cfg = session.pending_config
self._pending_config = {} session.pending_config = {}
try: try:
_apply_sdr_config(self._sdr, cfg) _apply_sdr_config(session.sdr, cfg)
except Exception as exc: except Exception as exc:
logger.warning("Applying configure failed: %s", exc) logger.warning("Applying configure failed: %s", exc)
try: try:
samples = await loop.run_in_executor(None, self._sdr.rx, buffer_size) samples = await loop.run_in_executor(
None, session.sdr.rx, session.buffer_size
)
except Exception as exc: except Exception as exc:
from ria_toolkit_oss.sdr import SdrDisconnectedError from ria_toolkit_oss.sdr import SdrDisconnectedError
if isinstance(exc, SdrDisconnectedError): if isinstance(exc, SdrDisconnectedError):
logger.warning("SDR disconnected: %s", exc) logger.warning("SDR disconnected: %s", exc)
await self._send_error(f"SDR disconnected: {exc}") await self._send_error(session.app_id, f"SDR disconnected: {exc}")
else: else:
logger.exception("SDR rx error") logger.exception("SDR rx error")
await self._send_error(f"SDR capture failed: {exc}") await self._send_error(session.app_id, f"SDR capture failed: {exc}")
break break
payload = _samples_to_interleaved_float32(samples) payload = _samples_to_interleaved_float32(samples)
@ -139,29 +327,305 @@ class Streamer:
except asyncio.CancelledError: except asyncio.CancelledError:
raise raise
finally: finally:
self._close_sdr() self._close_session_sdr(session)
# If the loop died on its own (e.g. SDR disconnect), clear the
# session handle so future ``start`` messages can proceed.
if self._rx is session:
self._rx = None
def _close_sdr(self) -> None: # ==================================================================
if self._sdr is None: # TX
async def _handle_tx_start(self, msg: dict) -> None:
app_id = msg.get("app_id") or ""
radio_config = dict(msg.get("radio_config") or {})
# --- interlocks (agent-enforced; never trust the hub alone) ---
if not self._cfg.tx_enabled:
await self._send_tx_status(app_id, "error", "tx disabled on this agent")
return return
tx_gain = radio_config.get("tx_gain")
if (
self._cfg.tx_max_gain_db is not None
and tx_gain is not None
and float(tx_gain) > float(self._cfg.tx_max_gain_db)
):
await self._send_tx_status(
app_id,
"error",
f"tx_gain {tx_gain} exceeds cap {self._cfg.tx_max_gain_db}",
)
return
tx_freq = radio_config.get("tx_center_frequency")
if self._cfg.tx_allowed_freq_ranges and tx_freq is not None:
f = float(tx_freq)
if not any(float(lo) <= f <= float(hi) for lo, hi in self._cfg.tx_allowed_freq_ranges):
await self._send_tx_status(
app_id,
"error",
f"tx_center_frequency {tx_freq} outside allowed ranges",
)
return
if self._tx is not None:
await self._send_tx_status(app_id, "error", "tx already active on this agent")
return
# --- device ---
device = radio_config.pop("device", None)
identifier = radio_config.pop("identifier", None)
buffer_size = int(radio_config.pop("buffer_size", _DEFAULT_BUFFER_SIZE))
underrun_policy = str(radio_config.pop("underrun_policy", "pause"))
if underrun_policy not in ("pause", "zero", "repeat"):
await self._send_tx_status(
app_id, "error", f"invalid underrun_policy {underrun_policy!r}"
)
return
if not device:
await self._send_tx_status(app_id, "error", "tx_start missing radio_config.device")
return
device_key: tuple[str, str | None] | None = None
sdr: Any = None
try: try:
self._sdr.close() sdr, device_key = self._registry.acquire(device, identifier)
_apply_sdr_config(sdr, radio_config)
# Only call init_tx when the hub supplied the three required
# parameters. Drivers that gate _stream_tx on _tx_initialized
# (e.g. Pluto) need this; drivers that don't (e.g. Mock) tolerate
# its absence.
init_args = {
k: radio_config.get(f"tx_{k}")
for k in ("sample_rate", "center_frequency", "gain")
}
if hasattr(sdr, "init_tx") and all(v is not None for v in init_args.values()):
sdr.init_tx(
sample_rate=init_args["sample_rate"],
center_frequency=init_args["center_frequency"],
gain=init_args["gain"],
channel=radio_config.get("tx_channel", 0),
gain_mode=radio_config.get("tx_gain_mode", "manual"),
)
except Exception as exc:
if device_key is not None:
if self._registry.release(device_key):
try:
sdr.close()
except Exception:
pass
logger.exception("Failed to init TX on %r", device)
await self._send_tx_status(app_id, "error", f"tx init failed: {exc}")
return
self._loop = asyncio.get_running_loop()
session = TxSession(
app_id=app_id,
sdr=sdr,
device_key=device_key,
buffer_size=buffer_size,
underrun_policy=underrun_policy,
started_at=time.monotonic(),
max_duration_s=self._cfg.tx_max_duration_s,
)
self._tx = session
await self._send_tx_status(app_id, "armed")
session.task = self._loop.run_in_executor(None, self._tx_executor_body, session)
# Spawn a small watchdog that transitions armed → transmitting when
# the first buffer has been consumed, and surfaces underrun / max-
# duration terminations back to the hub.
asyncio.create_task(self._tx_watchdog(session))
async def _handle_tx_stop(self, msg: dict) -> None:
session = self._tx
if session is None:
return
app_id = session.app_id
session.stop_event.set()
try:
session.sdr.pause_tx()
except Exception:
logger.debug("pause_tx raised during stop", exc_info=True)
# Wake the executor thread if it's blocked on ``queue.get``.
self._drain_tx_queue(session)
if session.task is not None:
try:
await asyncio.wait_for(asyncio.wrap_future(session.task), timeout=1.5)
except asyncio.TimeoutError:
logger.warning("TX executor did not exit within 1.5s after stop")
except Exception:
logger.debug("TX executor raised on shutdown", exc_info=True)
self._close_session_sdr(session)
self._tx = None
await self._send_tx_status(app_id, "done")
async def _handle_tx_configure(self, msg: dict) -> None:
if self._tx is None:
return
self._tx.pending_config.update(msg.get("radio_config") or {})
# ------------------------------------------------------------------
# TX executor & watchdog
def _tx_executor_body(self, session: TxSession) -> None:
try:
session.sdr._stream_tx(lambda n: self._tx_callback(session, n))
except Exception:
logger.exception("TX stream crashed")
self._schedule(self._send_tx_status(session.app_id, "error", "tx stream crashed"))
def _tx_callback(self, session: TxSession, num_samples) -> np.ndarray:
n = int(num_samples)
# Honor stop requests: return silence one last time and let the driver
# exit its loop on the next iteration (pause_tx flips _enable_tx).
if session.stop_event.is_set():
return _silence(n)
# Max-duration watchdog.
if (
session.max_duration_s is not None
and (time.monotonic() - session.started_at) >= float(session.max_duration_s)
):
session.stop_event.set()
try:
session.sdr.pause_tx()
except Exception:
pass
self._schedule(self._send_tx_status(session.app_id, "done", "max duration reached"))
return _silence(n)
# Apply queued configure at buffer boundary.
if session.pending_config:
cfg = session.pending_config
session.pending_config = {}
try:
_apply_sdr_config(session.sdr, cfg)
except Exception as exc:
logger.debug("tx_configure apply failed: %s", exc)
try:
raw = session.in_queue.get(timeout=0.1)
except queue.Empty:
return self._underrun_fill(session, n)
arr = np.frombuffer(raw, dtype=np.float32)
if arr.size < 2 or arr.size % 2 != 0:
logger.warning("Malformed TX frame: %d floats (must be non-zero even count)", arr.size)
return self._underrun_fill(session, n)
samples = (arr[0::2].astype(np.complex64) + 1j * arr[1::2].astype(np.complex64))
if samples.size < n:
out = np.zeros(n, dtype=np.complex64)
out[: samples.size] = samples
session.last_buffer = out
return out
if samples.size > n:
samples = samples[:n]
session.last_buffer = samples
if session.state == "armed":
session.state = "transmitting"
self._schedule(self._send_tx_status(session.app_id, "transmitting"))
return samples
def _underrun_fill(self, session: TxSession, n: int) -> np.ndarray:
policy = session.underrun_policy
if policy == "zero":
return _silence(n)
if policy == "repeat" and session.last_buffer is not None:
buf = session.last_buffer
if buf.size == n:
return buf
if buf.size > n:
return buf[:n].copy()
out = np.zeros(n, dtype=np.complex64)
out[: buf.size] = buf
return out
# "pause" policy (default) or "repeat" before any buffer arrived.
if not session.underrun_flag.is_set():
session.underrun_flag.set()
session.stop_event.set()
try:
session.sdr.pause_tx()
except Exception: except Exception:
pass pass
self._sdr = None return _silence(n)
async def _send_status(self, status: str) -> None: async def _tx_watchdog(self, session: TxSession) -> None:
# Poll the underrun flag so we can emit status + tear down cleanly
# when the callback flips the flag from the executor thread. Check
# underrun_flag before stop_event, since the "pause" path sets both.
while session is self._tx:
if session.underrun_flag.is_set():
await self._send_tx_status(session.app_id, "underrun")
await self._teardown_tx_after_underrun(session)
return
if session.stop_event.is_set():
return
await asyncio.sleep(0.05)
async def _teardown_tx_after_underrun(self, session: TxSession) -> None:
if self._tx is not session:
return
self._drain_tx_queue(session)
if session.task is not None:
try:
await asyncio.wait_for(asyncio.wrap_future(session.task), timeout=1.0)
except asyncio.TimeoutError:
logger.warning("TX executor did not exit within 1s after underrun")
except Exception:
logger.debug("TX executor raised during underrun teardown", exc_info=True)
self._close_session_sdr(session)
if self._tx is session:
self._tx = None
def _drain_tx_queue(self, session: TxSession) -> None:
try: try:
await self.ws.send_json({"type": "status", "status": status, "app_id": self._app_id}) while True:
session.in_queue.get_nowait()
except queue.Empty:
pass
def _schedule(self, coro) -> None:
loop = self._loop
if loop is None:
return
try:
asyncio.run_coroutine_threadsafe(coro, loop)
except Exception:
logger.debug("_schedule failed", exc_info=True)
# ==================================================================
# Helpers
def _close_session_sdr(self, session) -> None:
if session.sdr is None:
return
should_close = self._registry.release(session.device_key)
if should_close:
try:
session.sdr.close()
except Exception:
logger.debug("SDR close raised", exc_info=True)
async def _send_status(self, status: str, app_id: str) -> None:
try:
await self.ws.send_json({"type": "status", "status": status, "app_id": app_id})
except Exception as exc: except Exception as exc:
logger.debug("Status send failed: %s", exc) logger.debug("Status send failed: %s", exc)
async def _send_error(self, message: str) -> None: async def _send_error(self, app_id: str, message: str) -> None:
try: try:
await self.ws.send_json({"type": "error", "app_id": self._app_id, "message": message}) await self.ws.send_json({"type": "error", "app_id": app_id, "message": message})
except Exception as exc: except Exception as exc:
logger.debug("Error-frame send failed: %s", exc) logger.debug("Error-frame send failed: %s", exc)
async def _send_tx_status(self, app_id: str, state: str, message: str | None = None) -> None:
payload: dict = {"type": "tx_status", "app_id": app_id, "state": state}
if message is not None:
payload["message"] = message
try:
await self.ws.send_json(payload)
except Exception as exc:
logger.debug("tx_status send failed: %s", exc)
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# Helpers # Helpers
@ -172,6 +636,10 @@ _CONFIG_ATTR_MAP = {
"center_freq": ("center_freq", "rx_center_frequency"), "center_freq": ("center_freq", "rx_center_frequency"),
"gain": ("gain", "rx_gain"), "gain": ("gain", "rx_gain"),
"bandwidth": ("bandwidth", "rx_bandwidth"), "bandwidth": ("bandwidth", "rx_bandwidth"),
"tx_sample_rate": ("tx_sample_rate",),
"tx_center_frequency": ("tx_center_frequency", "tx_lo"),
"tx_gain": ("tx_gain",),
"tx_bandwidth": ("tx_bandwidth",),
} }
@ -194,6 +662,11 @@ def _apply_sdr_config(sdr: Any, cfg: dict) -> None:
logger.debug("radio_config key %r ignored (no matching attr)", key) logger.debug("radio_config key %r ignored (no matching attr)", key)
def _silence(num_samples: int) -> np.ndarray:
"""Return a ``num_samples``-length zero-filled complex64 buffer."""
return np.zeros(int(num_samples), dtype=np.complex64)
def _samples_to_interleaved_float32(samples: Any) -> bytes: def _samples_to_interleaved_float32(samples: Any) -> bytes:
"""Convert complex IQ samples (any numeric dtype) to interleaved float32 bytes.""" """Convert complex IQ samples (any numeric dtype) to interleaved float32 bytes."""
arr = np.asarray(samples) arr = np.asarray(samples)
@ -214,8 +687,12 @@ def _default_sdr_factory(device: str, identifier: str | None):
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# Top-level entry # Top-level entry
async def run_streamer(ws_url: str, token: str) -> None: async def run_streamer(ws_url: str, token: str, *, cfg: AgentConfig | None = None) -> None:
"""Connect to *ws_url* and run the streamer loop until cancelled.""" """Connect to *ws_url* and run the streamer loop until cancelled."""
ws = WsClient(ws_url, token) ws = WsClient(ws_url, token)
streamer = Streamer(ws) streamer = Streamer(ws, cfg=cfg)
await ws.run(streamer.on_message, streamer.build_heartbeat) await ws.run(
streamer.on_message,
streamer.build_heartbeat,
on_binary=streamer.on_binary,
)

View File

@ -15,6 +15,7 @@ logger = logging.getLogger("ria_agent.ws")
MessageHandler = Callable[[dict], Awaitable[None]] MessageHandler = Callable[[dict], Awaitable[None]]
HeartbeatBuilder = Callable[[], dict] HeartbeatBuilder = Callable[[], dict]
BinaryHandler = Callable[[bytes], Awaitable[None]]
class WsClient: class WsClient:
@ -65,7 +66,12 @@ class WsClient:
self._stop.set() self._stop.set()
# ------------------------------------------------------------------ # ------------------------------------------------------------------
async def run(self, on_message: MessageHandler, heartbeat: HeartbeatBuilder) -> None: async def run(
self,
on_message: MessageHandler,
heartbeat: HeartbeatBuilder,
on_binary: BinaryHandler | None = None,
) -> None:
"""Main loop: connect, heartbeat, dispatch messages, reconnect on drop.""" """Main loop: connect, heartbeat, dispatch messages, reconnect on drop."""
while not self._stop.is_set(): while not self._stop.is_set():
try: try:
@ -75,8 +81,13 @@ class WsClient:
try: try:
async for raw in self._ws: async for raw in self._ws:
if isinstance(raw, bytes): if isinstance(raw, bytes):
# Server shouldn't send binary to the agent; log and drop. if on_binary is None:
logger.debug("Discarding unexpected %d-byte binary frame", len(raw)) logger.debug("Discarding unexpected %d-byte binary frame", len(raw))
continue
try:
await on_binary(raw)
except Exception:
logger.exception("on_binary handler raised; dropping frame")
continue continue
try: try:
msg = json.loads(raw) msg = json.loads(raw)

View File

@ -21,6 +21,7 @@ from __future__ import annotations
import argparse import argparse
import json import json
import os
import shutil import shutil
import subprocess import subprocess
import sys import sys
@ -77,24 +78,33 @@ def _inspect_labels(engine: list[str], ref: str) -> dict:
return {} return {}
def _hardware_flags(labels: dict) -> list[str]: def _gpu_available() -> bool:
if os.path.exists("/dev/nvidia0"):
return True
return shutil.which("nvidia-smi") is not None
def _hardware_flags(labels: dict, no_gpu: bool, no_usb: bool, no_host_net: bool) -> tuple[list[str], list[str]]:
flags: list[str] = [] flags: list[str] = []
notes: list[str] = []
profile = (labels.get(_LABEL_PROFILE) or "").lower() profile = (labels.get(_LABEL_PROFILE) or "").lower()
hardware = (labels.get(_LABEL_HARDWARE) or "").lower() hardware = (labels.get(_LABEL_HARDWARE) or "").lower()
hw_items = {h.strip() for h in hardware.split(",") if h.strip()} hw_items = {h.strip() for h in hardware.split(",") if h.strip()}
if "nvidia" in profile or "holoscan" in profile or "cuda" in profile: wants_gpu = any(k in profile for k in ("nvidia", "holoscan", "cuda"))
flags += ["--gpus", "all"] if wants_gpu and not no_gpu:
if _gpu_available():
flags += ["--gpus", "all"]
else:
notes.append("image wants GPU but no NVIDIA runtime detected — skipping --gpus (use --force-gpu to override)")
needs_usb = hw_items & {"pluto", "rtlsdr", "hackrf", "bladerf"} if hw_items & {"pluto", "rtlsdr", "hackrf", "bladerf"} and not no_usb:
if needs_usb:
flags += ["--device", "/dev/bus/usb"] flags += ["--device", "/dev/bus/usb"]
needs_net = hw_items & {"usrp", "thinkrf", "pluto"} if hw_items & {"usrp", "thinkrf", "pluto"} and not no_host_net:
if needs_net:
flags += ["--net", "host"] flags += ["--net", "host"]
return flags return flags, notes
def _cmd_configure(args: argparse.Namespace) -> int: def _cmd_configure(args: argparse.Namespace) -> int:
@ -132,7 +142,10 @@ def _cmd_run(args: argparse.Namespace) -> int:
return rc return rc
labels = _inspect_labels(engine, ref) labels = _inspect_labels(engine, ref)
hw_flags = _hardware_flags(labels) no_gpu = args.no_gpu and not args.force_gpu
hw_flags, notes = _hardware_flags(labels, no_gpu=no_gpu, no_usb=args.no_usb, no_host_net=args.no_host_net)
if args.force_gpu and "--gpus" not in hw_flags:
hw_flags = ["--gpus", "all", *hw_flags]
cmd = [*engine, "run", "--rm"] cmd = [*engine, "run", "--rm"]
if not args.foreground: if not args.foreground:
@ -162,6 +175,8 @@ def _cmd_run(args: argparse.Namespace) -> int:
print(f"Running {ref} [{label_str}]") print(f"Running {ref} [{label_str}]")
if hw_flags: if hw_flags:
print(f" auto flags: {' '.join(hw_flags)}") print(f" auto flags: {' '.join(hw_flags)}")
for note in notes:
print(f" note: {note}")
return subprocess.call(cmd) return subprocess.call(cmd)
@ -225,6 +240,10 @@ def main() -> None:
p_run.add_argument("-v", "--volume", action="append", help="Extra volume mount") p_run.add_argument("-v", "--volume", action="append", help="Extra volume mount")
p_run.add_argument("-p", "--publish", action="append", help="Publish port") p_run.add_argument("-p", "--publish", action="append", help="Publish port")
p_run.add_argument("--foreground", "-F", action="store_true", help="Run in foreground (no -d)") p_run.add_argument("--foreground", "-F", action="store_true", help="Run in foreground (no -d)")
p_run.add_argument("--no-gpu", action="store_true", help="Skip --gpus flag even if image wants GPU")
p_run.add_argument("--force-gpu", action="store_true", help="Force --gpus all even if no NVIDIA runtime detected")
p_run.add_argument("--no-usb", action="store_true", help="Skip --device /dev/bus/usb")
p_run.add_argument("--no-host-net", action="store_true", help="Skip --net host")
p_run.add_argument("--dry-run", action="store_true", help="Print the container command and exit") p_run.add_argument("--dry-run", action="store_true", help="Print the container command and exit")
p_run.add_argument("--docker-args", nargs=argparse.REMAINDER, help="Pass remaining args to docker/podman run") p_run.add_argument("--docker-args", nargs=argparse.REMAINDER, help="Pass remaining args to docker/podman run")
p_run.add_argument("--app-args", nargs=argparse.REMAINDER, help="Pass remaining args to the app entrypoint") p_run.add_argument("--app-args", nargs=argparse.REMAINDER, help="Pass remaining args to the app entrypoint")

111
tests/agent/test_cli_tx.py Normal file
View File

@ -0,0 +1,111 @@
"""CLI flags for TX opt-in and interlocks."""
from __future__ import annotations
import json
import sys
from unittest.mock import patch
from ria_toolkit_oss.agent import cli as agent_cli
from ria_toolkit_oss.agent import config as agent_config
class _FakeResp:
def __init__(self, payload: dict):
self._payload = payload
def read(self) -> bytes:
return json.dumps(self._payload).encode()
def __enter__(self):
return self
def __exit__(self, *_a):
return False
def _run_register(argv: list[str], cfg_path) -> int:
fake_resp = _FakeResp({"agent_id": "agent-1", "token": "tok-abc"})
with patch.dict("os.environ", {"RIA_AGENT_CONFIG": str(cfg_path)}, clear=False), \
patch("urllib.request.urlopen", return_value=fake_resp), \
patch.object(sys, "argv", ["ria-agent", *argv]):
try:
agent_cli.main()
except SystemExit as exc:
return int(exc.code or 0)
return 0
def test_register_without_allow_tx_keeps_tx_disabled(tmp_path):
cfg_path = tmp_path / "agent.json"
_run_register(
["register", "--hub", "http://hub:3005", "--api-key", "K"],
cfg_path,
)
cfg = agent_config.load(path=cfg_path)
assert cfg.agent_id == "agent-1"
assert cfg.tx_enabled is False
assert cfg.tx_max_gain_db is None
def test_register_with_allow_tx_and_caps(tmp_path):
cfg_path = tmp_path / "agent.json"
_run_register(
[
"register",
"--hub",
"http://hub:3005",
"--api-key",
"K",
"--allow-tx",
"--tx-max-gain-db",
"-10",
"--tx-max-duration-s",
"60",
"--tx-freq-range",
"2.4e9",
"2.5e9",
"--tx-freq-range",
"5.7e9",
"5.8e9",
],
cfg_path,
)
cfg = agent_config.load(path=cfg_path)
assert cfg.tx_enabled is True
assert cfg.tx_max_gain_db == -10.0
assert cfg.tx_max_duration_s == 60.0
assert cfg.tx_allowed_freq_ranges == [[2.4e9, 2.5e9], [5.7e9, 5.8e9]]
def test_stream_allow_tx_does_not_persist(tmp_path):
# Pre-register with tx_enabled=False, then simulate `stream --allow-tx`.
# The on-disk config must remain unchanged; the runtime flag is process-local.
cfg_path = tmp_path / "agent.json"
base = agent_config.AgentConfig(
hub_url="http://hub:3005",
agent_id="agent-1",
token="tok-abc",
tx_enabled=False,
)
agent_config.save(base, path=cfg_path)
captured: dict = {}
async def _fake_run_streamer(url, token, *, cfg):
captured["cfg"] = cfg
return None
with patch.dict("os.environ", {"RIA_AGENT_CONFIG": str(cfg_path)}, clear=False), \
patch("ria_toolkit_oss.agent.streamer.run_streamer", new=_fake_run_streamer), \
patch.object(sys, "argv", ["ria-agent", "stream", "--allow-tx"]):
try:
agent_cli.main()
except SystemExit:
pass
# Runtime cfg had TX flipped on
assert captured["cfg"].tx_enabled is True
# But the persisted file is untouched
on_disk = agent_config.load(path=cfg_path)
assert on_disk.tx_enabled is False

View File

@ -20,6 +20,36 @@ def test_load_missing_returns_empty(tmp_path):
assert loaded == agent_config.AgentConfig() assert loaded == agent_config.AgentConfig()
def test_tx_fields_round_trip(tmp_path):
p = tmp_path / "agent.json"
cfg = agent_config.AgentConfig(
hub_url="https://hub.example.com",
agent_id="agent-1",
token="t",
tx_enabled=True,
tx_max_gain_db=-10.0,
tx_max_duration_s=60.0,
tx_allowed_freq_ranges=[[2.4e9, 2.5e9], [5.7e9, 5.8e9]],
)
agent_config.save(cfg, path=p)
loaded = agent_config.load(path=p)
assert loaded.tx_enabled is True
assert loaded.tx_max_gain_db == -10.0
assert loaded.tx_max_duration_s == 60.0
assert loaded.tx_allowed_freq_ranges == [[2.4e9, 2.5e9], [5.7e9, 5.8e9]]
def test_tx_fields_default_when_absent(tmp_path):
# Old configs written before TX existed should load cleanly with safe defaults.
p = tmp_path / "agent.json"
p.write_text('{"hub_url": "x", "agent_id": "a", "token": "t"}')
cfg = agent_config.load(path=p)
assert cfg.tx_enabled is False
assert cfg.tx_max_gain_db is None
assert cfg.tx_max_duration_s is None
assert cfg.tx_allowed_freq_ranges is None
def test_extra_keys_preserved(tmp_path): def test_extra_keys_preserved(tmp_path):
p = tmp_path / "agent.json" p = tmp_path / "agent.json"
p.write_text('{"hub_url": "x", "custom": 42}') p.write_text('{"hub_url": "x", "custom": 42}')

View File

@ -67,9 +67,9 @@ def test_streamer_reports_disconnected_and_ends_capture():
"radio_config": {"device": "fake", "buffer_size": 8}, "radio_config": {"device": "fake", "buffer_size": 8},
} }
) )
# Wait for the capture task to fail out. # Wait for the capture loop to emit its error frame and tear down the session.
for _ in range(50): for _ in range(100):
if streamer._capture_task and streamer._capture_task.done(): if any(m.get("type") == "error" for m in ws.json_sent) and streamer._rx is None:
break break
await asyncio.sleep(0.01) await asyncio.sleep(0.01)
return ws, sdr, streamer return ws, sdr, streamer
@ -79,3 +79,5 @@ def test_streamer_reports_disconnected_and_ends_capture():
errors = [m for m in ws.json_sent if m.get("type") == "error"] errors = [m for m in ws.json_sent if m.get("type") == "error"]
assert errors, "expected an error frame" assert errors, "expected an error frame"
assert "disconnected" in errors[-1]["message"].lower() assert "disconnected" in errors[-1]["message"].lower()
# Session handle cleared so future starts can proceed.
assert streamer._rx is None

View File

@ -0,0 +1,133 @@
"""Concurrent RX + TX sessions on the same agent — shared SDR via registry."""
from __future__ import annotations
import asyncio
import time
import numpy as np
from ria_toolkit_oss.agent.config import AgentConfig
from ria_toolkit_oss.agent.streamer import Streamer
from ria_toolkit_oss.sdr.mock import MockSDR
class FullDuplexMockSDR(MockSDR):
"""MockSDR with a recording TX path so the test can assert both directions."""
def __init__(self, buffer_size: int):
super().__init__(buffer_size=buffer_size)
self.tx_produced: list[np.ndarray] = []
def _stream_tx(self, callback):
self._enable_tx = True
self._tx_initialized = True
while self._enable_tx:
result = callback(self.rx_buffer_size)
self.tx_produced.append(np.asarray(result).copy())
time.sleep(0.005)
class FakeWs:
def __init__(self):
self.json_sent = []
self.bytes_sent = []
async def send_json(self, p):
self.json_sent.append(p)
async def send_bytes(self, b):
self.bytes_sent.append(b)
def _iq_frame(samples: np.ndarray) -> bytes:
interleaved = np.empty(samples.size * 2, dtype=np.float32)
interleaved[0::2] = samples.real
interleaved[1::2] = samples.imag
return interleaved.tobytes()
def test_rx_and_tx_share_one_sdr_instance():
built: list[FullDuplexMockSDR] = []
def factory(device, identifier):
sdr = FullDuplexMockSDR(buffer_size=16)
built.append(sdr)
return sdr
async def scenario():
ws = FakeWs()
s = Streamer(ws=ws, sdr_factory=factory, cfg=AgentConfig(tx_enabled=True))
# Start RX first.
await s.on_message(
{
"type": "start",
"app_id": "app-1",
"radio_config": {"device": "mock", "buffer_size": 16},
}
)
# Then start TX on the same device — should share the SDR handle.
await s.on_message(
{
"type": "tx_start",
"app_id": "app-1",
"radio_config": {
"device": "mock",
"buffer_size": 16,
"tx_gain": -20,
"tx_center_frequency": 2.45e9,
"underrun_policy": "zero",
},
}
)
# Push a known TX buffer.
marker = np.arange(16, dtype=np.complex64) + 7
await s.on_binary(_iq_frame(marker))
# Let both directions produce output.
for _ in range(80):
rx_ok = len(ws.bytes_sent) >= 2
tx_ok = any(np.array_equal(b, marker) for b in built[0].tx_produced) if built else False
if rx_ok and tx_ok:
break
await asyncio.sleep(0.01)
# Heartbeat should show both sessions.
hb = s.build_heartbeat()
# Stop TX first, RX keeps running.
await s.on_message({"type": "tx_stop", "app_id": "app-1"})
tx_after_stop = s._tx is None
rx_still_active = s._rx is not None
# Now stop RX.
await s.on_message({"type": "stop", "app_id": "app-1"})
return ws, s, built, hb, tx_after_stop, rx_still_active
ws, s, built, hb, tx_after_stop, rx_still_active = asyncio.run(scenario())
# One SDR was built and shared.
assert len(built) == 1, f"expected exactly one SDR instance, got {len(built)}"
# Both directions produced output.
assert len(ws.bytes_sent) >= 1, "RX produced no IQ frames"
marker = np.arange(16, dtype=np.complex64) + 7
assert any(
np.array_equal(b, marker) for b in built[0].tx_produced
), "TX callback never saw the pushed marker buffer"
# Heartbeat reflected both sessions while they were active.
assert hb["sessions"]["rx"]["app_id"] == "app-1"
assert hb["sessions"]["tx"]["app_id"] == "app-1"
# Stopping TX does not tear down RX.
assert tx_after_stop
assert rx_still_active
# After both stops, registry is empty.
assert s._registry.refcount(("mock", None)) == 0
assert s._rx is None
assert s._tx is None

View File

@ -23,7 +23,24 @@ def test_heartbeat_payload_shape():
assert p["status"] == "idle" assert p["status"] == "idle"
assert "mock" in p["hardware"] assert "mock" in p["hardware"]
assert "app_id" not in p assert "app_id" not in p
# New fields, default shape
assert p["capabilities"] == ["rx"]
assert p["tx_enabled"] is False
p2 = hardware.heartbeat_payload(status="streaming", app_id="abc") p2 = hardware.heartbeat_payload(status="streaming", app_id="abc")
assert p2["status"] == "streaming" assert p2["status"] == "streaming"
assert p2["app_id"] == "abc" assert p2["app_id"] == "abc"
def test_heartbeat_payload_tx_capability_from_cfg():
from ria_toolkit_oss.agent.config import AgentConfig
p = hardware.heartbeat_payload(cfg=AgentConfig(tx_enabled=True))
assert p["capabilities"] == ["rx", "tx"]
assert p["tx_enabled"] is True
def test_heartbeat_payload_sessions_field():
sessions = {"rx": {"app_id": "a", "state": "streaming"}}
p = hardware.heartbeat_payload(status="streaming", app_id="a", sessions=sessions)
assert p["sessions"] == sessions

View File

@ -0,0 +1,144 @@
"""End-to-end: local websockets server drives a Streamer's TX path."""
from __future__ import annotations
import asyncio
import json
import time
import numpy as np
import websockets
from ria_toolkit_oss.agent.config import AgentConfig
from ria_toolkit_oss.agent.streamer import Streamer
from ria_toolkit_oss.agent.ws_client import WsClient
from ria_toolkit_oss.sdr.mock import MockSDR
class RecordingMockSDR(MockSDR):
def __init__(self, buffer_size: int):
super().__init__(buffer_size=buffer_size)
self.tx_produced: list[np.ndarray] = []
def _stream_tx(self, callback):
self._enable_tx = True
self._tx_initialized = True
while self._enable_tx:
result = callback(self.rx_buffer_size)
self.tx_produced.append(np.asarray(result).copy())
time.sleep(0.005)
def _iq_frame(samples: np.ndarray) -> bytes:
interleaved = np.empty(samples.size * 2, dtype=np.float32)
interleaved[0::2] = samples.real
interleaved[1::2] = samples.imag
return interleaved.tobytes()
def test_server_tx_start_binary_stop_cycle_over_real_ws():
BUF = 16
sdr = RecordingMockSDR(buffer_size=BUF)
marker = np.arange(BUF, dtype=np.complex64) + 1
async def scenario():
control_frames: list[dict] = []
done = asyncio.Event()
async def server_handler(ws):
try:
# Drain initial heartbeat.
first = await asyncio.wait_for(ws.recv(), timeout=2.0)
control_frames.append(json.loads(first))
await ws.send(
json.dumps(
{
"type": "tx_start",
"app_id": "tx-app",
"radio_config": {
"device": "mock",
"buffer_size": BUF,
"tx_sample_rate": 1_000_000,
"tx_center_frequency": 2.45e9,
"tx_gain": -20,
"underrun_policy": "zero",
},
}
)
)
# Push a few binary IQ frames.
for _ in range(3):
await ws.send(_iq_frame(marker))
# Wait for at least "armed" + "transmitting" statuses.
for _ in range(100):
msg = await asyncio.wait_for(ws.recv(), timeout=2.0)
if isinstance(msg, str):
control_frames.append(json.loads(msg))
if any(
f.get("type") == "tx_status" and f.get("state") == "transmitting"
for f in control_frames
):
break
await ws.send(json.dumps({"type": "tx_stop", "app_id": "tx-app"}))
# Drain trailing statuses.
try:
while True:
msg = await asyncio.wait_for(ws.recv(), timeout=0.5)
if isinstance(msg, str):
control_frames.append(json.loads(msg))
except (asyncio.TimeoutError, Exception):
pass
finally:
done.set()
server = await websockets.serve(server_handler, "127.0.0.1", 0)
port = server.sockets[0].getsockname()[1]
try:
client = WsClient(
f"ws://127.0.0.1:{port}",
token="",
heartbeat_interval=10.0,
reconnect_pause=0.05,
)
streamer = Streamer(
ws=client,
sdr_factory=lambda d, i: sdr,
cfg=AgentConfig(tx_enabled=True),
)
task = asyncio.create_task(
client.run(
on_message=streamer.on_message,
heartbeat=streamer.build_heartbeat,
on_binary=streamer.on_binary,
)
)
await asyncio.wait_for(done.wait(), timeout=5.0)
client.stop()
task.cancel()
try:
await task
except (asyncio.CancelledError, Exception):
pass
finally:
server.close()
await server.wait_closed()
return control_frames, streamer
controls, streamer = asyncio.run(scenario())
# Heartbeat reached the server.
assert any(f.get("type") == "heartbeat" for f in controls)
# tx_status lifecycle: armed → transmitting → done.
tx_states = [f["state"] for f in controls if f.get("type") == "tx_status"]
assert tx_states[0] == "armed"
assert "transmitting" in tx_states
assert tx_states[-1] == "done"
# TX callback saw our marker buffer at least once.
assert any(np.array_equal(b, marker) for b in sdr.tx_produced)
# Session cleared.
assert streamer._tx is None

View File

@ -46,15 +46,29 @@ def test_apply_sdr_config_sets_attributes():
def test_heartbeat_reflects_status_and_app(): def test_heartbeat_reflects_status_and_app():
s = Streamer(ws=FakeWs(), sdr_factory=_factory) async def scenario():
hb = s.build_heartbeat() s = Streamer(ws=FakeWs(), sdr_factory=_factory)
assert hb["type"] == "heartbeat" hb = s.build_heartbeat()
assert hb["status"] == "idle" assert hb["type"] == "heartbeat"
s._status = "streaming" assert hb["status"] == "idle"
s._app_id = "app-42" # capabilities default to rx-only
hb2 = s.build_heartbeat() assert hb["capabilities"] == ["rx"]
assert hb2["status"] == "streaming" assert hb["tx_enabled"] is False
assert hb2["app_id"] == "app-42"
await s.on_message(
{
"type": "start",
"app_id": "app-42",
"radio_config": {"device": "mock", "buffer_size": 32},
}
)
hb2 = s.build_heartbeat()
assert hb2["status"] == "streaming"
assert hb2["app_id"] == "app-42"
assert hb2["sessions"]["rx"]["app_id"] == "app-42"
await s.on_message({"type": "stop", "app_id": "app-42"})
asyncio.run(scenario())
def test_full_start_stream_stop_cycle(): def test_full_start_stream_stop_cycle():
@ -89,7 +103,7 @@ def test_full_start_stream_stop_cycle():
statuses = [m for m in ws.json_sent if m.get("type") == "status"] statuses = [m for m in ws.json_sent if m.get("type") == "status"]
assert statuses[0]["status"] == "streaming" assert statuses[0]["status"] == "streaming"
assert statuses[-1]["status"] == "idle" assert statuses[-1]["status"] == "idle"
assert streamer._sdr is None assert streamer._rx is None
def test_start_without_device_emits_error(): def test_start_without_device_emits_error():
@ -110,6 +124,7 @@ def test_configure_queues_update():
await streamer.on_message( await streamer.on_message(
{"type": "configure", "app_id": "x", "radio_config": {"center_frequency": 915e6}} {"type": "configure", "app_id": "x", "radio_config": {"center_frequency": 915e6}}
) )
# Before start(), pending config lives on the standalone dict exposed via the _pending_config shim.
return streamer._pending_config return streamer._pending_config
pending = asyncio.run(scenario()) pending = asyncio.run(scenario())
@ -122,3 +137,56 @@ def test_unknown_message_type_is_ignored():
await s.on_message({"type": "nope"}) await s.on_message({"type": "nope"})
asyncio.run(scenario()) asyncio.run(scenario())
def test_registry_shares_sdr_across_start_stop_cycles():
# Two sequential start/stop cycles with the same (device, identifier)
# should hit the registry's cache path rather than constructing a new SDR.
built: list[MockSDR] = []
def counting_factory(device: str, identifier):
sdr = MockSDR(buffer_size=16, seed=0)
built.append(sdr)
return sdr
async def scenario():
s = Streamer(ws=FakeWs(), sdr_factory=counting_factory)
for _ in range(2):
await s.on_message(
{
"type": "start",
"app_id": "a",
"radio_config": {"device": "mock", "buffer_size": 16},
}
)
# Let one capture buffer flow before stopping so the loop is engaged.
await asyncio.sleep(0.02)
await s.on_message({"type": "stop", "app_id": "a"})
asyncio.run(scenario())
# A new SDR per cycle (we fully close between starts) — registry refcount
# drops to zero on each stop. This test confirms close-and-rebuild works;
# the ref-counting share-while-open case is covered in the full-duplex tests.
assert len(built) == 2
def test_tx_start_rejected_when_tx_disabled():
from ria_toolkit_oss.agent.config import AgentConfig
async def scenario():
ws = FakeWs()
s = Streamer(ws=ws, sdr_factory=_factory, cfg=AgentConfig(tx_enabled=False))
await s.on_message(
{
"type": "tx_start",
"app_id": "a",
"radio_config": {"device": "mock", "tx_center_frequency": 2.45e9, "tx_gain": -20},
}
)
return ws
ws = asyncio.run(scenario())
tx_statuses = [m for m in ws.json_sent if m.get("type") == "tx_status"]
assert tx_statuses, "expected a tx_status frame"
assert tx_statuses[-1]["state"] == "error"
assert "disabled" in tx_statuses[-1]["message"].lower()

View File

@ -0,0 +1,133 @@
"""TX streaming happy path + shutdown semantics."""
from __future__ import annotations
import asyncio
import time
import numpy as np
from ria_toolkit_oss.agent.config import AgentConfig
from ria_toolkit_oss.agent.streamer import Streamer
from ria_toolkit_oss.sdr.mock import MockSDR
class RecordingMockSDR(MockSDR):
"""MockSDR that records each TX callback's returned buffer."""
def __init__(self, buffer_size: int):
super().__init__(buffer_size=buffer_size)
self.tx_produced: list[np.ndarray] = []
def _stream_tx(self, callback) -> None:
self._enable_tx = True
self._tx_initialized = True
while self._enable_tx:
result = callback(self.rx_buffer_size)
self.tx_produced.append(np.asarray(result))
time.sleep(0.005)
class FakeWs:
def __init__(self):
self.json_sent: list[dict] = []
self.bytes_sent: list[bytes] = []
async def send_json(self, payload):
self.json_sent.append(payload)
async def send_bytes(self, data):
self.bytes_sent.append(data)
def _iq_frame(samples: np.ndarray) -> bytes:
interleaved = np.empty(samples.size * 2, dtype=np.float32)
interleaved[0::2] = samples.real
interleaved[1::2] = samples.imag
return interleaved.tobytes()
def test_tx_start_streams_binary_to_callback():
BUF = 16
sdr = RecordingMockSDR(buffer_size=BUF)
async def scenario():
ws = FakeWs()
s = Streamer(ws=ws, sdr_factory=lambda d, i: sdr, cfg=AgentConfig(tx_enabled=True))
# Frames of distinct content so we can assert ordering.
frame_a = np.arange(BUF, dtype=np.complex64) * (1 + 0j)
frame_b = (np.arange(BUF, dtype=np.complex64) + BUF) * (1 + 0j)
frame_c = (np.arange(BUF, dtype=np.complex64) + 2 * BUF) * (1 + 0j)
await s.on_message(
{
"type": "tx_start",
"app_id": "app-1",
"radio_config": {
"device": "mock",
"buffer_size": BUF,
"tx_sample_rate": 1_000_000,
"tx_center_frequency": 2.45e9,
"tx_gain": -20,
"underrun_policy": "zero",
},
}
)
# Push three IQ frames.
await s.on_binary(_iq_frame(frame_a))
await s.on_binary(_iq_frame(frame_b))
await s.on_binary(_iq_frame(frame_c))
# Let the executor thread consume them.
for _ in range(100):
# At least the 3 real frames, plus any zero-fill from before they
# arrived. We stop once 3 non-trivial buffers are recorded.
nontrivial = [b for b in sdr.tx_produced if np.any(b != 0)]
if len(nontrivial) >= 3:
break
await asyncio.sleep(0.01)
await s.on_message({"type": "tx_stop", "app_id": "app-1"})
return ws, sdr, s
ws, sdr, streamer = asyncio.run(scenario())
nontrivial = [b for b in sdr.tx_produced if np.any(b != 0)]
assert len(nontrivial) >= 3, "expected ≥3 nontrivial TX buffers"
# First three nontrivial buffers match the order we pushed them.
np.testing.assert_array_equal(nontrivial[0], np.arange(BUF, dtype=np.complex64))
np.testing.assert_array_equal(nontrivial[1], np.arange(BUF, 2 * BUF, dtype=np.complex64))
np.testing.assert_array_equal(nontrivial[2], np.arange(2 * BUF, 3 * BUF, dtype=np.complex64))
# Lifecycle: armed → transmitting → done.
states = [m["state"] for m in ws.json_sent if m.get("type") == "tx_status"]
assert states[0] == "armed"
assert "transmitting" in states
assert states[-1] == "done"
# Session cleared.
assert streamer._tx is None
def test_tx_stop_releases_sdr():
sdr = RecordingMockSDR(buffer_size=8)
async def scenario():
ws = FakeWs()
s = Streamer(ws=ws, sdr_factory=lambda d, i: sdr, cfg=AgentConfig(tx_enabled=True))
await s.on_message(
{
"type": "tx_start",
"app_id": "a",
"radio_config": {"device": "mock", "buffer_size": 8, "underrun_policy": "zero"},
}
)
await asyncio.sleep(0.03)
await s.on_message({"type": "tx_stop", "app_id": "a"})
return s
s = asyncio.run(scenario())
# After stop, the registry has no outstanding references to ("mock", None).
assert s._registry.refcount(("mock", None)) == 0
assert s._tx is None

View File

@ -0,0 +1,167 @@
"""Agent-side TX interlocks: gain cap, freq ranges, duplicate sessions, disabled."""
from __future__ import annotations
import asyncio
from ria_toolkit_oss.agent.config import AgentConfig
from ria_toolkit_oss.agent.streamer import Streamer
from ria_toolkit_oss.sdr.mock import MockSDR
class FakeWs:
def __init__(self):
self.json_sent = []
self.bytes_sent = []
async def send_json(self, p):
self.json_sent.append(p)
async def send_bytes(self, b):
self.bytes_sent.append(b)
def _last_tx_status(ws):
frames = [m for m in ws.json_sent if m.get("type") == "tx_status"]
return frames[-1] if frames else None
def _tx_start(app_id="a", **radio):
rc = {"device": "mock", "buffer_size": 16, "underrun_policy": "zero"}
rc.update(radio)
return {"type": "tx_start", "app_id": app_id, "radio_config": rc}
def _make_streamer(cfg):
built: list = []
def factory(device, identifier):
sdr = MockSDR(buffer_size=16)
built.append(sdr)
return sdr
ws = FakeWs()
s = Streamer(ws=ws, sdr_factory=factory, cfg=cfg)
return s, ws, built
def test_rejects_when_tx_disabled():
async def scenario():
s, ws, built = _make_streamer(AgentConfig(tx_enabled=False))
await s.on_message(_tx_start(tx_gain=-20, tx_center_frequency=2.45e9))
return s, ws, built
s, ws, built = asyncio.run(scenario())
status = _last_tx_status(ws)
assert status and status["state"] == "error"
assert "disabled" in status["message"].lower()
assert not built, "SDR should never have been constructed"
assert s._tx is None
def test_rejects_when_tx_gain_exceeds_cap():
async def scenario():
s, ws, built = _make_streamer(AgentConfig(tx_enabled=True, tx_max_gain_db=-15.0))
await s.on_message(_tx_start(tx_gain=-5, tx_center_frequency=2.45e9))
return ws, built
ws, built = asyncio.run(scenario())
status = _last_tx_status(ws)
assert status and status["state"] == "error"
assert "exceeds cap" in status["message"]
assert not built
def test_allows_gain_at_cap_boundary():
async def scenario():
s, ws, _ = _make_streamer(AgentConfig(tx_enabled=True, tx_max_gain_db=-10.0))
await s.on_message(_tx_start(tx_gain=-10, tx_center_frequency=2.45e9))
# Stop promptly to avoid keeping an executor thread around.
await asyncio.sleep(0.02)
await s.on_message({"type": "tx_stop", "app_id": "a"})
return ws
ws = asyncio.run(scenario())
states = [m["state"] for m in ws.json_sent if m.get("type") == "tx_status"]
assert "armed" in states
assert states[-1] == "done"
def test_rejects_when_freq_outside_ranges():
async def scenario():
s, ws, built = _make_streamer(
AgentConfig(
tx_enabled=True,
tx_allowed_freq_ranges=[[2.4e9, 2.5e9]],
)
)
await s.on_message(_tx_start(tx_center_frequency=5.8e9, tx_gain=-20))
return ws, built
ws, built = asyncio.run(scenario())
status = _last_tx_status(ws)
assert status and status["state"] == "error"
assert "outside allowed ranges" in status["message"]
assert not built
def test_allows_freq_inside_a_range():
async def scenario():
s, ws, _ = _make_streamer(
AgentConfig(
tx_enabled=True,
tx_allowed_freq_ranges=[[2.4e9, 2.5e9], [5.7e9, 5.8e9]],
)
)
await s.on_message(_tx_start(tx_center_frequency=5.75e9, tx_gain=-20))
await asyncio.sleep(0.02)
await s.on_message({"type": "tx_stop", "app_id": "a"})
return ws
ws = asyncio.run(scenario())
states = [m["state"] for m in ws.json_sent if m.get("type") == "tx_status"]
assert "armed" in states
assert states[-1] == "done"
def test_rejects_duplicate_tx_session():
async def scenario():
s, ws, _ = _make_streamer(AgentConfig(tx_enabled=True))
await s.on_message(_tx_start(app_id="a", tx_gain=-20, tx_center_frequency=2.45e9))
await asyncio.sleep(0.01)
await s.on_message(_tx_start(app_id="b", tx_gain=-20, tx_center_frequency=2.45e9))
# Let the second request process, then stop cleanly.
await asyncio.sleep(0.01)
await s.on_message({"type": "tx_stop", "app_id": "a"})
return ws
ws = asyncio.run(scenario())
errors = [
m for m in ws.json_sent
if m.get("type") == "tx_status" and m.get("state") == "error"
]
assert any("already active" in e.get("message", "") for e in errors)
def test_rejects_invalid_underrun_policy():
async def scenario():
s, ws, _ = _make_streamer(AgentConfig(tx_enabled=True))
await s.on_message(
{
"type": "tx_start",
"app_id": "a",
"radio_config": {
"device": "mock",
"buffer_size": 8,
"tx_gain": -20,
"tx_center_frequency": 2.45e9,
"underrun_policy": "teleport",
},
}
)
return ws
ws = asyncio.run(scenario())
status = _last_tx_status(ws)
assert status and status["state"] == "error"
assert "underrun_policy" in status["message"]

View File

@ -0,0 +1,136 @@
"""Underrun policies: pause, zero, repeat."""
from __future__ import annotations
import asyncio
import time
import numpy as np
from ria_toolkit_oss.agent.config import AgentConfig
from ria_toolkit_oss.agent.streamer import Streamer
from ria_toolkit_oss.sdr.mock import MockSDR
class RecordingMockSDR(MockSDR):
def __init__(self, buffer_size: int):
super().__init__(buffer_size=buffer_size)
self.tx_produced: list[np.ndarray] = []
def _stream_tx(self, callback):
self._enable_tx = True
self._tx_initialized = True
while self._enable_tx:
result = callback(self.rx_buffer_size)
self.tx_produced.append(np.asarray(result).copy())
time.sleep(0.005)
class FakeWs:
def __init__(self):
self.json_sent = []
self.bytes_sent = []
async def send_json(self, p):
self.json_sent.append(p)
async def send_bytes(self, b):
self.bytes_sent.append(b)
def _iq_frame(samples: np.ndarray) -> bytes:
interleaved = np.empty(samples.size * 2, dtype=np.float32)
interleaved[0::2] = samples.real
interleaved[1::2] = samples.imag
return interleaved.tobytes()
def _start_cfg(policy: str, buf: int = 8) -> dict:
return {
"type": "tx_start",
"app_id": "a",
"radio_config": {
"device": "mock",
"buffer_size": buf,
"tx_gain": -20,
"tx_center_frequency": 2.45e9,
"underrun_policy": policy,
},
}
def test_underrun_pause_stops_session_and_emits_status():
sdr = RecordingMockSDR(buffer_size=8)
async def scenario():
ws = FakeWs()
s = Streamer(ws=ws, sdr_factory=lambda d, i: sdr, cfg=AgentConfig(tx_enabled=True))
await s.on_message(_start_cfg("pause"))
# Do not push any buffers. The callback underruns on first tick and
# the watchdog should emit "underrun" and tear down.
for _ in range(100):
if any(
m.get("type") == "tx_status" and m.get("state") == "underrun"
for m in ws.json_sent
):
break
await asyncio.sleep(0.01)
for _ in range(50):
if s._tx is None:
break
await asyncio.sleep(0.01)
return ws, s
ws, s = asyncio.run(scenario())
states = [m["state"] for m in ws.json_sent if m.get("type") == "tx_status"]
assert "underrun" in states
assert s._tx is None
def test_underrun_zero_keeps_session_alive():
sdr = RecordingMockSDR(buffer_size=8)
async def scenario():
ws = FakeWs()
s = Streamer(ws=ws, sdr_factory=lambda d, i: sdr, cfg=AgentConfig(tx_enabled=True))
await s.on_message(_start_cfg("zero"))
# Let it produce several underrun-filled buffers.
await asyncio.sleep(0.08)
still_alive = s._tx is not None
await s.on_message({"type": "tx_stop", "app_id": "a"})
return ws, still_alive
ws, still_alive = asyncio.run(scenario())
# No underrun status emitted (policy absorbs it silently).
assert not any(
m.get("type") == "tx_status" and m.get("state") == "underrun" for m in ws.json_sent
)
assert still_alive
# All produced buffers are zero (no real data was pushed).
assert sdr.tx_produced, "expected at least one TX callback invocation"
assert all(not np.any(b != 0) for b in sdr.tx_produced)
def test_underrun_repeat_replays_last_buffer():
BUF = 8
sdr = RecordingMockSDR(buffer_size=BUF)
marker = np.arange(BUF, dtype=np.complex64) + 1 # distinct non-zero buffer
async def scenario():
ws = FakeWs()
s = Streamer(ws=ws, sdr_factory=lambda d, i: sdr, cfg=AgentConfig(tx_enabled=True))
await s.on_message(_start_cfg("repeat", buf=BUF))
await s.on_binary(_iq_frame(marker))
# Give the executor time to consume the real frame + several repeats.
await asyncio.sleep(0.08)
await s.on_message({"type": "tx_stop", "app_id": "a"})
return ws, sdr
ws, sdr = asyncio.run(scenario())
# No underrun status emitted.
assert not any(
m.get("type") == "tx_status" and m.get("state") == "underrun" for m in ws.json_sent
)
# At least two buffers equal to the marker — the real one and ≥1 repeat.
matching = [b for b in sdr.tx_produced if np.array_equal(b, marker)]
assert len(matching) >= 2, f"expected ≥2 buffers matching marker, got {len(matching)}"

View File

@ -113,6 +113,109 @@ def test_reconnects_after_server_drop():
assert n >= 2 assert n >= 2
def test_binary_frame_forwarded_to_handler():
payload = bytes(range(128))
async def scenario():
received: list[bytes] = []
done = asyncio.Event()
async def handler(ws):
await ws.send(payload)
done.set()
try:
await ws.wait_closed()
except Exception:
pass
server, port = await _open_server(handler)
try:
client = WsClient(
f"ws://127.0.0.1:{port}",
token="",
heartbeat_interval=10.0,
reconnect_pause=0.05,
)
async def on_bin(data):
received.append(data)
task = asyncio.create_task(
client.run(
on_message=lambda _m: asyncio.sleep(0),
heartbeat=lambda: {"type": "heartbeat"},
on_binary=on_bin,
)
)
for _ in range(50):
if received:
break
await asyncio.sleep(0.02)
client.stop()
task.cancel()
try:
await task
except (asyncio.CancelledError, Exception):
pass
finally:
server.close()
await server.wait_closed()
return received
received = asyncio.run(scenario())
assert received == [payload]
def test_binary_frame_dropped_when_no_handler():
# Regression guard: existing behavior (drop server-sent binary) preserved when
# on_binary is not supplied.
async def scenario():
crashes: list[Exception] = []
async def handler(ws):
await ws.send(b"\x00\x01\x02\x03")
await ws.send(json.dumps({"type": "ping"}))
try:
await ws.wait_closed()
except Exception:
pass
messages: list[dict] = []
server, port = await _open_server(handler)
try:
client = WsClient(
f"ws://127.0.0.1:{port}",
token="",
heartbeat_interval=10.0,
reconnect_pause=0.05,
)
async def on_msg(m):
messages.append(m)
task = asyncio.create_task(
client.run(on_message=on_msg, heartbeat=lambda: {"type": "heartbeat"})
)
for _ in range(50):
if messages:
break
await asyncio.sleep(0.02)
client.stop()
task.cancel()
try:
await task
except (asyncio.CancelledError, Exception) as exc:
crashes.append(exc)
finally:
server.close()
await server.wait_closed()
return messages, crashes
messages, crashes = asyncio.run(scenario())
# JSON still delivered; binary silently dropped; no uncaught crash.
assert messages and messages[0] == {"type": "ping"}
def test_malformed_control_frame_does_not_crash(): def test_malformed_control_frame_does_not_crash():
async def scenario(): async def scenario():
handled: list[dict] = [] handled: list[dict] = []