qac-cli-commands #26
|
|
@ -5,8 +5,8 @@ Subcommands:
|
||||||
- ``ria-agent run [legacy args]`` — legacy long-poll NodeAgent (unchanged).
|
- ``ria-agent run [legacy args]`` — legacy long-poll NodeAgent (unchanged).
|
||||||
- ``ria-agent stream`` — new WebSocket-based IQ streamer.
|
- ``ria-agent stream`` — new WebSocket-based IQ streamer.
|
||||||
- ``ria-agent detect`` — print SDR drivers whose modules import cleanly.
|
- ``ria-agent detect`` — print SDR drivers whose modules import cleanly.
|
||||||
- ``ria-agent register --url URL --token TOKEN`` — save credentials to
|
- ``ria-agent register --hub URL --api-key KEY`` — register with the hub and
|
||||||
``~/.ria/agent.json``.
|
save credentials (and optional TX interlocks) to ``~/.ria/agent.json``.
|
||||||
|
|
||||||
Invoking ``ria-agent`` with no subcommand falls through to the legacy
|
Invoking ``ria-agent`` with no subcommand falls through to the legacy
|
||||||
long-poll behavior for back-compatibility with existing deployments.
|
long-poll behavior for back-compatibility with existing deployments.
|
||||||
|
|
@ -69,9 +69,27 @@ def _cmd_register(args: argparse.Namespace) -> int:
|
||||||
if args.name:
|
if args.name:
|
||||||
cfg.name = args.name
|
cfg.name = args.name
|
||||||
cfg.insecure = bool(args.insecure)
|
cfg.insecure = bool(args.insecure)
|
||||||
|
cfg.tx_enabled = bool(getattr(args, "allow_tx", False))
|
||||||
|
if (v := getattr(args, "tx_max_gain_db", None)) is not None:
|
||||||
|
cfg.tx_max_gain_db = float(v)
|
||||||
|
if (v := getattr(args, "tx_max_duration_s", None)) is not None:
|
||||||
|
cfg.tx_max_duration_s = float(v)
|
||||||
|
freq_ranges = getattr(args, "tx_freq_range", None) or []
|
||||||
|
if freq_ranges:
|
||||||
|
cfg.tx_allowed_freq_ranges = [[float(lo), float(hi)] for lo, hi in freq_ranges]
|
||||||
path = _config.save(cfg)
|
path = _config.save(cfg)
|
||||||
|
|
||||||
print(f"Registered agent: {agent_id}")
|
print(f"Registered agent: {agent_id}")
|
||||||
|
if cfg.tx_enabled:
|
||||||
|
caps: list[str] = []
|
||||||
|
if cfg.tx_max_gain_db is not None:
|
||||||
|
caps.append(f"gain<={cfg.tx_max_gain_db} dB")
|
||||||
|
if cfg.tx_max_duration_s is not None:
|
||||||
|
caps.append(f"duration<={cfg.tx_max_duration_s} s")
|
||||||
|
if cfg.tx_allowed_freq_ranges:
|
||||||
|
caps.append(f"freq in {cfg.tx_allowed_freq_ranges}")
|
||||||
|
tail = f" ({', '.join(caps)})" if caps else ""
|
||||||
|
print(f"TX enabled{tail}")
|
||||||
print(f"Credentials saved to {path}")
|
print(f"Credentials saved to {path}")
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
|
|
@ -85,8 +103,10 @@ def _cmd_stream(args: argparse.Namespace) -> int:
|
||||||
if not url:
|
if not url:
|
||||||
print("error: --url is required (or run `ria-agent register` first)", file=sys.stderr)
|
print("error: --url is required (or run `ria-agent register` first)", file=sys.stderr)
|
||||||
return 2
|
return 2
|
||||||
|
if getattr(args, "allow_tx", False):
|
||||||
|
cfg.tx_enabled = True
|
||||||
try:
|
try:
|
||||||
asyncio.run(run_streamer(url, token))
|
asyncio.run(run_streamer(url, token, cfg=cfg))
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
pass
|
pass
|
||||||
return 0
|
return 0
|
||||||
|
|
@ -123,11 +143,47 @@ def main() -> None:
|
||||||
p_reg.add_argument("--api-key", dest="api_key", required=True, help="Hub API key")
|
p_reg.add_argument("--api-key", dest="api_key", required=True, help="Hub API key")
|
||||||
p_reg.add_argument("--name", default=None, help="Human-friendly agent name")
|
p_reg.add_argument("--name", default=None, help="Human-friendly agent name")
|
||||||
p_reg.add_argument("--insecure", action="store_true", help="Skip TLS verification")
|
p_reg.add_argument("--insecure", action="store_true", help="Skip TLS verification")
|
||||||
|
p_reg.add_argument(
|
||||||
|
"--allow-tx",
|
||||||
|
dest="allow_tx",
|
||||||
|
action="store_true",
|
||||||
|
help="Opt this agent in to TX (required for any transmission from the hub)",
|
||||||
|
)
|
||||||
|
p_reg.add_argument(
|
||||||
|
"--tx-max-gain-db",
|
||||||
|
dest="tx_max_gain_db",
|
||||||
|
type=float,
|
||||||
|
default=None,
|
||||||
|
help="Reject tx_start frames whose tx_gain exceeds this cap (dB)",
|
||||||
|
)
|
||||||
|
p_reg.add_argument(
|
||||||
|
"--tx-max-duration-s",
|
||||||
|
dest="tx_max_duration_s",
|
||||||
|
type=float,
|
||||||
|
default=None,
|
||||||
|
help="Auto-stop any TX session after this many seconds",
|
||||||
|
)
|
||||||
|
p_reg.add_argument(
|
||||||
|
"--tx-freq-range",
|
||||||
|
dest="tx_freq_range",
|
||||||
|
type=float,
|
||||||
|
nargs=2,
|
||||||
|
action="append",
|
||||||
|
metavar=("LO", "HI"),
|
||||||
|
default=None,
|
||||||
|
help="Allowed TX center-frequency range in Hz (repeat for multiple bands)",
|
||||||
|
)
|
||||||
|
|
||||||
p_stream = sub.add_parser("stream", help="Run the WebSocket IQ streamer")
|
p_stream = sub.add_parser("stream", help="Run the WebSocket IQ streamer")
|
||||||
p_stream.add_argument("--url", default=None, help="Override WebSocket URL")
|
p_stream.add_argument("--url", default=None, help="Override WebSocket URL")
|
||||||
p_stream.add_argument("--token", default=None, help="Override bearer token")
|
p_stream.add_argument("--token", default=None, help="Override bearer token")
|
||||||
p_stream.add_argument("--log-level", default="INFO")
|
p_stream.add_argument("--log-level", default="INFO")
|
||||||
|
p_stream.add_argument(
|
||||||
|
"--allow-tx",
|
||||||
|
dest="allow_tx",
|
||||||
|
action="store_true",
|
||||||
|
help="Runtime override: enable TX for this process without writing config",
|
||||||
|
)
|
||||||
|
|
||||||
# Unknown extras are forwarded to the legacy CLI when command == "run".
|
# Unknown extras are forwarded to the legacy CLI when command == "run".
|
||||||
args, extras = parser.parse_known_args(argv)
|
args, extras = parser.parse_known_args(argv)
|
||||||
|
|
|
||||||
|
|
@ -7,7 +7,11 @@ Schema::
|
||||||
"agent_id": "agent-abc123",
|
"agent_id": "agent-abc123",
|
||||||
"token": "rha_xxxx",
|
"token": "rha_xxxx",
|
||||||
"name": "lab-bench-1",
|
"name": "lab-bench-1",
|
||||||
"insecure": false
|
"insecure": false,
|
||||||
|
"tx_enabled": false,
|
||||||
|
"tx_max_gain_db": null,
|
||||||
|
"tx_max_duration_s": null,
|
||||||
|
"tx_allowed_freq_ranges": null
|
||||||
}
|
}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
@ -18,7 +22,8 @@ import os
|
||||||
from dataclasses import asdict, dataclass, field
|
from dataclasses import asdict, dataclass, field
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
_DEFAULT_PATH = Path(os.environ.get("RIA_AGENT_CONFIG", str(Path.home() / ".ria" / "agent.json")))
|
def _resolve_default_path() -> Path:
|
||||||
|
return Path(os.environ.get("RIA_AGENT_CONFIG", str(Path.home() / ".ria" / "agent.json")))
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|
@ -29,15 +34,29 @@ class AgentConfig:
|
||||||
name: str = ""
|
name: str = ""
|
||||||
insecure: bool = False
|
insecure: bool = False
|
||||||
api_key: str = ""
|
api_key: str = ""
|
||||||
|
tx_enabled: bool = False
|
||||||
|
tx_max_gain_db: float | None = None
|
||||||
|
tx_max_duration_s: float | None = None
|
||||||
|
tx_allowed_freq_ranges: list[list[float]] | None = None
|
||||||
extra: dict = field(default_factory=dict)
|
extra: dict = field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
def default_path() -> Path:
|
def default_path() -> Path:
|
||||||
return _DEFAULT_PATH
|
return _resolve_default_path()
|
||||||
|
|
||||||
|
|
||||||
|
def _coerce_ranges(raw) -> list[list[float]] | None:
|
||||||
|
if raw is None:
|
||||||
|
return None
|
||||||
|
out: list[list[float]] = []
|
||||||
|
for pair in raw:
|
||||||
|
lo, hi = pair
|
||||||
|
out.append([float(lo), float(hi)])
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
def load(path: Path | None = None) -> AgentConfig:
|
def load(path: Path | None = None) -> AgentConfig:
|
||||||
p = path or _DEFAULT_PATH
|
p = path or _resolve_default_path()
|
||||||
if not p.exists():
|
if not p.exists():
|
||||||
return AgentConfig()
|
return AgentConfig()
|
||||||
data = json.loads(p.read_text())
|
data = json.loads(p.read_text())
|
||||||
|
|
@ -50,12 +69,16 @@ def load(path: Path | None = None) -> AgentConfig:
|
||||||
name=data.get("name", ""),
|
name=data.get("name", ""),
|
||||||
insecure=bool(data.get("insecure", False)),
|
insecure=bool(data.get("insecure", False)),
|
||||||
api_key=data.get("api_key", ""),
|
api_key=data.get("api_key", ""),
|
||||||
|
tx_enabled=bool(data.get("tx_enabled", False)),
|
||||||
|
tx_max_gain_db=(float(v) if (v := data.get("tx_max_gain_db")) is not None else None),
|
||||||
|
tx_max_duration_s=(float(v) if (v := data.get("tx_max_duration_s")) is not None else None),
|
||||||
|
tx_allowed_freq_ranges=_coerce_ranges(data.get("tx_allowed_freq_ranges")),
|
||||||
extra=extra,
|
extra=extra,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def save(cfg: AgentConfig, path: Path | None = None) -> Path:
|
def save(cfg: AgentConfig, path: Path | None = None) -> Path:
|
||||||
p = path or _DEFAULT_PATH
|
p = path or _resolve_default_path()
|
||||||
p.parent.mkdir(parents=True, exist_ok=True)
|
p.parent.mkdir(parents=True, exist_ok=True)
|
||||||
data = asdict(cfg)
|
data = asdict(cfg)
|
||||||
extra = data.pop("extra", {}) or {}
|
extra = data.pop("extra", {}) or {}
|
||||||
|
|
|
||||||
|
|
@ -4,19 +4,41 @@ from __future__ import annotations
|
||||||
|
|
||||||
from ria_toolkit_oss.sdr import detect_available
|
from ria_toolkit_oss.sdr import detect_available
|
||||||
|
|
||||||
|
from .config import AgentConfig
|
||||||
|
|
||||||
|
|
||||||
def available_devices() -> list[str]:
|
def available_devices() -> list[str]:
|
||||||
"""Return a sorted list of device names whose driver modules import cleanly."""
|
"""Return a sorted list of device names whose driver modules import cleanly."""
|
||||||
return sorted(detect_available().keys())
|
return sorted(detect_available().keys())
|
||||||
|
|
||||||
|
|
||||||
def heartbeat_payload(status: str = "idle", app_id: str | None = None) -> dict:
|
def heartbeat_payload(
|
||||||
"""Build the JSON body of a periodic heartbeat frame."""
|
status: str = "idle",
|
||||||
|
app_id: str | None = None,
|
||||||
|
*,
|
||||||
|
cfg: AgentConfig | None = None,
|
||||||
|
sessions: dict | None = None,
|
||||||
|
) -> dict:
|
||||||
|
"""Build the JSON body of a periodic heartbeat frame.
|
||||||
|
|
||||||
|
*cfg* drives the ``capabilities`` list and the ``tx_enabled`` flag. If not
|
||||||
|
supplied, the heartbeat advertises RX-only with ``tx_enabled=False`` —
|
||||||
|
matching the pre-TX shape.
|
||||||
|
"""
|
||||||
|
c = cfg or AgentConfig()
|
||||||
|
capabilities = ["rx"]
|
||||||
|
if c.tx_enabled:
|
||||||
|
capabilities.append("tx")
|
||||||
|
|
||||||
payload: dict = {
|
payload: dict = {
|
||||||
"type": "heartbeat",
|
"type": "heartbeat",
|
||||||
"hardware": available_devices(),
|
"hardware": available_devices(),
|
||||||
"status": status,
|
"status": status,
|
||||||
|
"capabilities": capabilities,
|
||||||
|
"tx_enabled": bool(c.tx_enabled),
|
||||||
}
|
}
|
||||||
if app_id:
|
if app_id:
|
||||||
payload["app_id"] = app_id
|
payload["app_id"] = app_id
|
||||||
|
if sessions:
|
||||||
|
payload["sessions"] = sessions
|
||||||
return payload
|
return payload
|
||||||
|
|
|
||||||
|
|
@ -1,20 +1,33 @@
|
||||||
"""Thin IQ-streaming agent.
|
"""IQ-streaming agent.
|
||||||
|
|
||||||
Listens for control messages from the RIA Hub over a persistent WebSocket.
|
Listens for control messages from the RIA Hub over a persistent WebSocket.
|
||||||
When the server sends ``start``, opens the SDR described in ``radio_config``,
|
Supports:
|
||||||
loops over ``sdr.rx(buffer_size)``, and sends each buffer as raw
|
|
||||||
interleaved float32 bytes. ``stop`` closes the SDR; ``configure`` applies
|
- An **RX session** (hub sends ``start``/``stop``/``configure``; agent opens
|
||||||
parameter updates at the next capture boundary.
|
the SDR, loops ``sdr.rx()`` and ships raw interleaved float32 IQ).
|
||||||
|
- A **TX session** (hub sends ``tx_start``/``tx_stop``/``tx_configure`` plus
|
||||||
|
binary IQ frames; agent feeds them into ``sdr._stream_tx``). Phase 3 wires
|
||||||
|
up the session plumbing and rejects TX when ``cfg.tx_enabled`` is False;
|
||||||
|
Phase 4 implements the full TX loop.
|
||||||
|
|
||||||
|
Both sessions can run concurrently on the same physical SDR (FDD) — a
|
||||||
|
ref-counted SDR registry shares one driver instance when RX and TX name the
|
||||||
|
same ``(device, identifier)``.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
|
import queue
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
from dataclasses import dataclass, field
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
from .config import AgentConfig
|
||||||
from .hardware import heartbeat_payload
|
from .hardware import heartbeat_payload
|
||||||
from .ws_client import WsClient
|
from .ws_client import WsClient
|
||||||
|
|
||||||
|
|
@ -23,6 +36,98 @@ logger = logging.getLogger("ria_agent.streamer")
|
||||||
_DEFAULT_BUFFER_SIZE = 1024
|
_DEFAULT_BUFFER_SIZE = 1024
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Session dataclasses
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class RxSession:
|
||||||
|
app_id: str
|
||||||
|
sdr: Any
|
||||||
|
device_key: tuple[str, str | None]
|
||||||
|
buffer_size: int
|
||||||
|
task: asyncio.Task | None = None
|
||||||
|
pending_config: dict = field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TxSession:
|
||||||
|
app_id: str
|
||||||
|
sdr: Any
|
||||||
|
device_key: tuple[str, str | None]
|
||||||
|
buffer_size: int
|
||||||
|
task: Any = None # concurrent.futures.Future from run_in_executor
|
||||||
|
pending_config: dict = field(default_factory=dict)
|
||||||
|
underrun_policy: str = "pause"
|
||||||
|
last_buffer: np.ndarray | None = None
|
||||||
|
stop_event: threading.Event = field(default_factory=threading.Event)
|
||||||
|
started_at: float = 0.0
|
||||||
|
max_duration_s: float | None = None
|
||||||
|
state: str = "armed"
|
||||||
|
# Thread-safe queue of inbound interleaved-float32 IQ frames. Bounded so
|
||||||
|
# hub-side over-production triggers WS backpressure rather than memory
|
||||||
|
# growth in the agent.
|
||||||
|
in_queue: "queue.Queue[bytes]" = field(default_factory=lambda: queue.Queue(maxsize=8))
|
||||||
|
# Set by the TX callback when it hits an underrun while policy=="pause";
|
||||||
|
# asyncio side flips the session state and emits tx_status.
|
||||||
|
underrun_flag: threading.Event = field(default_factory=threading.Event)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# SDR registry (ref-counted so one Pluto handle serves RX + TX simultaneously)
|
||||||
|
|
||||||
|
|
||||||
|
class _SdrRegistry:
|
||||||
|
def __init__(self, factory):
|
||||||
|
self._factory = factory
|
||||||
|
self._instances: dict[tuple[str, str | None], tuple[Any, int]] = {}
|
||||||
|
self._lock = threading.Lock()
|
||||||
|
|
||||||
|
def acquire(self, device: str, identifier: str | None) -> tuple[Any, tuple[str, str | None]]:
|
||||||
|
key = (device, identifier)
|
||||||
|
with self._lock:
|
||||||
|
if key in self._instances:
|
||||||
|
sdr, rc = self._instances[key]
|
||||||
|
self._instances[key] = (sdr, rc + 1)
|
||||||
|
return sdr, key
|
||||||
|
# Build outside the lock: driver init can be slow and we don't want to
|
||||||
|
# block concurrent releases on unrelated devices.
|
||||||
|
sdr = self._factory(device, identifier)
|
||||||
|
with self._lock:
|
||||||
|
if key in self._instances:
|
||||||
|
# Raced another acquirer; discard our duplicate and share theirs.
|
||||||
|
other_sdr, rc = self._instances[key]
|
||||||
|
try:
|
||||||
|
sdr.close()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
self._instances[key] = (other_sdr, rc + 1)
|
||||||
|
return other_sdr, key
|
||||||
|
self._instances[key] = (sdr, 1)
|
||||||
|
return sdr, key
|
||||||
|
|
||||||
|
def release(self, key: tuple[str, str | None]) -> bool:
|
||||||
|
"""Decrement refcount. Returns True if the caller owns the last reference
|
||||||
|
and should close the SDR."""
|
||||||
|
with self._lock:
|
||||||
|
sdr, rc = self._instances.get(key, (None, 0))
|
||||||
|
if sdr is None:
|
||||||
|
return False
|
||||||
|
if rc <= 1:
|
||||||
|
del self._instances[key]
|
||||||
|
return True
|
||||||
|
self._instances[key] = (sdr, rc - 1)
|
||||||
|
return False
|
||||||
|
|
||||||
|
def refcount(self, key: tuple[str, str | None]) -> int:
|
||||||
|
with self._lock:
|
||||||
|
return self._instances.get(key, (None, 0))[1]
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Streamer
|
||||||
|
|
||||||
|
|
||||||
class Streamer:
|
class Streamer:
|
||||||
"""Main streamer loop.
|
"""Main streamer loop.
|
||||||
|
|
||||||
|
|
@ -31,103 +136,186 @@ class Streamer:
|
||||||
ws:
|
ws:
|
||||||
Connected :class:`WsClient`.
|
Connected :class:`WsClient`.
|
||||||
sdr_factory:
|
sdr_factory:
|
||||||
Callable ``(device, identifier) -> SDR``. Defaults to
|
Callable ``(device, identifier) -> SDR``. Defaults to the helper in
|
||||||
:func:`ria_toolkit_oss.sdr.get_sdr_device`. Injectable for tests.
|
:mod:`ria_toolkit_oss.sdr`. Injectable for tests.
|
||||||
|
cfg:
|
||||||
|
:class:`AgentConfig` for interlocks (``tx_enabled`` and caps) and
|
||||||
|
heartbeat capabilities. Defaults to an empty ``AgentConfig()`` which
|
||||||
|
leaves TX disabled.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, ws: WsClient, sdr_factory=None) -> None:
|
def __init__(
|
||||||
|
self,
|
||||||
|
ws,
|
||||||
|
sdr_factory=None,
|
||||||
|
cfg: AgentConfig | None = None,
|
||||||
|
) -> None:
|
||||||
self.ws = ws
|
self.ws = ws
|
||||||
self._sdr_factory = sdr_factory
|
self._cfg = cfg or AgentConfig()
|
||||||
self._app_id: str | None = None
|
self._registry = _SdrRegistry(sdr_factory or _default_sdr_factory)
|
||||||
self._sdr: Any = None
|
self._rx: RxSession | None = None
|
||||||
self._pending_config: dict = {}
|
self._tx: TxSession | None = None
|
||||||
self._capture_task: asyncio.Task | None = None
|
# Pending radio_config accepted via ``configure`` before ``start``.
|
||||||
self._status = "idle"
|
self._standalone_pending_config: dict = {}
|
||||||
|
# Cached asyncio event loop, set the first time a handler runs. Used
|
||||||
|
# to schedule async callbacks from the TX executor thread.
|
||||||
|
self._loop: asyncio.AbstractEventLoop | None = None
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Back-compat read-only shims for callers that check ``._sdr`` etc.
|
||||||
|
# Writes to these attributes are not supported — use the session objects.
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _sdr(self):
|
||||||
|
return self._rx.sdr if self._rx is not None else None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _pending_config(self) -> dict:
|
||||||
|
return self._rx.pending_config if self._rx is not None else self._standalone_pending_config
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
# WsClient wiring
|
# WsClient wiring
|
||||||
|
|
||||||
def build_heartbeat(self) -> dict:
|
def build_heartbeat(self) -> dict:
|
||||||
return heartbeat_payload(status=self._status, app_id=self._app_id)
|
status = "streaming" if (self._rx is not None or self._tx is not None) else "idle"
|
||||||
|
app_id: str | None = None
|
||||||
|
if self._rx is not None:
|
||||||
|
app_id = self._rx.app_id
|
||||||
|
elif self._tx is not None:
|
||||||
|
app_id = self._tx.app_id
|
||||||
|
|
||||||
|
sessions: dict[str, dict] = {}
|
||||||
|
if self._rx is not None:
|
||||||
|
sessions["rx"] = {"app_id": self._rx.app_id, "state": "streaming"}
|
||||||
|
if self._tx is not None:
|
||||||
|
sessions["tx"] = {"app_id": self._tx.app_id, "state": self._tx.state}
|
||||||
|
|
||||||
|
return heartbeat_payload(
|
||||||
|
status=status,
|
||||||
|
app_id=app_id,
|
||||||
|
cfg=self._cfg,
|
||||||
|
sessions=sessions or None,
|
||||||
|
)
|
||||||
|
|
||||||
async def on_message(self, msg: dict) -> None:
|
async def on_message(self, msg: dict) -> None:
|
||||||
t = msg.get("type")
|
t = msg.get("type")
|
||||||
if t == "start":
|
handler = {
|
||||||
await self._handle_start(msg)
|
"start": self._handle_rx_start,
|
||||||
elif t == "stop":
|
"stop": self._handle_rx_stop,
|
||||||
await self._handle_stop(msg)
|
"configure": self._handle_rx_configure,
|
||||||
elif t == "configure":
|
"tx_start": self._handle_tx_start,
|
||||||
self._pending_config.update(msg.get("radio_config") or {})
|
"tx_stop": self._handle_tx_stop,
|
||||||
logger.debug("Queued configure: %s", self._pending_config)
|
"tx_configure": self._handle_tx_configure,
|
||||||
else:
|
}.get(t)
|
||||||
|
if handler is None:
|
||||||
logger.warning("Unknown server message type: %r", t)
|
logger.warning("Unknown server message type: %r", t)
|
||||||
|
return
|
||||||
|
await handler(msg)
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
async def on_binary(self, data: bytes) -> None:
|
||||||
async def _handle_start(self, msg: dict) -> None:
|
tx = self._tx
|
||||||
if self._capture_task is not None and not self._capture_task.done():
|
if tx is None:
|
||||||
|
logger.debug("Dropping %d-byte binary frame: no TX session", len(data))
|
||||||
|
return
|
||||||
|
# Backpressure: if the TX queue is full, await briefly so the hub's
|
||||||
|
# ``await ws.send`` throttles naturally via TCP. We don't block
|
||||||
|
# indefinitely — a 2s stall means something else is wrong.
|
||||||
|
loop = asyncio.get_running_loop()
|
||||||
|
try:
|
||||||
|
await loop.run_in_executor(None, lambda: tx.in_queue.put(data, timeout=2.0))
|
||||||
|
except queue.Full:
|
||||||
|
logger.warning("TX queue stalled; dropping frame")
|
||||||
|
|
||||||
|
# ==================================================================
|
||||||
|
# RX
|
||||||
|
|
||||||
|
async def _handle_rx_start(self, msg: dict) -> None:
|
||||||
|
if self._rx is not None:
|
||||||
logger.warning("start received while already streaming — ignoring")
|
logger.warning("start received while already streaming — ignoring")
|
||||||
return
|
return
|
||||||
|
|
||||||
self._app_id = msg.get("app_id")
|
app_id = msg.get("app_id") or ""
|
||||||
radio_config = dict(msg.get("radio_config") or {})
|
radio_config = dict(msg.get("radio_config") or {})
|
||||||
device = radio_config.pop("device", None)
|
device = radio_config.pop("device", None)
|
||||||
identifier = radio_config.pop("identifier", None)
|
identifier = radio_config.pop("identifier", None)
|
||||||
buffer_size = int(radio_config.pop("buffer_size", _DEFAULT_BUFFER_SIZE))
|
buffer_size = int(radio_config.pop("buffer_size", _DEFAULT_BUFFER_SIZE))
|
||||||
if not device:
|
if not device:
|
||||||
await self._send_error("start missing radio_config.device")
|
await self._send_error(app_id, "start missing radio_config.device")
|
||||||
return
|
return
|
||||||
|
|
||||||
try:
|
try:
|
||||||
factory = self._sdr_factory or _default_sdr_factory
|
sdr, device_key = self._registry.acquire(device, identifier)
|
||||||
self._sdr = factory(device, identifier)
|
_apply_sdr_config(sdr, radio_config)
|
||||||
_apply_sdr_config(self._sdr, radio_config)
|
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
logger.exception("Failed to open SDR %r", device)
|
logger.exception("Failed to open SDR %r", device)
|
||||||
await self._send_error(f"SDR init failed: {exc}")
|
await self._send_error(app_id, f"SDR init failed: {exc}")
|
||||||
return
|
return
|
||||||
|
|
||||||
self._status = "streaming"
|
# Inherit any pending config that was queued before start.
|
||||||
await self._send_status("streaming")
|
pending = dict(self._standalone_pending_config)
|
||||||
self._capture_task = asyncio.create_task(
|
self._standalone_pending_config = {}
|
||||||
self._capture_loop(buffer_size), name="ria-streamer-capture"
|
|
||||||
|
session = RxSession(
|
||||||
|
app_id=app_id,
|
||||||
|
sdr=sdr,
|
||||||
|
device_key=device_key,
|
||||||
|
buffer_size=buffer_size,
|
||||||
|
pending_config=pending,
|
||||||
|
)
|
||||||
|
self._rx = session
|
||||||
|
await self._send_status("streaming", app_id)
|
||||||
|
session.task = asyncio.create_task(
|
||||||
|
self._capture_loop(session), name="ria-streamer-capture"
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _handle_stop(self, msg: dict) -> None:
|
async def _handle_rx_stop(self, msg: dict) -> None:
|
||||||
if self._capture_task is not None:
|
session = self._rx
|
||||||
self._capture_task.cancel()
|
if session is None:
|
||||||
|
return
|
||||||
|
if session.task is not None:
|
||||||
|
session.task.cancel()
|
||||||
try:
|
try:
|
||||||
await self._capture_task
|
await session.task
|
||||||
except (asyncio.CancelledError, Exception):
|
except (asyncio.CancelledError, Exception):
|
||||||
pass
|
pass
|
||||||
self._capture_task = None
|
self._close_session_sdr(session)
|
||||||
self._close_sdr()
|
app_id = session.app_id
|
||||||
self._app_id = None
|
self._rx = None
|
||||||
self._status = "idle"
|
await self._send_status("idle", app_id)
|
||||||
await self._send_status("idle")
|
|
||||||
|
|
||||||
async def _capture_loop(self, buffer_size: int) -> None:
|
async def _handle_rx_configure(self, msg: dict) -> None:
|
||||||
|
cfg = dict(msg.get("radio_config") or {})
|
||||||
|
if self._rx is not None:
|
||||||
|
self._rx.pending_config.update(cfg)
|
||||||
|
else:
|
||||||
|
self._standalone_pending_config.update(cfg)
|
||||||
|
logger.debug("Queued configure: %s", cfg)
|
||||||
|
|
||||||
|
async def _capture_loop(self, session: RxSession) -> None:
|
||||||
loop = asyncio.get_running_loop()
|
loop = asyncio.get_running_loop()
|
||||||
try:
|
try:
|
||||||
while True:
|
while True:
|
||||||
if self._pending_config:
|
if session.pending_config:
|
||||||
cfg = self._pending_config
|
cfg = session.pending_config
|
||||||
self._pending_config = {}
|
session.pending_config = {}
|
||||||
try:
|
try:
|
||||||
_apply_sdr_config(self._sdr, cfg)
|
_apply_sdr_config(session.sdr, cfg)
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
logger.warning("Applying configure failed: %s", exc)
|
logger.warning("Applying configure failed: %s", exc)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
samples = await loop.run_in_executor(None, self._sdr.rx, buffer_size)
|
samples = await loop.run_in_executor(
|
||||||
|
None, session.sdr.rx, session.buffer_size
|
||||||
|
)
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
from ria_toolkit_oss.sdr import SdrDisconnectedError
|
from ria_toolkit_oss.sdr import SdrDisconnectedError
|
||||||
|
|
||||||
if isinstance(exc, SdrDisconnectedError):
|
if isinstance(exc, SdrDisconnectedError):
|
||||||
logger.warning("SDR disconnected: %s", exc)
|
logger.warning("SDR disconnected: %s", exc)
|
||||||
await self._send_error(f"SDR disconnected: {exc}")
|
await self._send_error(session.app_id, f"SDR disconnected: {exc}")
|
||||||
else:
|
else:
|
||||||
logger.exception("SDR rx error")
|
logger.exception("SDR rx error")
|
||||||
await self._send_error(f"SDR capture failed: {exc}")
|
await self._send_error(session.app_id, f"SDR capture failed: {exc}")
|
||||||
break
|
break
|
||||||
|
|
||||||
payload = _samples_to_interleaved_float32(samples)
|
payload = _samples_to_interleaved_float32(samples)
|
||||||
|
|
@ -139,29 +327,305 @@ class Streamer:
|
||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
raise
|
raise
|
||||||
finally:
|
finally:
|
||||||
self._close_sdr()
|
self._close_session_sdr(session)
|
||||||
|
# If the loop died on its own (e.g. SDR disconnect), clear the
|
||||||
|
# session handle so future ``start`` messages can proceed.
|
||||||
|
if self._rx is session:
|
||||||
|
self._rx = None
|
||||||
|
|
||||||
def _close_sdr(self) -> None:
|
# ==================================================================
|
||||||
if self._sdr is None:
|
# TX
|
||||||
|
|
||||||
|
async def _handle_tx_start(self, msg: dict) -> None:
|
||||||
|
app_id = msg.get("app_id") or ""
|
||||||
|
radio_config = dict(msg.get("radio_config") or {})
|
||||||
|
|
||||||
|
# --- interlocks (agent-enforced; never trust the hub alone) ---
|
||||||
|
if not self._cfg.tx_enabled:
|
||||||
|
await self._send_tx_status(app_id, "error", "tx disabled on this agent")
|
||||||
return
|
return
|
||||||
|
tx_gain = radio_config.get("tx_gain")
|
||||||
|
if (
|
||||||
|
self._cfg.tx_max_gain_db is not None
|
||||||
|
and tx_gain is not None
|
||||||
|
and float(tx_gain) > float(self._cfg.tx_max_gain_db)
|
||||||
|
):
|
||||||
|
await self._send_tx_status(
|
||||||
|
app_id,
|
||||||
|
"error",
|
||||||
|
f"tx_gain {tx_gain} exceeds cap {self._cfg.tx_max_gain_db}",
|
||||||
|
)
|
||||||
|
return
|
||||||
|
tx_freq = radio_config.get("tx_center_frequency")
|
||||||
|
if self._cfg.tx_allowed_freq_ranges and tx_freq is not None:
|
||||||
|
f = float(tx_freq)
|
||||||
|
if not any(float(lo) <= f <= float(hi) for lo, hi in self._cfg.tx_allowed_freq_ranges):
|
||||||
|
await self._send_tx_status(
|
||||||
|
app_id,
|
||||||
|
"error",
|
||||||
|
f"tx_center_frequency {tx_freq} outside allowed ranges",
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
if self._tx is not None:
|
||||||
|
await self._send_tx_status(app_id, "error", "tx already active on this agent")
|
||||||
|
return
|
||||||
|
|
||||||
|
# --- device ---
|
||||||
|
device = radio_config.pop("device", None)
|
||||||
|
identifier = radio_config.pop("identifier", None)
|
||||||
|
buffer_size = int(radio_config.pop("buffer_size", _DEFAULT_BUFFER_SIZE))
|
||||||
|
underrun_policy = str(radio_config.pop("underrun_policy", "pause"))
|
||||||
|
if underrun_policy not in ("pause", "zero", "repeat"):
|
||||||
|
await self._send_tx_status(
|
||||||
|
app_id, "error", f"invalid underrun_policy {underrun_policy!r}"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
if not device:
|
||||||
|
await self._send_tx_status(app_id, "error", "tx_start missing radio_config.device")
|
||||||
|
return
|
||||||
|
|
||||||
|
device_key: tuple[str, str | None] | None = None
|
||||||
|
sdr: Any = None
|
||||||
try:
|
try:
|
||||||
self._sdr.close()
|
sdr, device_key = self._registry.acquire(device, identifier)
|
||||||
|
_apply_sdr_config(sdr, radio_config)
|
||||||
|
# Only call init_tx when the hub supplied the three required
|
||||||
|
# parameters. Drivers that gate _stream_tx on _tx_initialized
|
||||||
|
# (e.g. Pluto) need this; drivers that don't (e.g. Mock) tolerate
|
||||||
|
# its absence.
|
||||||
|
init_args = {
|
||||||
|
k: radio_config.get(f"tx_{k}")
|
||||||
|
for k in ("sample_rate", "center_frequency", "gain")
|
||||||
|
}
|
||||||
|
if hasattr(sdr, "init_tx") and all(v is not None for v in init_args.values()):
|
||||||
|
sdr.init_tx(
|
||||||
|
sample_rate=init_args["sample_rate"],
|
||||||
|
center_frequency=init_args["center_frequency"],
|
||||||
|
gain=init_args["gain"],
|
||||||
|
channel=radio_config.get("tx_channel", 0),
|
||||||
|
gain_mode=radio_config.get("tx_gain_mode", "manual"),
|
||||||
|
)
|
||||||
|
except Exception as exc:
|
||||||
|
if device_key is not None:
|
||||||
|
if self._registry.release(device_key):
|
||||||
|
try:
|
||||||
|
sdr.close()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
logger.exception("Failed to init TX on %r", device)
|
||||||
|
await self._send_tx_status(app_id, "error", f"tx init failed: {exc}")
|
||||||
|
return
|
||||||
|
|
||||||
|
self._loop = asyncio.get_running_loop()
|
||||||
|
session = TxSession(
|
||||||
|
app_id=app_id,
|
||||||
|
sdr=sdr,
|
||||||
|
device_key=device_key,
|
||||||
|
buffer_size=buffer_size,
|
||||||
|
underrun_policy=underrun_policy,
|
||||||
|
started_at=time.monotonic(),
|
||||||
|
max_duration_s=self._cfg.tx_max_duration_s,
|
||||||
|
)
|
||||||
|
self._tx = session
|
||||||
|
await self._send_tx_status(app_id, "armed")
|
||||||
|
session.task = self._loop.run_in_executor(None, self._tx_executor_body, session)
|
||||||
|
# Spawn a small watchdog that transitions armed → transmitting when
|
||||||
|
# the first buffer has been consumed, and surfaces underrun / max-
|
||||||
|
# duration terminations back to the hub.
|
||||||
|
asyncio.create_task(self._tx_watchdog(session))
|
||||||
|
|
||||||
|
async def _handle_tx_stop(self, msg: dict) -> None:
|
||||||
|
session = self._tx
|
||||||
|
if session is None:
|
||||||
|
return
|
||||||
|
app_id = session.app_id
|
||||||
|
session.stop_event.set()
|
||||||
|
try:
|
||||||
|
session.sdr.pause_tx()
|
||||||
|
except Exception:
|
||||||
|
logger.debug("pause_tx raised during stop", exc_info=True)
|
||||||
|
# Wake the executor thread if it's blocked on ``queue.get``.
|
||||||
|
self._drain_tx_queue(session)
|
||||||
|
if session.task is not None:
|
||||||
|
try:
|
||||||
|
await asyncio.wait_for(asyncio.wrap_future(session.task), timeout=1.5)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
logger.warning("TX executor did not exit within 1.5s after stop")
|
||||||
|
except Exception:
|
||||||
|
logger.debug("TX executor raised on shutdown", exc_info=True)
|
||||||
|
self._close_session_sdr(session)
|
||||||
|
self._tx = None
|
||||||
|
await self._send_tx_status(app_id, "done")
|
||||||
|
|
||||||
|
async def _handle_tx_configure(self, msg: dict) -> None:
|
||||||
|
if self._tx is None:
|
||||||
|
return
|
||||||
|
self._tx.pending_config.update(msg.get("radio_config") or {})
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# TX executor & watchdog
|
||||||
|
|
||||||
|
def _tx_executor_body(self, session: TxSession) -> None:
|
||||||
|
try:
|
||||||
|
session.sdr._stream_tx(lambda n: self._tx_callback(session, n))
|
||||||
|
except Exception:
|
||||||
|
logger.exception("TX stream crashed")
|
||||||
|
self._schedule(self._send_tx_status(session.app_id, "error", "tx stream crashed"))
|
||||||
|
|
||||||
|
def _tx_callback(self, session: TxSession, num_samples) -> np.ndarray:
|
||||||
|
n = int(num_samples)
|
||||||
|
# Honor stop requests: return silence one last time and let the driver
|
||||||
|
# exit its loop on the next iteration (pause_tx flips _enable_tx).
|
||||||
|
if session.stop_event.is_set():
|
||||||
|
return _silence(n)
|
||||||
|
|
||||||
|
# Max-duration watchdog.
|
||||||
|
if (
|
||||||
|
session.max_duration_s is not None
|
||||||
|
and (time.monotonic() - session.started_at) >= float(session.max_duration_s)
|
||||||
|
):
|
||||||
|
session.stop_event.set()
|
||||||
|
try:
|
||||||
|
session.sdr.pause_tx()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
self._schedule(self._send_tx_status(session.app_id, "done", "max duration reached"))
|
||||||
|
return _silence(n)
|
||||||
|
|
||||||
|
# Apply queued configure at buffer boundary.
|
||||||
|
if session.pending_config:
|
||||||
|
cfg = session.pending_config
|
||||||
|
session.pending_config = {}
|
||||||
|
try:
|
||||||
|
_apply_sdr_config(session.sdr, cfg)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.debug("tx_configure apply failed: %s", exc)
|
||||||
|
|
||||||
|
try:
|
||||||
|
raw = session.in_queue.get(timeout=0.1)
|
||||||
|
except queue.Empty:
|
||||||
|
return self._underrun_fill(session, n)
|
||||||
|
|
||||||
|
arr = np.frombuffer(raw, dtype=np.float32)
|
||||||
|
if arr.size < 2 or arr.size % 2 != 0:
|
||||||
|
logger.warning("Malformed TX frame: %d floats (must be non-zero even count)", arr.size)
|
||||||
|
return self._underrun_fill(session, n)
|
||||||
|
samples = (arr[0::2].astype(np.complex64) + 1j * arr[1::2].astype(np.complex64))
|
||||||
|
if samples.size < n:
|
||||||
|
out = np.zeros(n, dtype=np.complex64)
|
||||||
|
out[: samples.size] = samples
|
||||||
|
session.last_buffer = out
|
||||||
|
return out
|
||||||
|
if samples.size > n:
|
||||||
|
samples = samples[:n]
|
||||||
|
session.last_buffer = samples
|
||||||
|
if session.state == "armed":
|
||||||
|
session.state = "transmitting"
|
||||||
|
self._schedule(self._send_tx_status(session.app_id, "transmitting"))
|
||||||
|
return samples
|
||||||
|
|
||||||
|
def _underrun_fill(self, session: TxSession, n: int) -> np.ndarray:
|
||||||
|
policy = session.underrun_policy
|
||||||
|
if policy == "zero":
|
||||||
|
return _silence(n)
|
||||||
|
if policy == "repeat" and session.last_buffer is not None:
|
||||||
|
buf = session.last_buffer
|
||||||
|
if buf.size == n:
|
||||||
|
return buf
|
||||||
|
if buf.size > n:
|
||||||
|
return buf[:n].copy()
|
||||||
|
out = np.zeros(n, dtype=np.complex64)
|
||||||
|
out[: buf.size] = buf
|
||||||
|
return out
|
||||||
|
# "pause" policy (default) or "repeat" before any buffer arrived.
|
||||||
|
if not session.underrun_flag.is_set():
|
||||||
|
session.underrun_flag.set()
|
||||||
|
session.stop_event.set()
|
||||||
|
try:
|
||||||
|
session.sdr.pause_tx()
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
self._sdr = None
|
return _silence(n)
|
||||||
|
|
||||||
async def _send_status(self, status: str) -> None:
|
async def _tx_watchdog(self, session: TxSession) -> None:
|
||||||
|
# Poll the underrun flag so we can emit status + tear down cleanly
|
||||||
|
# when the callback flips the flag from the executor thread. Check
|
||||||
|
# underrun_flag before stop_event, since the "pause" path sets both.
|
||||||
|
while session is self._tx:
|
||||||
|
if session.underrun_flag.is_set():
|
||||||
|
await self._send_tx_status(session.app_id, "underrun")
|
||||||
|
await self._teardown_tx_after_underrun(session)
|
||||||
|
return
|
||||||
|
if session.stop_event.is_set():
|
||||||
|
return
|
||||||
|
await asyncio.sleep(0.05)
|
||||||
|
|
||||||
|
async def _teardown_tx_after_underrun(self, session: TxSession) -> None:
|
||||||
|
if self._tx is not session:
|
||||||
|
return
|
||||||
|
self._drain_tx_queue(session)
|
||||||
|
if session.task is not None:
|
||||||
|
try:
|
||||||
|
await asyncio.wait_for(asyncio.wrap_future(session.task), timeout=1.0)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
logger.warning("TX executor did not exit within 1s after underrun")
|
||||||
|
except Exception:
|
||||||
|
logger.debug("TX executor raised during underrun teardown", exc_info=True)
|
||||||
|
self._close_session_sdr(session)
|
||||||
|
if self._tx is session:
|
||||||
|
self._tx = None
|
||||||
|
|
||||||
|
def _drain_tx_queue(self, session: TxSession) -> None:
|
||||||
try:
|
try:
|
||||||
await self.ws.send_json({"type": "status", "status": status, "app_id": self._app_id})
|
while True:
|
||||||
|
session.in_queue.get_nowait()
|
||||||
|
except queue.Empty:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def _schedule(self, coro) -> None:
|
||||||
|
loop = self._loop
|
||||||
|
if loop is None:
|
||||||
|
return
|
||||||
|
try:
|
||||||
|
asyncio.run_coroutine_threadsafe(coro, loop)
|
||||||
|
except Exception:
|
||||||
|
logger.debug("_schedule failed", exc_info=True)
|
||||||
|
|
||||||
|
# ==================================================================
|
||||||
|
# Helpers
|
||||||
|
|
||||||
|
def _close_session_sdr(self, session) -> None:
|
||||||
|
if session.sdr is None:
|
||||||
|
return
|
||||||
|
should_close = self._registry.release(session.device_key)
|
||||||
|
if should_close:
|
||||||
|
try:
|
||||||
|
session.sdr.close()
|
||||||
|
except Exception:
|
||||||
|
logger.debug("SDR close raised", exc_info=True)
|
||||||
|
|
||||||
|
async def _send_status(self, status: str, app_id: str) -> None:
|
||||||
|
try:
|
||||||
|
await self.ws.send_json({"type": "status", "status": status, "app_id": app_id})
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
logger.debug("Status send failed: %s", exc)
|
logger.debug("Status send failed: %s", exc)
|
||||||
|
|
||||||
async def _send_error(self, message: str) -> None:
|
async def _send_error(self, app_id: str, message: str) -> None:
|
||||||
try:
|
try:
|
||||||
await self.ws.send_json({"type": "error", "app_id": self._app_id, "message": message})
|
await self.ws.send_json({"type": "error", "app_id": app_id, "message": message})
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
logger.debug("Error-frame send failed: %s", exc)
|
logger.debug("Error-frame send failed: %s", exc)
|
||||||
|
|
||||||
|
async def _send_tx_status(self, app_id: str, state: str, message: str | None = None) -> None:
|
||||||
|
payload: dict = {"type": "tx_status", "app_id": app_id, "state": state}
|
||||||
|
if message is not None:
|
||||||
|
payload["message"] = message
|
||||||
|
try:
|
||||||
|
await self.ws.send_json(payload)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.debug("tx_status send failed: %s", exc)
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Helpers
|
# Helpers
|
||||||
|
|
@ -172,6 +636,10 @@ _CONFIG_ATTR_MAP = {
|
||||||
"center_freq": ("center_freq", "rx_center_frequency"),
|
"center_freq": ("center_freq", "rx_center_frequency"),
|
||||||
"gain": ("gain", "rx_gain"),
|
"gain": ("gain", "rx_gain"),
|
||||||
"bandwidth": ("bandwidth", "rx_bandwidth"),
|
"bandwidth": ("bandwidth", "rx_bandwidth"),
|
||||||
|
"tx_sample_rate": ("tx_sample_rate",),
|
||||||
|
"tx_center_frequency": ("tx_center_frequency", "tx_lo"),
|
||||||
|
"tx_gain": ("tx_gain",),
|
||||||
|
"tx_bandwidth": ("tx_bandwidth",),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -194,6 +662,11 @@ def _apply_sdr_config(sdr: Any, cfg: dict) -> None:
|
||||||
logger.debug("radio_config key %r ignored (no matching attr)", key)
|
logger.debug("radio_config key %r ignored (no matching attr)", key)
|
||||||
|
|
||||||
|
|
||||||
|
def _silence(num_samples: int) -> np.ndarray:
|
||||||
|
"""Return a ``num_samples``-length zero-filled complex64 buffer."""
|
||||||
|
return np.zeros(int(num_samples), dtype=np.complex64)
|
||||||
|
|
||||||
|
|
||||||
def _samples_to_interleaved_float32(samples: Any) -> bytes:
|
def _samples_to_interleaved_float32(samples: Any) -> bytes:
|
||||||
"""Convert complex IQ samples (any numeric dtype) to interleaved float32 bytes."""
|
"""Convert complex IQ samples (any numeric dtype) to interleaved float32 bytes."""
|
||||||
arr = np.asarray(samples)
|
arr = np.asarray(samples)
|
||||||
|
|
@ -214,8 +687,12 @@ def _default_sdr_factory(device: str, identifier: str | None):
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Top-level entry
|
# Top-level entry
|
||||||
|
|
||||||
async def run_streamer(ws_url: str, token: str) -> None:
|
async def run_streamer(ws_url: str, token: str, *, cfg: AgentConfig | None = None) -> None:
|
||||||
"""Connect to *ws_url* and run the streamer loop until cancelled."""
|
"""Connect to *ws_url* and run the streamer loop until cancelled."""
|
||||||
ws = WsClient(ws_url, token)
|
ws = WsClient(ws_url, token)
|
||||||
streamer = Streamer(ws)
|
streamer = Streamer(ws, cfg=cfg)
|
||||||
await ws.run(streamer.on_message, streamer.build_heartbeat)
|
await ws.run(
|
||||||
|
streamer.on_message,
|
||||||
|
streamer.build_heartbeat,
|
||||||
|
on_binary=streamer.on_binary,
|
||||||
|
)
|
||||||
|
|
|
||||||
|
|
@ -15,6 +15,7 @@ logger = logging.getLogger("ria_agent.ws")
|
||||||
|
|
||||||
MessageHandler = Callable[[dict], Awaitable[None]]
|
MessageHandler = Callable[[dict], Awaitable[None]]
|
||||||
HeartbeatBuilder = Callable[[], dict]
|
HeartbeatBuilder = Callable[[], dict]
|
||||||
|
BinaryHandler = Callable[[bytes], Awaitable[None]]
|
||||||
|
|
||||||
|
|
||||||
class WsClient:
|
class WsClient:
|
||||||
|
|
@ -65,7 +66,12 @@ class WsClient:
|
||||||
self._stop.set()
|
self._stop.set()
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
async def run(self, on_message: MessageHandler, heartbeat: HeartbeatBuilder) -> None:
|
async def run(
|
||||||
|
self,
|
||||||
|
on_message: MessageHandler,
|
||||||
|
heartbeat: HeartbeatBuilder,
|
||||||
|
on_binary: BinaryHandler | None = None,
|
||||||
|
) -> None:
|
||||||
"""Main loop: connect, heartbeat, dispatch messages, reconnect on drop."""
|
"""Main loop: connect, heartbeat, dispatch messages, reconnect on drop."""
|
||||||
while not self._stop.is_set():
|
while not self._stop.is_set():
|
||||||
try:
|
try:
|
||||||
|
|
@ -75,8 +81,13 @@ class WsClient:
|
||||||
try:
|
try:
|
||||||
async for raw in self._ws:
|
async for raw in self._ws:
|
||||||
if isinstance(raw, bytes):
|
if isinstance(raw, bytes):
|
||||||
# Server shouldn't send binary to the agent; log and drop.
|
if on_binary is None:
|
||||||
logger.debug("Discarding unexpected %d-byte binary frame", len(raw))
|
logger.debug("Discarding unexpected %d-byte binary frame", len(raw))
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
await on_binary(raw)
|
||||||
|
except Exception:
|
||||||
|
logger.exception("on_binary handler raised; dropping frame")
|
||||||
continue
|
continue
|
||||||
try:
|
try:
|
||||||
msg = json.loads(raw)
|
msg = json.loads(raw)
|
||||||
|
|
|
||||||
|
|
@ -21,6 +21,7 @@ from __future__ import annotations
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import json
|
import json
|
||||||
|
import os
|
||||||
import shutil
|
import shutil
|
||||||
import subprocess
|
import subprocess
|
||||||
import sys
|
import sys
|
||||||
|
|
@ -77,24 +78,33 @@ def _inspect_labels(engine: list[str], ref: str) -> dict:
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
|
|
||||||
def _hardware_flags(labels: dict) -> list[str]:
|
def _gpu_available() -> bool:
|
||||||
|
if os.path.exists("/dev/nvidia0"):
|
||||||
|
return True
|
||||||
|
return shutil.which("nvidia-smi") is not None
|
||||||
|
|
||||||
|
|
||||||
|
def _hardware_flags(labels: dict, no_gpu: bool, no_usb: bool, no_host_net: bool) -> tuple[list[str], list[str]]:
|
||||||
flags: list[str] = []
|
flags: list[str] = []
|
||||||
|
notes: list[str] = []
|
||||||
profile = (labels.get(_LABEL_PROFILE) or "").lower()
|
profile = (labels.get(_LABEL_PROFILE) or "").lower()
|
||||||
hardware = (labels.get(_LABEL_HARDWARE) or "").lower()
|
hardware = (labels.get(_LABEL_HARDWARE) or "").lower()
|
||||||
hw_items = {h.strip() for h in hardware.split(",") if h.strip()}
|
hw_items = {h.strip() for h in hardware.split(",") if h.strip()}
|
||||||
|
|
||||||
if "nvidia" in profile or "holoscan" in profile or "cuda" in profile:
|
wants_gpu = any(k in profile for k in ("nvidia", "holoscan", "cuda"))
|
||||||
flags += ["--gpus", "all"]
|
if wants_gpu and not no_gpu:
|
||||||
|
if _gpu_available():
|
||||||
|
flags += ["--gpus", "all"]
|
||||||
|
else:
|
||||||
|
notes.append("image wants GPU but no NVIDIA runtime detected — skipping --gpus (use --force-gpu to override)")
|
||||||
|
|
||||||
needs_usb = hw_items & {"pluto", "rtlsdr", "hackrf", "bladerf"}
|
if hw_items & {"pluto", "rtlsdr", "hackrf", "bladerf"} and not no_usb:
|
||||||
if needs_usb:
|
|
||||||
flags += ["--device", "/dev/bus/usb"]
|
flags += ["--device", "/dev/bus/usb"]
|
||||||
|
|
||||||
needs_net = hw_items & {"usrp", "thinkrf", "pluto"}
|
if hw_items & {"usrp", "thinkrf", "pluto"} and not no_host_net:
|
||||||
if needs_net:
|
|
||||||
flags += ["--net", "host"]
|
flags += ["--net", "host"]
|
||||||
|
|
||||||
return flags
|
return flags, notes
|
||||||
|
|
||||||
|
|
||||||
def _cmd_configure(args: argparse.Namespace) -> int:
|
def _cmd_configure(args: argparse.Namespace) -> int:
|
||||||
|
|
@ -132,7 +142,10 @@ def _cmd_run(args: argparse.Namespace) -> int:
|
||||||
return rc
|
return rc
|
||||||
|
|
||||||
labels = _inspect_labels(engine, ref)
|
labels = _inspect_labels(engine, ref)
|
||||||
hw_flags = _hardware_flags(labels)
|
no_gpu = args.no_gpu and not args.force_gpu
|
||||||
|
hw_flags, notes = _hardware_flags(labels, no_gpu=no_gpu, no_usb=args.no_usb, no_host_net=args.no_host_net)
|
||||||
|
if args.force_gpu and "--gpus" not in hw_flags:
|
||||||
|
hw_flags = ["--gpus", "all", *hw_flags]
|
||||||
|
|
||||||
cmd = [*engine, "run", "--rm"]
|
cmd = [*engine, "run", "--rm"]
|
||||||
if not args.foreground:
|
if not args.foreground:
|
||||||
|
|
@ -162,6 +175,8 @@ def _cmd_run(args: argparse.Namespace) -> int:
|
||||||
print(f"Running {ref} [{label_str}]")
|
print(f"Running {ref} [{label_str}]")
|
||||||
if hw_flags:
|
if hw_flags:
|
||||||
print(f" auto flags: {' '.join(hw_flags)}")
|
print(f" auto flags: {' '.join(hw_flags)}")
|
||||||
|
for note in notes:
|
||||||
|
print(f" note: {note}")
|
||||||
return subprocess.call(cmd)
|
return subprocess.call(cmd)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -225,6 +240,10 @@ def main() -> None:
|
||||||
p_run.add_argument("-v", "--volume", action="append", help="Extra volume mount")
|
p_run.add_argument("-v", "--volume", action="append", help="Extra volume mount")
|
||||||
p_run.add_argument("-p", "--publish", action="append", help="Publish port")
|
p_run.add_argument("-p", "--publish", action="append", help="Publish port")
|
||||||
p_run.add_argument("--foreground", "-F", action="store_true", help="Run in foreground (no -d)")
|
p_run.add_argument("--foreground", "-F", action="store_true", help="Run in foreground (no -d)")
|
||||||
|
p_run.add_argument("--no-gpu", action="store_true", help="Skip --gpus flag even if image wants GPU")
|
||||||
|
p_run.add_argument("--force-gpu", action="store_true", help="Force --gpus all even if no NVIDIA runtime detected")
|
||||||
|
p_run.add_argument("--no-usb", action="store_true", help="Skip --device /dev/bus/usb")
|
||||||
|
p_run.add_argument("--no-host-net", action="store_true", help="Skip --net host")
|
||||||
p_run.add_argument("--dry-run", action="store_true", help="Print the container command and exit")
|
p_run.add_argument("--dry-run", action="store_true", help="Print the container command and exit")
|
||||||
p_run.add_argument("--docker-args", nargs=argparse.REMAINDER, help="Pass remaining args to docker/podman run")
|
p_run.add_argument("--docker-args", nargs=argparse.REMAINDER, help="Pass remaining args to docker/podman run")
|
||||||
p_run.add_argument("--app-args", nargs=argparse.REMAINDER, help="Pass remaining args to the app entrypoint")
|
p_run.add_argument("--app-args", nargs=argparse.REMAINDER, help="Pass remaining args to the app entrypoint")
|
||||||
|
|
|
||||||
111
tests/agent/test_cli_tx.py
Normal file
111
tests/agent/test_cli_tx.py
Normal file
|
|
@ -0,0 +1,111 @@
|
||||||
|
"""CLI flags for TX opt-in and interlocks."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import sys
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
from ria_toolkit_oss.agent import cli as agent_cli
|
||||||
|
from ria_toolkit_oss.agent import config as agent_config
|
||||||
|
|
||||||
|
|
||||||
|
class _FakeResp:
|
||||||
|
def __init__(self, payload: dict):
|
||||||
|
self._payload = payload
|
||||||
|
|
||||||
|
def read(self) -> bytes:
|
||||||
|
return json.dumps(self._payload).encode()
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(self, *_a):
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def _run_register(argv: list[str], cfg_path) -> int:
|
||||||
|
fake_resp = _FakeResp({"agent_id": "agent-1", "token": "tok-abc"})
|
||||||
|
with patch.dict("os.environ", {"RIA_AGENT_CONFIG": str(cfg_path)}, clear=False), \
|
||||||
|
patch("urllib.request.urlopen", return_value=fake_resp), \
|
||||||
|
patch.object(sys, "argv", ["ria-agent", *argv]):
|
||||||
|
try:
|
||||||
|
agent_cli.main()
|
||||||
|
except SystemExit as exc:
|
||||||
|
return int(exc.code or 0)
|
||||||
|
return 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_register_without_allow_tx_keeps_tx_disabled(tmp_path):
|
||||||
|
cfg_path = tmp_path / "agent.json"
|
||||||
|
_run_register(
|
||||||
|
["register", "--hub", "http://hub:3005", "--api-key", "K"],
|
||||||
|
cfg_path,
|
||||||
|
)
|
||||||
|
cfg = agent_config.load(path=cfg_path)
|
||||||
|
assert cfg.agent_id == "agent-1"
|
||||||
|
assert cfg.tx_enabled is False
|
||||||
|
assert cfg.tx_max_gain_db is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_register_with_allow_tx_and_caps(tmp_path):
|
||||||
|
cfg_path = tmp_path / "agent.json"
|
||||||
|
_run_register(
|
||||||
|
[
|
||||||
|
"register",
|
||||||
|
"--hub",
|
||||||
|
"http://hub:3005",
|
||||||
|
"--api-key",
|
||||||
|
"K",
|
||||||
|
"--allow-tx",
|
||||||
|
"--tx-max-gain-db",
|
||||||
|
"-10",
|
||||||
|
"--tx-max-duration-s",
|
||||||
|
"60",
|
||||||
|
"--tx-freq-range",
|
||||||
|
"2.4e9",
|
||||||
|
"2.5e9",
|
||||||
|
"--tx-freq-range",
|
||||||
|
"5.7e9",
|
||||||
|
"5.8e9",
|
||||||
|
],
|
||||||
|
cfg_path,
|
||||||
|
)
|
||||||
|
cfg = agent_config.load(path=cfg_path)
|
||||||
|
assert cfg.tx_enabled is True
|
||||||
|
assert cfg.tx_max_gain_db == -10.0
|
||||||
|
assert cfg.tx_max_duration_s == 60.0
|
||||||
|
assert cfg.tx_allowed_freq_ranges == [[2.4e9, 2.5e9], [5.7e9, 5.8e9]]
|
||||||
|
|
||||||
|
|
||||||
|
def test_stream_allow_tx_does_not_persist(tmp_path):
|
||||||
|
# Pre-register with tx_enabled=False, then simulate `stream --allow-tx`.
|
||||||
|
# The on-disk config must remain unchanged; the runtime flag is process-local.
|
||||||
|
cfg_path = tmp_path / "agent.json"
|
||||||
|
base = agent_config.AgentConfig(
|
||||||
|
hub_url="http://hub:3005",
|
||||||
|
agent_id="agent-1",
|
||||||
|
token="tok-abc",
|
||||||
|
tx_enabled=False,
|
||||||
|
)
|
||||||
|
agent_config.save(base, path=cfg_path)
|
||||||
|
|
||||||
|
captured: dict = {}
|
||||||
|
|
||||||
|
async def _fake_run_streamer(url, token, *, cfg):
|
||||||
|
captured["cfg"] = cfg
|
||||||
|
return None
|
||||||
|
|
||||||
|
with patch.dict("os.environ", {"RIA_AGENT_CONFIG": str(cfg_path)}, clear=False), \
|
||||||
|
patch("ria_toolkit_oss.agent.streamer.run_streamer", new=_fake_run_streamer), \
|
||||||
|
patch.object(sys, "argv", ["ria-agent", "stream", "--allow-tx"]):
|
||||||
|
try:
|
||||||
|
agent_cli.main()
|
||||||
|
except SystemExit:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Runtime cfg had TX flipped on
|
||||||
|
assert captured["cfg"].tx_enabled is True
|
||||||
|
# But the persisted file is untouched
|
||||||
|
on_disk = agent_config.load(path=cfg_path)
|
||||||
|
assert on_disk.tx_enabled is False
|
||||||
|
|
@ -20,6 +20,36 @@ def test_load_missing_returns_empty(tmp_path):
|
||||||
assert loaded == agent_config.AgentConfig()
|
assert loaded == agent_config.AgentConfig()
|
||||||
|
|
||||||
|
|
||||||
|
def test_tx_fields_round_trip(tmp_path):
|
||||||
|
p = tmp_path / "agent.json"
|
||||||
|
cfg = agent_config.AgentConfig(
|
||||||
|
hub_url="https://hub.example.com",
|
||||||
|
agent_id="agent-1",
|
||||||
|
token="t",
|
||||||
|
tx_enabled=True,
|
||||||
|
tx_max_gain_db=-10.0,
|
||||||
|
tx_max_duration_s=60.0,
|
||||||
|
tx_allowed_freq_ranges=[[2.4e9, 2.5e9], [5.7e9, 5.8e9]],
|
||||||
|
)
|
||||||
|
agent_config.save(cfg, path=p)
|
||||||
|
loaded = agent_config.load(path=p)
|
||||||
|
assert loaded.tx_enabled is True
|
||||||
|
assert loaded.tx_max_gain_db == -10.0
|
||||||
|
assert loaded.tx_max_duration_s == 60.0
|
||||||
|
assert loaded.tx_allowed_freq_ranges == [[2.4e9, 2.5e9], [5.7e9, 5.8e9]]
|
||||||
|
|
||||||
|
|
||||||
|
def test_tx_fields_default_when_absent(tmp_path):
|
||||||
|
# Old configs written before TX existed should load cleanly with safe defaults.
|
||||||
|
p = tmp_path / "agent.json"
|
||||||
|
p.write_text('{"hub_url": "x", "agent_id": "a", "token": "t"}')
|
||||||
|
cfg = agent_config.load(path=p)
|
||||||
|
assert cfg.tx_enabled is False
|
||||||
|
assert cfg.tx_max_gain_db is None
|
||||||
|
assert cfg.tx_max_duration_s is None
|
||||||
|
assert cfg.tx_allowed_freq_ranges is None
|
||||||
|
|
||||||
|
|
||||||
def test_extra_keys_preserved(tmp_path):
|
def test_extra_keys_preserved(tmp_path):
|
||||||
p = tmp_path / "agent.json"
|
p = tmp_path / "agent.json"
|
||||||
p.write_text('{"hub_url": "x", "custom": 42}')
|
p.write_text('{"hub_url": "x", "custom": 42}')
|
||||||
|
|
|
||||||
|
|
@ -67,9 +67,9 @@ def test_streamer_reports_disconnected_and_ends_capture():
|
||||||
"radio_config": {"device": "fake", "buffer_size": 8},
|
"radio_config": {"device": "fake", "buffer_size": 8},
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
# Wait for the capture task to fail out.
|
# Wait for the capture loop to emit its error frame and tear down the session.
|
||||||
for _ in range(50):
|
for _ in range(100):
|
||||||
if streamer._capture_task and streamer._capture_task.done():
|
if any(m.get("type") == "error" for m in ws.json_sent) and streamer._rx is None:
|
||||||
break
|
break
|
||||||
await asyncio.sleep(0.01)
|
await asyncio.sleep(0.01)
|
||||||
return ws, sdr, streamer
|
return ws, sdr, streamer
|
||||||
|
|
@ -79,3 +79,5 @@ def test_streamer_reports_disconnected_and_ends_capture():
|
||||||
errors = [m for m in ws.json_sent if m.get("type") == "error"]
|
errors = [m for m in ws.json_sent if m.get("type") == "error"]
|
||||||
assert errors, "expected an error frame"
|
assert errors, "expected an error frame"
|
||||||
assert "disconnected" in errors[-1]["message"].lower()
|
assert "disconnected" in errors[-1]["message"].lower()
|
||||||
|
# Session handle cleared so future starts can proceed.
|
||||||
|
assert streamer._rx is None
|
||||||
|
|
|
||||||
133
tests/agent/test_full_duplex.py
Normal file
133
tests/agent/test_full_duplex.py
Normal file
|
|
@ -0,0 +1,133 @@
|
||||||
|
"""Concurrent RX + TX sessions on the same agent — shared SDR via registry."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import time
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from ria_toolkit_oss.agent.config import AgentConfig
|
||||||
|
from ria_toolkit_oss.agent.streamer import Streamer
|
||||||
|
from ria_toolkit_oss.sdr.mock import MockSDR
|
||||||
|
|
||||||
|
|
||||||
|
class FullDuplexMockSDR(MockSDR):
|
||||||
|
"""MockSDR with a recording TX path so the test can assert both directions."""
|
||||||
|
|
||||||
|
def __init__(self, buffer_size: int):
|
||||||
|
super().__init__(buffer_size=buffer_size)
|
||||||
|
self.tx_produced: list[np.ndarray] = []
|
||||||
|
|
||||||
|
def _stream_tx(self, callback):
|
||||||
|
self._enable_tx = True
|
||||||
|
self._tx_initialized = True
|
||||||
|
while self._enable_tx:
|
||||||
|
result = callback(self.rx_buffer_size)
|
||||||
|
self.tx_produced.append(np.asarray(result).copy())
|
||||||
|
time.sleep(0.005)
|
||||||
|
|
||||||
|
|
||||||
|
class FakeWs:
|
||||||
|
def __init__(self):
|
||||||
|
self.json_sent = []
|
||||||
|
self.bytes_sent = []
|
||||||
|
|
||||||
|
async def send_json(self, p):
|
||||||
|
self.json_sent.append(p)
|
||||||
|
|
||||||
|
async def send_bytes(self, b):
|
||||||
|
self.bytes_sent.append(b)
|
||||||
|
|
||||||
|
|
||||||
|
def _iq_frame(samples: np.ndarray) -> bytes:
|
||||||
|
interleaved = np.empty(samples.size * 2, dtype=np.float32)
|
||||||
|
interleaved[0::2] = samples.real
|
||||||
|
interleaved[1::2] = samples.imag
|
||||||
|
return interleaved.tobytes()
|
||||||
|
|
||||||
|
|
||||||
|
def test_rx_and_tx_share_one_sdr_instance():
|
||||||
|
built: list[FullDuplexMockSDR] = []
|
||||||
|
|
||||||
|
def factory(device, identifier):
|
||||||
|
sdr = FullDuplexMockSDR(buffer_size=16)
|
||||||
|
built.append(sdr)
|
||||||
|
return sdr
|
||||||
|
|
||||||
|
async def scenario():
|
||||||
|
ws = FakeWs()
|
||||||
|
s = Streamer(ws=ws, sdr_factory=factory, cfg=AgentConfig(tx_enabled=True))
|
||||||
|
|
||||||
|
# Start RX first.
|
||||||
|
await s.on_message(
|
||||||
|
{
|
||||||
|
"type": "start",
|
||||||
|
"app_id": "app-1",
|
||||||
|
"radio_config": {"device": "mock", "buffer_size": 16},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
# Then start TX on the same device — should share the SDR handle.
|
||||||
|
await s.on_message(
|
||||||
|
{
|
||||||
|
"type": "tx_start",
|
||||||
|
"app_id": "app-1",
|
||||||
|
"radio_config": {
|
||||||
|
"device": "mock",
|
||||||
|
"buffer_size": 16,
|
||||||
|
"tx_gain": -20,
|
||||||
|
"tx_center_frequency": 2.45e9,
|
||||||
|
"underrun_policy": "zero",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Push a known TX buffer.
|
||||||
|
marker = np.arange(16, dtype=np.complex64) + 7
|
||||||
|
await s.on_binary(_iq_frame(marker))
|
||||||
|
|
||||||
|
# Let both directions produce output.
|
||||||
|
for _ in range(80):
|
||||||
|
rx_ok = len(ws.bytes_sent) >= 2
|
||||||
|
tx_ok = any(np.array_equal(b, marker) for b in built[0].tx_produced) if built else False
|
||||||
|
if rx_ok and tx_ok:
|
||||||
|
break
|
||||||
|
await asyncio.sleep(0.01)
|
||||||
|
|
||||||
|
# Heartbeat should show both sessions.
|
||||||
|
hb = s.build_heartbeat()
|
||||||
|
|
||||||
|
# Stop TX first, RX keeps running.
|
||||||
|
await s.on_message({"type": "tx_stop", "app_id": "app-1"})
|
||||||
|
tx_after_stop = s._tx is None
|
||||||
|
rx_still_active = s._rx is not None
|
||||||
|
|
||||||
|
# Now stop RX.
|
||||||
|
await s.on_message({"type": "stop", "app_id": "app-1"})
|
||||||
|
|
||||||
|
return ws, s, built, hb, tx_after_stop, rx_still_active
|
||||||
|
|
||||||
|
ws, s, built, hb, tx_after_stop, rx_still_active = asyncio.run(scenario())
|
||||||
|
|
||||||
|
# One SDR was built and shared.
|
||||||
|
assert len(built) == 1, f"expected exactly one SDR instance, got {len(built)}"
|
||||||
|
|
||||||
|
# Both directions produced output.
|
||||||
|
assert len(ws.bytes_sent) >= 1, "RX produced no IQ frames"
|
||||||
|
marker = np.arange(16, dtype=np.complex64) + 7
|
||||||
|
assert any(
|
||||||
|
np.array_equal(b, marker) for b in built[0].tx_produced
|
||||||
|
), "TX callback never saw the pushed marker buffer"
|
||||||
|
|
||||||
|
# Heartbeat reflected both sessions while they were active.
|
||||||
|
assert hb["sessions"]["rx"]["app_id"] == "app-1"
|
||||||
|
assert hb["sessions"]["tx"]["app_id"] == "app-1"
|
||||||
|
|
||||||
|
# Stopping TX does not tear down RX.
|
||||||
|
assert tx_after_stop
|
||||||
|
assert rx_still_active
|
||||||
|
|
||||||
|
# After both stops, registry is empty.
|
||||||
|
assert s._registry.refcount(("mock", None)) == 0
|
||||||
|
assert s._rx is None
|
||||||
|
assert s._tx is None
|
||||||
|
|
@ -23,7 +23,24 @@ def test_heartbeat_payload_shape():
|
||||||
assert p["status"] == "idle"
|
assert p["status"] == "idle"
|
||||||
assert "mock" in p["hardware"]
|
assert "mock" in p["hardware"]
|
||||||
assert "app_id" not in p
|
assert "app_id" not in p
|
||||||
|
# New fields, default shape
|
||||||
|
assert p["capabilities"] == ["rx"]
|
||||||
|
assert p["tx_enabled"] is False
|
||||||
|
|
||||||
p2 = hardware.heartbeat_payload(status="streaming", app_id="abc")
|
p2 = hardware.heartbeat_payload(status="streaming", app_id="abc")
|
||||||
assert p2["status"] == "streaming"
|
assert p2["status"] == "streaming"
|
||||||
assert p2["app_id"] == "abc"
|
assert p2["app_id"] == "abc"
|
||||||
|
|
||||||
|
|
||||||
|
def test_heartbeat_payload_tx_capability_from_cfg():
|
||||||
|
from ria_toolkit_oss.agent.config import AgentConfig
|
||||||
|
|
||||||
|
p = hardware.heartbeat_payload(cfg=AgentConfig(tx_enabled=True))
|
||||||
|
assert p["capabilities"] == ["rx", "tx"]
|
||||||
|
assert p["tx_enabled"] is True
|
||||||
|
|
||||||
|
|
||||||
|
def test_heartbeat_payload_sessions_field():
|
||||||
|
sessions = {"rx": {"app_id": "a", "state": "streaming"}}
|
||||||
|
p = hardware.heartbeat_payload(status="streaming", app_id="a", sessions=sessions)
|
||||||
|
assert p["sessions"] == sessions
|
||||||
|
|
|
||||||
144
tests/agent/test_integration_tx.py
Normal file
144
tests/agent/test_integration_tx.py
Normal file
|
|
@ -0,0 +1,144 @@
|
||||||
|
"""End-to-end: local websockets server drives a Streamer's TX path."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import time
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import websockets
|
||||||
|
|
||||||
|
from ria_toolkit_oss.agent.config import AgentConfig
|
||||||
|
from ria_toolkit_oss.agent.streamer import Streamer
|
||||||
|
from ria_toolkit_oss.agent.ws_client import WsClient
|
||||||
|
from ria_toolkit_oss.sdr.mock import MockSDR
|
||||||
|
|
||||||
|
|
||||||
|
class RecordingMockSDR(MockSDR):
|
||||||
|
def __init__(self, buffer_size: int):
|
||||||
|
super().__init__(buffer_size=buffer_size)
|
||||||
|
self.tx_produced: list[np.ndarray] = []
|
||||||
|
|
||||||
|
def _stream_tx(self, callback):
|
||||||
|
self._enable_tx = True
|
||||||
|
self._tx_initialized = True
|
||||||
|
while self._enable_tx:
|
||||||
|
result = callback(self.rx_buffer_size)
|
||||||
|
self.tx_produced.append(np.asarray(result).copy())
|
||||||
|
time.sleep(0.005)
|
||||||
|
|
||||||
|
|
||||||
|
def _iq_frame(samples: np.ndarray) -> bytes:
|
||||||
|
interleaved = np.empty(samples.size * 2, dtype=np.float32)
|
||||||
|
interleaved[0::2] = samples.real
|
||||||
|
interleaved[1::2] = samples.imag
|
||||||
|
return interleaved.tobytes()
|
||||||
|
|
||||||
|
|
||||||
|
def test_server_tx_start_binary_stop_cycle_over_real_ws():
|
||||||
|
BUF = 16
|
||||||
|
sdr = RecordingMockSDR(buffer_size=BUF)
|
||||||
|
marker = np.arange(BUF, dtype=np.complex64) + 1
|
||||||
|
|
||||||
|
async def scenario():
|
||||||
|
control_frames: list[dict] = []
|
||||||
|
done = asyncio.Event()
|
||||||
|
|
||||||
|
async def server_handler(ws):
|
||||||
|
try:
|
||||||
|
# Drain initial heartbeat.
|
||||||
|
first = await asyncio.wait_for(ws.recv(), timeout=2.0)
|
||||||
|
control_frames.append(json.loads(first))
|
||||||
|
|
||||||
|
await ws.send(
|
||||||
|
json.dumps(
|
||||||
|
{
|
||||||
|
"type": "tx_start",
|
||||||
|
"app_id": "tx-app",
|
||||||
|
"radio_config": {
|
||||||
|
"device": "mock",
|
||||||
|
"buffer_size": BUF,
|
||||||
|
"tx_sample_rate": 1_000_000,
|
||||||
|
"tx_center_frequency": 2.45e9,
|
||||||
|
"tx_gain": -20,
|
||||||
|
"underrun_policy": "zero",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Push a few binary IQ frames.
|
||||||
|
for _ in range(3):
|
||||||
|
await ws.send(_iq_frame(marker))
|
||||||
|
|
||||||
|
# Wait for at least "armed" + "transmitting" statuses.
|
||||||
|
for _ in range(100):
|
||||||
|
msg = await asyncio.wait_for(ws.recv(), timeout=2.0)
|
||||||
|
if isinstance(msg, str):
|
||||||
|
control_frames.append(json.loads(msg))
|
||||||
|
if any(
|
||||||
|
f.get("type") == "tx_status" and f.get("state") == "transmitting"
|
||||||
|
for f in control_frames
|
||||||
|
):
|
||||||
|
break
|
||||||
|
|
||||||
|
await ws.send(json.dumps({"type": "tx_stop", "app_id": "tx-app"}))
|
||||||
|
|
||||||
|
# Drain trailing statuses.
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
msg = await asyncio.wait_for(ws.recv(), timeout=0.5)
|
||||||
|
if isinstance(msg, str):
|
||||||
|
control_frames.append(json.loads(msg))
|
||||||
|
except (asyncio.TimeoutError, Exception):
|
||||||
|
pass
|
||||||
|
finally:
|
||||||
|
done.set()
|
||||||
|
|
||||||
|
server = await websockets.serve(server_handler, "127.0.0.1", 0)
|
||||||
|
port = server.sockets[0].getsockname()[1]
|
||||||
|
try:
|
||||||
|
client = WsClient(
|
||||||
|
f"ws://127.0.0.1:{port}",
|
||||||
|
token="",
|
||||||
|
heartbeat_interval=10.0,
|
||||||
|
reconnect_pause=0.05,
|
||||||
|
)
|
||||||
|
streamer = Streamer(
|
||||||
|
ws=client,
|
||||||
|
sdr_factory=lambda d, i: sdr,
|
||||||
|
cfg=AgentConfig(tx_enabled=True),
|
||||||
|
)
|
||||||
|
task = asyncio.create_task(
|
||||||
|
client.run(
|
||||||
|
on_message=streamer.on_message,
|
||||||
|
heartbeat=streamer.build_heartbeat,
|
||||||
|
on_binary=streamer.on_binary,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
await asyncio.wait_for(done.wait(), timeout=5.0)
|
||||||
|
client.stop()
|
||||||
|
task.cancel()
|
||||||
|
try:
|
||||||
|
await task
|
||||||
|
except (asyncio.CancelledError, Exception):
|
||||||
|
pass
|
||||||
|
finally:
|
||||||
|
server.close()
|
||||||
|
await server.wait_closed()
|
||||||
|
return control_frames, streamer
|
||||||
|
|
||||||
|
controls, streamer = asyncio.run(scenario())
|
||||||
|
|
||||||
|
# Heartbeat reached the server.
|
||||||
|
assert any(f.get("type") == "heartbeat" for f in controls)
|
||||||
|
# tx_status lifecycle: armed → transmitting → done.
|
||||||
|
tx_states = [f["state"] for f in controls if f.get("type") == "tx_status"]
|
||||||
|
assert tx_states[0] == "armed"
|
||||||
|
assert "transmitting" in tx_states
|
||||||
|
assert tx_states[-1] == "done"
|
||||||
|
# TX callback saw our marker buffer at least once.
|
||||||
|
assert any(np.array_equal(b, marker) for b in sdr.tx_produced)
|
||||||
|
# Session cleared.
|
||||||
|
assert streamer._tx is None
|
||||||
|
|
@ -46,15 +46,29 @@ def test_apply_sdr_config_sets_attributes():
|
||||||
|
|
||||||
|
|
||||||
def test_heartbeat_reflects_status_and_app():
|
def test_heartbeat_reflects_status_and_app():
|
||||||
s = Streamer(ws=FakeWs(), sdr_factory=_factory)
|
async def scenario():
|
||||||
hb = s.build_heartbeat()
|
s = Streamer(ws=FakeWs(), sdr_factory=_factory)
|
||||||
assert hb["type"] == "heartbeat"
|
hb = s.build_heartbeat()
|
||||||
assert hb["status"] == "idle"
|
assert hb["type"] == "heartbeat"
|
||||||
s._status = "streaming"
|
assert hb["status"] == "idle"
|
||||||
s._app_id = "app-42"
|
# capabilities default to rx-only
|
||||||
hb2 = s.build_heartbeat()
|
assert hb["capabilities"] == ["rx"]
|
||||||
assert hb2["status"] == "streaming"
|
assert hb["tx_enabled"] is False
|
||||||
assert hb2["app_id"] == "app-42"
|
|
||||||
|
await s.on_message(
|
||||||
|
{
|
||||||
|
"type": "start",
|
||||||
|
"app_id": "app-42",
|
||||||
|
"radio_config": {"device": "mock", "buffer_size": 32},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
hb2 = s.build_heartbeat()
|
||||||
|
assert hb2["status"] == "streaming"
|
||||||
|
assert hb2["app_id"] == "app-42"
|
||||||
|
assert hb2["sessions"]["rx"]["app_id"] == "app-42"
|
||||||
|
await s.on_message({"type": "stop", "app_id": "app-42"})
|
||||||
|
|
||||||
|
asyncio.run(scenario())
|
||||||
|
|
||||||
|
|
||||||
def test_full_start_stream_stop_cycle():
|
def test_full_start_stream_stop_cycle():
|
||||||
|
|
@ -89,7 +103,7 @@ def test_full_start_stream_stop_cycle():
|
||||||
statuses = [m for m in ws.json_sent if m.get("type") == "status"]
|
statuses = [m for m in ws.json_sent if m.get("type") == "status"]
|
||||||
assert statuses[0]["status"] == "streaming"
|
assert statuses[0]["status"] == "streaming"
|
||||||
assert statuses[-1]["status"] == "idle"
|
assert statuses[-1]["status"] == "idle"
|
||||||
assert streamer._sdr is None
|
assert streamer._rx is None
|
||||||
|
|
||||||
|
|
||||||
def test_start_without_device_emits_error():
|
def test_start_without_device_emits_error():
|
||||||
|
|
@ -110,6 +124,7 @@ def test_configure_queues_update():
|
||||||
await streamer.on_message(
|
await streamer.on_message(
|
||||||
{"type": "configure", "app_id": "x", "radio_config": {"center_frequency": 915e6}}
|
{"type": "configure", "app_id": "x", "radio_config": {"center_frequency": 915e6}}
|
||||||
)
|
)
|
||||||
|
# Before start(), pending config lives on the standalone dict exposed via the _pending_config shim.
|
||||||
return streamer._pending_config
|
return streamer._pending_config
|
||||||
|
|
||||||
pending = asyncio.run(scenario())
|
pending = asyncio.run(scenario())
|
||||||
|
|
@ -122,3 +137,56 @@ def test_unknown_message_type_is_ignored():
|
||||||
await s.on_message({"type": "nope"})
|
await s.on_message({"type": "nope"})
|
||||||
|
|
||||||
asyncio.run(scenario())
|
asyncio.run(scenario())
|
||||||
|
|
||||||
|
|
||||||
|
def test_registry_shares_sdr_across_start_stop_cycles():
|
||||||
|
# Two sequential start/stop cycles with the same (device, identifier)
|
||||||
|
# should hit the registry's cache path rather than constructing a new SDR.
|
||||||
|
built: list[MockSDR] = []
|
||||||
|
|
||||||
|
def counting_factory(device: str, identifier):
|
||||||
|
sdr = MockSDR(buffer_size=16, seed=0)
|
||||||
|
built.append(sdr)
|
||||||
|
return sdr
|
||||||
|
|
||||||
|
async def scenario():
|
||||||
|
s = Streamer(ws=FakeWs(), sdr_factory=counting_factory)
|
||||||
|
for _ in range(2):
|
||||||
|
await s.on_message(
|
||||||
|
{
|
||||||
|
"type": "start",
|
||||||
|
"app_id": "a",
|
||||||
|
"radio_config": {"device": "mock", "buffer_size": 16},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
# Let one capture buffer flow before stopping so the loop is engaged.
|
||||||
|
await asyncio.sleep(0.02)
|
||||||
|
await s.on_message({"type": "stop", "app_id": "a"})
|
||||||
|
|
||||||
|
asyncio.run(scenario())
|
||||||
|
# A new SDR per cycle (we fully close between starts) — registry refcount
|
||||||
|
# drops to zero on each stop. This test confirms close-and-rebuild works;
|
||||||
|
# the ref-counting share-while-open case is covered in the full-duplex tests.
|
||||||
|
assert len(built) == 2
|
||||||
|
|
||||||
|
|
||||||
|
def test_tx_start_rejected_when_tx_disabled():
|
||||||
|
from ria_toolkit_oss.agent.config import AgentConfig
|
||||||
|
|
||||||
|
async def scenario():
|
||||||
|
ws = FakeWs()
|
||||||
|
s = Streamer(ws=ws, sdr_factory=_factory, cfg=AgentConfig(tx_enabled=False))
|
||||||
|
await s.on_message(
|
||||||
|
{
|
||||||
|
"type": "tx_start",
|
||||||
|
"app_id": "a",
|
||||||
|
"radio_config": {"device": "mock", "tx_center_frequency": 2.45e9, "tx_gain": -20},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return ws
|
||||||
|
|
||||||
|
ws = asyncio.run(scenario())
|
||||||
|
tx_statuses = [m for m in ws.json_sent if m.get("type") == "tx_status"]
|
||||||
|
assert tx_statuses, "expected a tx_status frame"
|
||||||
|
assert tx_statuses[-1]["state"] == "error"
|
||||||
|
assert "disabled" in tx_statuses[-1]["message"].lower()
|
||||||
|
|
|
||||||
133
tests/agent/test_streamer_tx.py
Normal file
133
tests/agent/test_streamer_tx.py
Normal file
|
|
@ -0,0 +1,133 @@
|
||||||
|
"""TX streaming happy path + shutdown semantics."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import time
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from ria_toolkit_oss.agent.config import AgentConfig
|
||||||
|
from ria_toolkit_oss.agent.streamer import Streamer
|
||||||
|
from ria_toolkit_oss.sdr.mock import MockSDR
|
||||||
|
|
||||||
|
|
||||||
|
class RecordingMockSDR(MockSDR):
|
||||||
|
"""MockSDR that records each TX callback's returned buffer."""
|
||||||
|
|
||||||
|
def __init__(self, buffer_size: int):
|
||||||
|
super().__init__(buffer_size=buffer_size)
|
||||||
|
self.tx_produced: list[np.ndarray] = []
|
||||||
|
|
||||||
|
def _stream_tx(self, callback) -> None:
|
||||||
|
self._enable_tx = True
|
||||||
|
self._tx_initialized = True
|
||||||
|
while self._enable_tx:
|
||||||
|
result = callback(self.rx_buffer_size)
|
||||||
|
self.tx_produced.append(np.asarray(result))
|
||||||
|
time.sleep(0.005)
|
||||||
|
|
||||||
|
|
||||||
|
class FakeWs:
|
||||||
|
def __init__(self):
|
||||||
|
self.json_sent: list[dict] = []
|
||||||
|
self.bytes_sent: list[bytes] = []
|
||||||
|
|
||||||
|
async def send_json(self, payload):
|
||||||
|
self.json_sent.append(payload)
|
||||||
|
|
||||||
|
async def send_bytes(self, data):
|
||||||
|
self.bytes_sent.append(data)
|
||||||
|
|
||||||
|
|
||||||
|
def _iq_frame(samples: np.ndarray) -> bytes:
|
||||||
|
interleaved = np.empty(samples.size * 2, dtype=np.float32)
|
||||||
|
interleaved[0::2] = samples.real
|
||||||
|
interleaved[1::2] = samples.imag
|
||||||
|
return interleaved.tobytes()
|
||||||
|
|
||||||
|
|
||||||
|
def test_tx_start_streams_binary_to_callback():
|
||||||
|
BUF = 16
|
||||||
|
sdr = RecordingMockSDR(buffer_size=BUF)
|
||||||
|
|
||||||
|
async def scenario():
|
||||||
|
ws = FakeWs()
|
||||||
|
s = Streamer(ws=ws, sdr_factory=lambda d, i: sdr, cfg=AgentConfig(tx_enabled=True))
|
||||||
|
|
||||||
|
# Frames of distinct content so we can assert ordering.
|
||||||
|
frame_a = np.arange(BUF, dtype=np.complex64) * (1 + 0j)
|
||||||
|
frame_b = (np.arange(BUF, dtype=np.complex64) + BUF) * (1 + 0j)
|
||||||
|
frame_c = (np.arange(BUF, dtype=np.complex64) + 2 * BUF) * (1 + 0j)
|
||||||
|
|
||||||
|
await s.on_message(
|
||||||
|
{
|
||||||
|
"type": "tx_start",
|
||||||
|
"app_id": "app-1",
|
||||||
|
"radio_config": {
|
||||||
|
"device": "mock",
|
||||||
|
"buffer_size": BUF,
|
||||||
|
"tx_sample_rate": 1_000_000,
|
||||||
|
"tx_center_frequency": 2.45e9,
|
||||||
|
"tx_gain": -20,
|
||||||
|
"underrun_policy": "zero",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
# Push three IQ frames.
|
||||||
|
await s.on_binary(_iq_frame(frame_a))
|
||||||
|
await s.on_binary(_iq_frame(frame_b))
|
||||||
|
await s.on_binary(_iq_frame(frame_c))
|
||||||
|
|
||||||
|
# Let the executor thread consume them.
|
||||||
|
for _ in range(100):
|
||||||
|
# At least the 3 real frames, plus any zero-fill from before they
|
||||||
|
# arrived. We stop once 3 non-trivial buffers are recorded.
|
||||||
|
nontrivial = [b for b in sdr.tx_produced if np.any(b != 0)]
|
||||||
|
if len(nontrivial) >= 3:
|
||||||
|
break
|
||||||
|
await asyncio.sleep(0.01)
|
||||||
|
|
||||||
|
await s.on_message({"type": "tx_stop", "app_id": "app-1"})
|
||||||
|
return ws, sdr, s
|
||||||
|
|
||||||
|
ws, sdr, streamer = asyncio.run(scenario())
|
||||||
|
|
||||||
|
nontrivial = [b for b in sdr.tx_produced if np.any(b != 0)]
|
||||||
|
assert len(nontrivial) >= 3, "expected ≥3 nontrivial TX buffers"
|
||||||
|
|
||||||
|
# First three nontrivial buffers match the order we pushed them.
|
||||||
|
np.testing.assert_array_equal(nontrivial[0], np.arange(BUF, dtype=np.complex64))
|
||||||
|
np.testing.assert_array_equal(nontrivial[1], np.arange(BUF, 2 * BUF, dtype=np.complex64))
|
||||||
|
np.testing.assert_array_equal(nontrivial[2], np.arange(2 * BUF, 3 * BUF, dtype=np.complex64))
|
||||||
|
|
||||||
|
# Lifecycle: armed → transmitting → done.
|
||||||
|
states = [m["state"] for m in ws.json_sent if m.get("type") == "tx_status"]
|
||||||
|
assert states[0] == "armed"
|
||||||
|
assert "transmitting" in states
|
||||||
|
assert states[-1] == "done"
|
||||||
|
# Session cleared.
|
||||||
|
assert streamer._tx is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_tx_stop_releases_sdr():
|
||||||
|
sdr = RecordingMockSDR(buffer_size=8)
|
||||||
|
|
||||||
|
async def scenario():
|
||||||
|
ws = FakeWs()
|
||||||
|
s = Streamer(ws=ws, sdr_factory=lambda d, i: sdr, cfg=AgentConfig(tx_enabled=True))
|
||||||
|
await s.on_message(
|
||||||
|
{
|
||||||
|
"type": "tx_start",
|
||||||
|
"app_id": "a",
|
||||||
|
"radio_config": {"device": "mock", "buffer_size": 8, "underrun_policy": "zero"},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
await asyncio.sleep(0.03)
|
||||||
|
await s.on_message({"type": "tx_stop", "app_id": "a"})
|
||||||
|
return s
|
||||||
|
|
||||||
|
s = asyncio.run(scenario())
|
||||||
|
# After stop, the registry has no outstanding references to ("mock", None).
|
||||||
|
assert s._registry.refcount(("mock", None)) == 0
|
||||||
|
assert s._tx is None
|
||||||
167
tests/agent/test_tx_safety.py
Normal file
167
tests/agent/test_tx_safety.py
Normal file
|
|
@ -0,0 +1,167 @@
|
||||||
|
"""Agent-side TX interlocks: gain cap, freq ranges, duplicate sessions, disabled."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
from ria_toolkit_oss.agent.config import AgentConfig
|
||||||
|
from ria_toolkit_oss.agent.streamer import Streamer
|
||||||
|
from ria_toolkit_oss.sdr.mock import MockSDR
|
||||||
|
|
||||||
|
|
||||||
|
class FakeWs:
|
||||||
|
def __init__(self):
|
||||||
|
self.json_sent = []
|
||||||
|
self.bytes_sent = []
|
||||||
|
|
||||||
|
async def send_json(self, p):
|
||||||
|
self.json_sent.append(p)
|
||||||
|
|
||||||
|
async def send_bytes(self, b):
|
||||||
|
self.bytes_sent.append(b)
|
||||||
|
|
||||||
|
|
||||||
|
def _last_tx_status(ws):
|
||||||
|
frames = [m for m in ws.json_sent if m.get("type") == "tx_status"]
|
||||||
|
return frames[-1] if frames else None
|
||||||
|
|
||||||
|
|
||||||
|
def _tx_start(app_id="a", **radio):
|
||||||
|
rc = {"device": "mock", "buffer_size": 16, "underrun_policy": "zero"}
|
||||||
|
rc.update(radio)
|
||||||
|
return {"type": "tx_start", "app_id": app_id, "radio_config": rc}
|
||||||
|
|
||||||
|
|
||||||
|
def _make_streamer(cfg):
|
||||||
|
built: list = []
|
||||||
|
|
||||||
|
def factory(device, identifier):
|
||||||
|
sdr = MockSDR(buffer_size=16)
|
||||||
|
built.append(sdr)
|
||||||
|
return sdr
|
||||||
|
|
||||||
|
ws = FakeWs()
|
||||||
|
s = Streamer(ws=ws, sdr_factory=factory, cfg=cfg)
|
||||||
|
return s, ws, built
|
||||||
|
|
||||||
|
|
||||||
|
def test_rejects_when_tx_disabled():
|
||||||
|
async def scenario():
|
||||||
|
s, ws, built = _make_streamer(AgentConfig(tx_enabled=False))
|
||||||
|
await s.on_message(_tx_start(tx_gain=-20, tx_center_frequency=2.45e9))
|
||||||
|
return s, ws, built
|
||||||
|
|
||||||
|
s, ws, built = asyncio.run(scenario())
|
||||||
|
status = _last_tx_status(ws)
|
||||||
|
assert status and status["state"] == "error"
|
||||||
|
assert "disabled" in status["message"].lower()
|
||||||
|
assert not built, "SDR should never have been constructed"
|
||||||
|
assert s._tx is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_rejects_when_tx_gain_exceeds_cap():
|
||||||
|
async def scenario():
|
||||||
|
s, ws, built = _make_streamer(AgentConfig(tx_enabled=True, tx_max_gain_db=-15.0))
|
||||||
|
await s.on_message(_tx_start(tx_gain=-5, tx_center_frequency=2.45e9))
|
||||||
|
return ws, built
|
||||||
|
|
||||||
|
ws, built = asyncio.run(scenario())
|
||||||
|
status = _last_tx_status(ws)
|
||||||
|
assert status and status["state"] == "error"
|
||||||
|
assert "exceeds cap" in status["message"]
|
||||||
|
assert not built
|
||||||
|
|
||||||
|
|
||||||
|
def test_allows_gain_at_cap_boundary():
|
||||||
|
async def scenario():
|
||||||
|
s, ws, _ = _make_streamer(AgentConfig(tx_enabled=True, tx_max_gain_db=-10.0))
|
||||||
|
await s.on_message(_tx_start(tx_gain=-10, tx_center_frequency=2.45e9))
|
||||||
|
# Stop promptly to avoid keeping an executor thread around.
|
||||||
|
await asyncio.sleep(0.02)
|
||||||
|
await s.on_message({"type": "tx_stop", "app_id": "a"})
|
||||||
|
return ws
|
||||||
|
|
||||||
|
ws = asyncio.run(scenario())
|
||||||
|
states = [m["state"] for m in ws.json_sent if m.get("type") == "tx_status"]
|
||||||
|
assert "armed" in states
|
||||||
|
assert states[-1] == "done"
|
||||||
|
|
||||||
|
|
||||||
|
def test_rejects_when_freq_outside_ranges():
|
||||||
|
async def scenario():
|
||||||
|
s, ws, built = _make_streamer(
|
||||||
|
AgentConfig(
|
||||||
|
tx_enabled=True,
|
||||||
|
tx_allowed_freq_ranges=[[2.4e9, 2.5e9]],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
await s.on_message(_tx_start(tx_center_frequency=5.8e9, tx_gain=-20))
|
||||||
|
return ws, built
|
||||||
|
|
||||||
|
ws, built = asyncio.run(scenario())
|
||||||
|
status = _last_tx_status(ws)
|
||||||
|
assert status and status["state"] == "error"
|
||||||
|
assert "outside allowed ranges" in status["message"]
|
||||||
|
assert not built
|
||||||
|
|
||||||
|
|
||||||
|
def test_allows_freq_inside_a_range():
|
||||||
|
async def scenario():
|
||||||
|
s, ws, _ = _make_streamer(
|
||||||
|
AgentConfig(
|
||||||
|
tx_enabled=True,
|
||||||
|
tx_allowed_freq_ranges=[[2.4e9, 2.5e9], [5.7e9, 5.8e9]],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
await s.on_message(_tx_start(tx_center_frequency=5.75e9, tx_gain=-20))
|
||||||
|
await asyncio.sleep(0.02)
|
||||||
|
await s.on_message({"type": "tx_stop", "app_id": "a"})
|
||||||
|
return ws
|
||||||
|
|
||||||
|
ws = asyncio.run(scenario())
|
||||||
|
states = [m["state"] for m in ws.json_sent if m.get("type") == "tx_status"]
|
||||||
|
assert "armed" in states
|
||||||
|
assert states[-1] == "done"
|
||||||
|
|
||||||
|
|
||||||
|
def test_rejects_duplicate_tx_session():
|
||||||
|
async def scenario():
|
||||||
|
s, ws, _ = _make_streamer(AgentConfig(tx_enabled=True))
|
||||||
|
await s.on_message(_tx_start(app_id="a", tx_gain=-20, tx_center_frequency=2.45e9))
|
||||||
|
await asyncio.sleep(0.01)
|
||||||
|
await s.on_message(_tx_start(app_id="b", tx_gain=-20, tx_center_frequency=2.45e9))
|
||||||
|
# Let the second request process, then stop cleanly.
|
||||||
|
await asyncio.sleep(0.01)
|
||||||
|
await s.on_message({"type": "tx_stop", "app_id": "a"})
|
||||||
|
return ws
|
||||||
|
|
||||||
|
ws = asyncio.run(scenario())
|
||||||
|
errors = [
|
||||||
|
m for m in ws.json_sent
|
||||||
|
if m.get("type") == "tx_status" and m.get("state") == "error"
|
||||||
|
]
|
||||||
|
assert any("already active" in e.get("message", "") for e in errors)
|
||||||
|
|
||||||
|
|
||||||
|
def test_rejects_invalid_underrun_policy():
|
||||||
|
async def scenario():
|
||||||
|
s, ws, _ = _make_streamer(AgentConfig(tx_enabled=True))
|
||||||
|
await s.on_message(
|
||||||
|
{
|
||||||
|
"type": "tx_start",
|
||||||
|
"app_id": "a",
|
||||||
|
"radio_config": {
|
||||||
|
"device": "mock",
|
||||||
|
"buffer_size": 8,
|
||||||
|
"tx_gain": -20,
|
||||||
|
"tx_center_frequency": 2.45e9,
|
||||||
|
"underrun_policy": "teleport",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return ws
|
||||||
|
|
||||||
|
ws = asyncio.run(scenario())
|
||||||
|
status = _last_tx_status(ws)
|
||||||
|
assert status and status["state"] == "error"
|
||||||
|
assert "underrun_policy" in status["message"]
|
||||||
136
tests/agent/test_tx_underrun.py
Normal file
136
tests/agent/test_tx_underrun.py
Normal file
|
|
@ -0,0 +1,136 @@
|
||||||
|
"""Underrun policies: pause, zero, repeat."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import time
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from ria_toolkit_oss.agent.config import AgentConfig
|
||||||
|
from ria_toolkit_oss.agent.streamer import Streamer
|
||||||
|
from ria_toolkit_oss.sdr.mock import MockSDR
|
||||||
|
|
||||||
|
|
||||||
|
class RecordingMockSDR(MockSDR):
|
||||||
|
def __init__(self, buffer_size: int):
|
||||||
|
super().__init__(buffer_size=buffer_size)
|
||||||
|
self.tx_produced: list[np.ndarray] = []
|
||||||
|
|
||||||
|
def _stream_tx(self, callback):
|
||||||
|
self._enable_tx = True
|
||||||
|
self._tx_initialized = True
|
||||||
|
while self._enable_tx:
|
||||||
|
result = callback(self.rx_buffer_size)
|
||||||
|
self.tx_produced.append(np.asarray(result).copy())
|
||||||
|
time.sleep(0.005)
|
||||||
|
|
||||||
|
|
||||||
|
class FakeWs:
|
||||||
|
def __init__(self):
|
||||||
|
self.json_sent = []
|
||||||
|
self.bytes_sent = []
|
||||||
|
|
||||||
|
async def send_json(self, p):
|
||||||
|
self.json_sent.append(p)
|
||||||
|
|
||||||
|
async def send_bytes(self, b):
|
||||||
|
self.bytes_sent.append(b)
|
||||||
|
|
||||||
|
|
||||||
|
def _iq_frame(samples: np.ndarray) -> bytes:
|
||||||
|
interleaved = np.empty(samples.size * 2, dtype=np.float32)
|
||||||
|
interleaved[0::2] = samples.real
|
||||||
|
interleaved[1::2] = samples.imag
|
||||||
|
return interleaved.tobytes()
|
||||||
|
|
||||||
|
|
||||||
|
def _start_cfg(policy: str, buf: int = 8) -> dict:
|
||||||
|
return {
|
||||||
|
"type": "tx_start",
|
||||||
|
"app_id": "a",
|
||||||
|
"radio_config": {
|
||||||
|
"device": "mock",
|
||||||
|
"buffer_size": buf,
|
||||||
|
"tx_gain": -20,
|
||||||
|
"tx_center_frequency": 2.45e9,
|
||||||
|
"underrun_policy": policy,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def test_underrun_pause_stops_session_and_emits_status():
|
||||||
|
sdr = RecordingMockSDR(buffer_size=8)
|
||||||
|
|
||||||
|
async def scenario():
|
||||||
|
ws = FakeWs()
|
||||||
|
s = Streamer(ws=ws, sdr_factory=lambda d, i: sdr, cfg=AgentConfig(tx_enabled=True))
|
||||||
|
await s.on_message(_start_cfg("pause"))
|
||||||
|
# Do not push any buffers. The callback underruns on first tick and
|
||||||
|
# the watchdog should emit "underrun" and tear down.
|
||||||
|
for _ in range(100):
|
||||||
|
if any(
|
||||||
|
m.get("type") == "tx_status" and m.get("state") == "underrun"
|
||||||
|
for m in ws.json_sent
|
||||||
|
):
|
||||||
|
break
|
||||||
|
await asyncio.sleep(0.01)
|
||||||
|
for _ in range(50):
|
||||||
|
if s._tx is None:
|
||||||
|
break
|
||||||
|
await asyncio.sleep(0.01)
|
||||||
|
return ws, s
|
||||||
|
|
||||||
|
ws, s = asyncio.run(scenario())
|
||||||
|
states = [m["state"] for m in ws.json_sent if m.get("type") == "tx_status"]
|
||||||
|
assert "underrun" in states
|
||||||
|
assert s._tx is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_underrun_zero_keeps_session_alive():
|
||||||
|
sdr = RecordingMockSDR(buffer_size=8)
|
||||||
|
|
||||||
|
async def scenario():
|
||||||
|
ws = FakeWs()
|
||||||
|
s = Streamer(ws=ws, sdr_factory=lambda d, i: sdr, cfg=AgentConfig(tx_enabled=True))
|
||||||
|
await s.on_message(_start_cfg("zero"))
|
||||||
|
# Let it produce several underrun-filled buffers.
|
||||||
|
await asyncio.sleep(0.08)
|
||||||
|
still_alive = s._tx is not None
|
||||||
|
await s.on_message({"type": "tx_stop", "app_id": "a"})
|
||||||
|
return ws, still_alive
|
||||||
|
|
||||||
|
ws, still_alive = asyncio.run(scenario())
|
||||||
|
# No underrun status emitted (policy absorbs it silently).
|
||||||
|
assert not any(
|
||||||
|
m.get("type") == "tx_status" and m.get("state") == "underrun" for m in ws.json_sent
|
||||||
|
)
|
||||||
|
assert still_alive
|
||||||
|
# All produced buffers are zero (no real data was pushed).
|
||||||
|
assert sdr.tx_produced, "expected at least one TX callback invocation"
|
||||||
|
assert all(not np.any(b != 0) for b in sdr.tx_produced)
|
||||||
|
|
||||||
|
|
||||||
|
def test_underrun_repeat_replays_last_buffer():
|
||||||
|
BUF = 8
|
||||||
|
sdr = RecordingMockSDR(buffer_size=BUF)
|
||||||
|
marker = np.arange(BUF, dtype=np.complex64) + 1 # distinct non-zero buffer
|
||||||
|
|
||||||
|
async def scenario():
|
||||||
|
ws = FakeWs()
|
||||||
|
s = Streamer(ws=ws, sdr_factory=lambda d, i: sdr, cfg=AgentConfig(tx_enabled=True))
|
||||||
|
await s.on_message(_start_cfg("repeat", buf=BUF))
|
||||||
|
await s.on_binary(_iq_frame(marker))
|
||||||
|
# Give the executor time to consume the real frame + several repeats.
|
||||||
|
await asyncio.sleep(0.08)
|
||||||
|
await s.on_message({"type": "tx_stop", "app_id": "a"})
|
||||||
|
return ws, sdr
|
||||||
|
|
||||||
|
ws, sdr = asyncio.run(scenario())
|
||||||
|
# No underrun status emitted.
|
||||||
|
assert not any(
|
||||||
|
m.get("type") == "tx_status" and m.get("state") == "underrun" for m in ws.json_sent
|
||||||
|
)
|
||||||
|
# At least two buffers equal to the marker — the real one and ≥1 repeat.
|
||||||
|
matching = [b for b in sdr.tx_produced if np.array_equal(b, marker)]
|
||||||
|
assert len(matching) >= 2, f"expected ≥2 buffers matching marker, got {len(matching)}"
|
||||||
|
|
@ -113,6 +113,109 @@ def test_reconnects_after_server_drop():
|
||||||
assert n >= 2
|
assert n >= 2
|
||||||
|
|
||||||
|
|
||||||
|
def test_binary_frame_forwarded_to_handler():
|
||||||
|
payload = bytes(range(128))
|
||||||
|
|
||||||
|
async def scenario():
|
||||||
|
received: list[bytes] = []
|
||||||
|
done = asyncio.Event()
|
||||||
|
|
||||||
|
async def handler(ws):
|
||||||
|
await ws.send(payload)
|
||||||
|
done.set()
|
||||||
|
try:
|
||||||
|
await ws.wait_closed()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
server, port = await _open_server(handler)
|
||||||
|
try:
|
||||||
|
client = WsClient(
|
||||||
|
f"ws://127.0.0.1:{port}",
|
||||||
|
token="",
|
||||||
|
heartbeat_interval=10.0,
|
||||||
|
reconnect_pause=0.05,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def on_bin(data):
|
||||||
|
received.append(data)
|
||||||
|
|
||||||
|
task = asyncio.create_task(
|
||||||
|
client.run(
|
||||||
|
on_message=lambda _m: asyncio.sleep(0),
|
||||||
|
heartbeat=lambda: {"type": "heartbeat"},
|
||||||
|
on_binary=on_bin,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
for _ in range(50):
|
||||||
|
if received:
|
||||||
|
break
|
||||||
|
await asyncio.sleep(0.02)
|
||||||
|
client.stop()
|
||||||
|
task.cancel()
|
||||||
|
try:
|
||||||
|
await task
|
||||||
|
except (asyncio.CancelledError, Exception):
|
||||||
|
pass
|
||||||
|
finally:
|
||||||
|
server.close()
|
||||||
|
await server.wait_closed()
|
||||||
|
return received
|
||||||
|
|
||||||
|
received = asyncio.run(scenario())
|
||||||
|
assert received == [payload]
|
||||||
|
|
||||||
|
|
||||||
|
def test_binary_frame_dropped_when_no_handler():
|
||||||
|
# Regression guard: existing behavior (drop server-sent binary) preserved when
|
||||||
|
# on_binary is not supplied.
|
||||||
|
async def scenario():
|
||||||
|
crashes: list[Exception] = []
|
||||||
|
|
||||||
|
async def handler(ws):
|
||||||
|
await ws.send(b"\x00\x01\x02\x03")
|
||||||
|
await ws.send(json.dumps({"type": "ping"}))
|
||||||
|
try:
|
||||||
|
await ws.wait_closed()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
messages: list[dict] = []
|
||||||
|
server, port = await _open_server(handler)
|
||||||
|
try:
|
||||||
|
client = WsClient(
|
||||||
|
f"ws://127.0.0.1:{port}",
|
||||||
|
token="",
|
||||||
|
heartbeat_interval=10.0,
|
||||||
|
reconnect_pause=0.05,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def on_msg(m):
|
||||||
|
messages.append(m)
|
||||||
|
|
||||||
|
task = asyncio.create_task(
|
||||||
|
client.run(on_message=on_msg, heartbeat=lambda: {"type": "heartbeat"})
|
||||||
|
)
|
||||||
|
for _ in range(50):
|
||||||
|
if messages:
|
||||||
|
break
|
||||||
|
await asyncio.sleep(0.02)
|
||||||
|
client.stop()
|
||||||
|
task.cancel()
|
||||||
|
try:
|
||||||
|
await task
|
||||||
|
except (asyncio.CancelledError, Exception) as exc:
|
||||||
|
crashes.append(exc)
|
||||||
|
finally:
|
||||||
|
server.close()
|
||||||
|
await server.wait_closed()
|
||||||
|
return messages, crashes
|
||||||
|
|
||||||
|
messages, crashes = asyncio.run(scenario())
|
||||||
|
# JSON still delivered; binary silently dropped; no uncaught crash.
|
||||||
|
assert messages and messages[0] == {"type": "ping"}
|
||||||
|
|
||||||
|
|
||||||
def test_malformed_control_frame_does_not_crash():
|
def test_malformed_control_frame_does_not_crash():
|
||||||
async def scenario():
|
async def scenario():
|
||||||
handled: list[dict] = []
|
handled: list[dict] = []
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user