Compare commits

..

No commits in common. "22b035dbeea70003da6d6afc747ccf9d34f3f8aa" and "f03825a6db5affed034987985ffed09333ceb6d9" have entirely different histories.

25 changed files with 192 additions and 184 deletions

4
poetry.lock generated
View File

@ -1,4 +1,4 @@
# This file is automatically @generated by Poetry 2.1.2 and should not be changed by hand. # This file is automatically @generated by Poetry 2.3.4 and should not be changed by hand.
[[package]] [[package]]
name = "alabaster" name = "alabaster"
@ -1096,7 +1096,7 @@ files = [
[package.dependencies] [package.dependencies]
attrs = ">=22.2.0" attrs = ">=22.2.0"
jsonschema-specifications = ">=2023.03.6" jsonschema-specifications = ">=2023.3.6"
referencing = ">=0.28.4" referencing = ">=0.28.4"
rpds-py = ">=0.25.0" rpds-py = ">=0.25.0"

View File

@ -66,9 +66,8 @@ class LoggingFakeWs:
pass pass
def _make_iq_frame( def _make_iq_frame(buffer_size: int, tone_hz: float, sample_rate: float,
buffer_size: int, tone_hz: float, sample_rate: float, phase_offset: float = 0.0 phase_offset: float = 0.0) -> tuple[bytes, float]:
) -> tuple[bytes, float]:
"""Return ``(interleaved_float32_bytes, next_phase)`` for a sine tone. """Return ``(interleaved_float32_bytes, next_phase)`` for a sine tone.
Emitting one continuous phase-coherent tone requires threading the phase Emitting one continuous phase-coherent tone requires threading the phase
@ -94,9 +93,7 @@ def _make_pluto_factory(identifier: str | None):
if device != "pluto": if device != "pluto":
raise ValueError(f"this script only drives pluto; got device={device!r}") raise ValueError(f"this script only drives pluto; got device={device!r}")
from ria_toolkit_oss.sdr.pluto import Pluto from ria_toolkit_oss.sdr.pluto import Pluto
return Pluto(identifier=identifier) return Pluto(identifier=identifier)
return factory return factory
@ -133,14 +130,13 @@ async def _run(args: argparse.Namespace) -> int:
# Abort if tx_start was rejected by an interlock (no session → nothing to do). # Abort if tx_start was rejected by an interlock (no session → nothing to do).
if streamer._tx is None: if streamer._tx is None:
print("tx_start rejected — see [tx_status] line above for the reason.", file=sys.stderr) print("tx_start rejected — see [tx_status] line above for the reason.",
file=sys.stderr)
return 2 return 2
print( print(f"Transmitting at {args.frequency/1e6:.3f} MHz with "
f"Transmitting at {args.frequency/1e6:.3f} MHz with " f"{args.tone/1e3:.1f} kHz baseband tone at gain {args.gain} dB. "
f"{args.tone/1e3:.1f} kHz baseband tone at gain {args.gain} dB. " f"{'Running for ' + str(args.duration) + 's' if args.duration > 0 else 'Run until Ctrl-C'}.")
f"{'Running for ' + str(args.duration) + 's' if args.duration > 0 else 'Run until Ctrl-C'}."
)
# Arrange a clean shutdown on Ctrl-C. # Arrange a clean shutdown on Ctrl-C.
stop = asyncio.Event() stop = asyncio.Event()
@ -161,11 +157,12 @@ async def _run(args: argparse.Namespace) -> int:
# topped up. The queue's own backpressure keeps us from spinning. # topped up. The queue's own backpressure keeps us from spinning.
produce_interval = buffer_dt * 0.5 produce_interval = buffer_dt * 0.5
try: try:
async def producer(): async def producer():
nonlocal phase nonlocal phase
while not stop.is_set(): while not stop.is_set():
frame, phase = _make_iq_frame(args.buffer_size, args.tone, args.sample_rate, phase) frame, phase = _make_iq_frame(
args.buffer_size, args.tone, args.sample_rate, phase
)
await streamer.on_binary(frame) await streamer.on_binary(frame)
await asyncio.sleep(produce_interval) await asyncio.sleep(produce_interval)
@ -196,17 +193,20 @@ def main() -> int:
p = argparse.ArgumentParser( p = argparse.ArgumentParser(
description="End-to-end TX smoke test: agent → Pluto continuous tone.", description="End-to-end TX smoke test: agent → Pluto continuous tone.",
) )
p.add_argument("--identifier", default=None, help="Pluto IP/hostname (default: auto-discover pluto.local)") p.add_argument("--identifier", default=None,
p.add_argument("--frequency", type=float, default=3_410_000_000.0, help="TX LO in Hz (default 2.45 GHz)") help="Pluto IP/hostname (default: auto-discover pluto.local)")
p.add_argument("--gain", type=float, default=-0.0, help="TX gain in dB; Pluto range [-89, 0] (default -30)") p.add_argument("--frequency", type=float, default=3_410_000_000.0,
p.add_argument("--sample-rate", type=float, default=1_000_000.0, help="Baseband sample rate (default 1 Msps)") help="TX LO in Hz (default 2.45 GHz)")
p.add_argument( p.add_argument("--gain", type=float, default=-0.0,
"--tone", type=float, default=100_000.0, help="Baseband tone offset in Hz; 0 = DC (default 100 kHz)" help="TX gain in dB; Pluto range [-89, 0] (default -30)")
) p.add_argument("--sample-rate", type=float, default=1_000_000.0,
p.add_argument("--buffer-size", type=int, default=4096, help="Complex samples per frame (default 4096)") help="Baseband sample rate (default 1 Msps)")
p.add_argument( p.add_argument("--tone", type=float, default=100_000.0,
"--duration", type=float, default=60.0, help="Seconds to transmit; 0 = run until Ctrl-C (default 30)" help="Baseband tone offset in Hz; 0 = DC (default 100 kHz)")
) p.add_argument("--buffer-size", type=int, default=4096,
help="Complex samples per frame (default 4096)")
p.add_argument("--duration", type=float, default=60.0,
help="Seconds to transmit; 0 = run until Ctrl-C (default 30)")
p.add_argument("--log-level", default="INFO") p.add_argument("--log-level", default="INFO")
args = p.parse_args() args = p.parse_args()

View File

@ -41,7 +41,8 @@ from ria_toolkit_oss.agent.streamer import Streamer
from ria_toolkit_oss.agent.ws_client import WsClient from ria_toolkit_oss.agent.ws_client import WsClient
def _make_iq_frame(buffer_size: int, tone_hz: float, sample_rate: float, phase_offset: float) -> tuple[bytes, float]: def _make_iq_frame(buffer_size: int, tone_hz: float, sample_rate: float,
phase_offset: float) -> tuple[bytes, float]:
n = np.arange(buffer_size, dtype=np.float64) n = np.arange(buffer_size, dtype=np.float64)
phase = 2.0 * np.pi * tone_hz / sample_rate * n + phase_offset phase = 2.0 * np.pi * tone_hz / sample_rate * n + phase_offset
amp = 0.7 amp = 0.7
@ -58,9 +59,7 @@ def _make_pluto_factory(identifier: str | None):
if device != "pluto": if device != "pluto":
raise ValueError(f"this script only drives pluto; got device={device!r}") raise ValueError(f"this script only drives pluto; got device={device!r}")
from ria_toolkit_oss.sdr.pluto import Pluto from ria_toolkit_oss.sdr.pluto import Pluto
return Pluto(identifier=identifier) return Pluto(identifier=identifier)
return factory return factory
@ -74,29 +73,27 @@ async def _mock_hub_handler(ws, args, stop: asyncio.Event):
payload = json.loads(first) payload = json.loads(first)
if payload.get("type") == "heartbeat": if payload.get("type") == "heartbeat":
caps = payload.get("capabilities") caps = payload.get("capabilities")
print(f"[mock-hub] agent heartbeat: capabilities={caps} " f"tx_enabled={payload.get('tx_enabled')}") print(f"[mock-hub] agent heartbeat: capabilities={caps} "
f"tx_enabled={payload.get('tx_enabled')}")
except asyncio.TimeoutError: except asyncio.TimeoutError:
print("[mock-hub] warning: no heartbeat received in first 2s") print("[mock-hub] warning: no heartbeat received in first 2s")
# Arm the agent's TX path. # Arm the agent's TX path.
await ws.send( await ws.send(json.dumps({
json.dumps( "type": "tx_start",
{ "app_id": "ws-smoke",
"type": "tx_start", "radio_config": {
"app_id": "ws-smoke", "device": "pluto",
"radio_config": { "identifier": args.identifier,
"device": "pluto", "tx_sample_rate": int(args.sample_rate),
"identifier": args.identifier, "tx_center_frequency": int(args.frequency),
"tx_sample_rate": int(args.sample_rate), "tx_gain": int(args.gain),
"tx_center_frequency": int(args.frequency), "buffer_size": int(args.buffer_size),
"tx_gain": int(args.gain), "underrun_policy": "repeat",
"buffer_size": int(args.buffer_size), },
"underrun_policy": "repeat", }))
}, print(f"[mock-hub] sent tx_start at {args.frequency/1e6:.3f} MHz, "
} f"gain={args.gain} dB")
)
)
print(f"[mock-hub] sent tx_start at {args.frequency/1e6:.3f} MHz, " f"gain={args.gain} dB")
# Producer: push IQ frames at a steady clip. Use a concurrent receiver so # Producer: push IQ frames at a steady clip. Use a concurrent receiver so
# tx_status frames show up in real time rather than being queued behind # tx_status frames show up in real time rather than being queued behind
@ -115,11 +112,15 @@ async def _mock_hub_handler(ws, args, stop: asyncio.Event):
recv_task = asyncio.create_task(receiver()) recv_task = asyncio.create_task(receiver())
try: try:
deadline = None if args.duration <= 0 else (asyncio.get_event_loop().time() + args.duration) deadline = None if args.duration <= 0 else (
asyncio.get_event_loop().time() + args.duration
)
while not stop.is_set(): while not stop.is_set():
if deadline is not None and asyncio.get_event_loop().time() >= deadline: if deadline is not None and asyncio.get_event_loop().time() >= deadline:
break break
frame, phase = _make_iq_frame(args.buffer_size, args.tone, args.sample_rate, phase) frame, phase = _make_iq_frame(
args.buffer_size, args.tone, args.sample_rate, phase
)
try: try:
await ws.send(frame) await ws.send(frame)
except websockets.ConnectionClosed: except websockets.ConnectionClosed:
@ -203,15 +204,20 @@ def main() -> int:
p = argparse.ArgumentParser( p = argparse.ArgumentParser(
description="Full-stack TX smoke: localhost mock-hub → WS → agent → Pluto.", description="Full-stack TX smoke: localhost mock-hub → WS → agent → Pluto.",
) )
p.add_argument("--identifier", default=None, help="Pluto IP/hostname (default: auto-discover pluto.local)") p.add_argument("--identifier", default=None,
p.add_argument("--frequency", type=float, default=2_450_000_000.0, help="TX LO in Hz (default 2.45 GHz)") help="Pluto IP/hostname (default: auto-discover pluto.local)")
p.add_argument("--gain", type=float, default=0.0, help="TX gain in dB; Pluto range [-89, 0] (default 0)") p.add_argument("--frequency", type=float, default=2_450_000_000.0,
p.add_argument("--sample-rate", type=float, default=1_000_000.0, help="Baseband sample rate (default 1 Msps)") help="TX LO in Hz (default 2.45 GHz)")
p.add_argument("--tone", type=float, default=100_000.0, help="Baseband tone offset in Hz (default 100 kHz)") p.add_argument("--gain", type=float, default=0.0,
p.add_argument("--buffer-size", type=int, default=4096, help="Complex samples per frame (default 4096)") help="TX gain in dB; Pluto range [-89, 0] (default 0)")
p.add_argument( p.add_argument("--sample-rate", type=float, default=1_000_000.0,
"--duration", type=float, default=30.0, help="Seconds to transmit; 0 = run until Ctrl-C (default 30)" help="Baseband sample rate (default 1 Msps)")
) p.add_argument("--tone", type=float, default=100_000.0,
help="Baseband tone offset in Hz (default 100 kHz)")
p.add_argument("--buffer-size", type=int, default=4096,
help="Complex samples per frame (default 4096)")
p.add_argument("--duration", type=float, default=30.0,
help="Seconds to transmit; 0 = run until Ctrl-C (default 30)")
p.add_argument("--log-level", default="INFO") p.add_argument("--log-level", default="INFO")
args = p.parse_args() args = p.parse_args()

View File

@ -118,9 +118,9 @@ def _derive_ws_url(hub_url: str, agent_id: str) -> str:
return "" return ""
base = hub_url.rstrip("/") base = hub_url.rstrip("/")
if base.startswith("https://"): if base.startswith("https://"):
base = "wss://" + base[len("https://") :] base = "wss://" + base[len("https://"):]
elif base.startswith("http://"): elif base.startswith("http://"):
base = "ws://" + base[len("http://") :] base = "ws://" + base[len("http://"):]
suffix = f"/screens/agent/ws?agent_id={agent_id}" if agent_id else "/screens/agent/ws" suffix = f"/screens/agent/ws?agent_id={agent_id}" if agent_id else "/screens/agent/ws"
return base + suffix return base + suffix

View File

@ -22,7 +22,6 @@ import os
from dataclasses import asdict, dataclass, field from dataclasses import asdict, dataclass, field
from pathlib import Path from pathlib import Path
def _resolve_default_path() -> Path: def _resolve_default_path() -> Path:
return Path(os.environ.get("RIA_AGENT_CONFIG", str(Path.home() / ".ria" / "agent.json"))) return Path(os.environ.get("RIA_AGENT_CONFIG", str(Path.home() / ".ria" / "agent.json")))

View File

@ -46,7 +46,9 @@ def heartbeat_payload(
if c.tx_max_duration_s is not None: if c.tx_max_duration_s is not None:
payload["tx_max_duration_s"] = float(c.tx_max_duration_s) payload["tx_max_duration_s"] = float(c.tx_max_duration_s)
if c.tx_allowed_freq_ranges: if c.tx_allowed_freq_ranges:
payload["tx_allowed_freq_ranges"] = [[float(lo), float(hi)] for lo, hi in c.tx_allowed_freq_ranges] payload["tx_allowed_freq_ranges"] = [
[float(lo), float(hi)] for lo, hi in c.tx_allowed_freq_ranges
]
if app_id: if app_id:
payload["app_id"] = app_id payload["app_id"] = app_id
if sessions: if sessions:

View File

@ -270,7 +270,9 @@ class Streamer:
) )
self._rx = session self._rx = session
await self._send_status("streaming", app_id) await self._send_status("streaming", app_id)
session.task = asyncio.create_task(self._capture_loop(session), name="ria-streamer-capture") session.task = asyncio.create_task(
self._capture_loop(session), name="ria-streamer-capture"
)
async def _handle_rx_stop(self, msg: dict) -> None: async def _handle_rx_stop(self, msg: dict) -> None:
session = self._rx session = self._rx
@ -308,7 +310,9 @@ class Streamer:
logger.warning("Applying configure failed: %s", exc) logger.warning("Applying configure failed: %s", exc)
try: try:
samples = await loop.run_in_executor(None, session.sdr.rx, session.buffer_size) samples = await loop.run_in_executor(
None, session.sdr.rx, session.buffer_size
)
except Exception as exc: except Exception as exc:
from ria_toolkit_oss.sdr import SdrDisconnectedError from ria_toolkit_oss.sdr import SdrDisconnectedError
@ -338,7 +342,7 @@ class Streamer:
# ================================================================== # ==================================================================
# TX # TX
async def _handle_tx_start(self, msg: dict) -> None: # noqa: C901 async def _handle_tx_start(self, msg: dict) -> None:
app_id = msg.get("app_id") or "" app_id = msg.get("app_id") or ""
radio_config = dict(msg.get("radio_config") or {}) radio_config = dict(msg.get("radio_config") or {})
@ -379,7 +383,9 @@ class Streamer:
buffer_size = int(radio_config.pop("buffer_size", _DEFAULT_BUFFER_SIZE)) buffer_size = int(radio_config.pop("buffer_size", _DEFAULT_BUFFER_SIZE))
underrun_policy = str(radio_config.pop("underrun_policy", "pause")) underrun_policy = str(radio_config.pop("underrun_policy", "pause"))
if underrun_policy not in ("pause", "zero", "repeat"): if underrun_policy not in ("pause", "zero", "repeat"):
await self._send_tx_status(app_id, "error", f"invalid underrun_policy {underrun_policy!r}") await self._send_tx_status(
app_id, "error", f"invalid underrun_policy {underrun_policy!r}"
)
return return
if not device: if not device:
await self._send_tx_status(app_id, "error", "tx_start missing radio_config.device") await self._send_tx_status(app_id, "error", "tx_start missing radio_config.device")
@ -398,10 +404,15 @@ class Streamer:
# manifest bug and we want it surfaced immediately, not papered # manifest bug and we want it surfaced immediately, not papered
# over with stale radio state. # over with stale radio state.
if hasattr(sdr, "init_tx"): if hasattr(sdr, "init_tx"):
init_args = {k: radio_config.get(f"tx_{k}") for k in ("sample_rate", "center_frequency", "gain")} init_args = {
k: radio_config.get(f"tx_{k}")
for k in ("sample_rate", "center_frequency", "gain")
}
missing = [f"tx_{k}" for k, v in init_args.items() if v is None] missing = [f"tx_{k}" for k, v in init_args.items() if v is None]
if missing: if missing:
raise ValueError(f"tx_start missing required radio_config keys: {missing}") raise ValueError(
f"tx_start missing required radio_config keys: {missing}"
)
sdr.init_tx( sdr.init_tx(
sample_rate=init_args["sample_rate"], sample_rate=init_args["sample_rate"],
center_frequency=init_args["center_frequency"], center_frequency=init_args["center_frequency"],
@ -487,8 +498,9 @@ class Streamer:
return _silence(n) return _silence(n)
# Max-duration watchdog. # Max-duration watchdog.
if session.max_duration_s is not None and (time.monotonic() - session.started_at) >= float( if (
session.max_duration_s session.max_duration_s is not None
and (time.monotonic() - session.started_at) >= float(session.max_duration_s)
): ):
session.stop_event.set() session.stop_event.set()
try: try:
@ -516,7 +528,7 @@ class Streamer:
if arr.size < 2 or arr.size % 2 != 0: if arr.size < 2 or arr.size % 2 != 0:
logger.warning("Malformed TX frame: %d floats (must be non-zero even count)", arr.size) logger.warning("Malformed TX frame: %d floats (must be non-zero even count)", arr.size)
return self._underrun_fill(session, n) return self._underrun_fill(session, n)
samples = arr[0::2].astype(np.complex64) + 1j * arr[1::2].astype(np.complex64) samples = (arr[0::2].astype(np.complex64) + 1j * arr[1::2].astype(np.complex64))
if samples.size < n: if samples.size < n:
out = np.zeros(n, dtype=np.complex64) out = np.zeros(n, dtype=np.complex64)
out[: samples.size] = samples out[: samples.size] = samples
@ -735,7 +747,6 @@ def _default_sdr_factory(device: str, identifier: str | None):
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# Top-level entry # Top-level entry
async def run_streamer(ws_url: str, token: str, *, cfg: AgentConfig | None = None) -> None: async def run_streamer(ws_url: str, token: str, *, cfg: AgentConfig | None = None) -> None:
"""Connect to *ws_url* and run the streamer loop until cancelled.""" """Connect to *ws_url* and run the streamer loop until cancelled."""
ws = WsClient(ws_url, token) ws = WsClient(ws_url, token)

View File

@ -37,7 +37,7 @@ def _engine(cfg: _config.AppConfig, sudo_override: bool = False) -> list[str]:
for exe in ("docker", "podman"): for exe in ("docker", "podman"):
if shutil.which(exe): if shutil.which(exe):
use_sudo = sudo_override or cfg.sudo use_sudo = sudo_override or cfg.sudo
return ["sudo", exe] if use_sudo else [exe] return (["sudo", exe] if use_sudo else [exe])
print("error: neither 'docker' nor 'podman' found on PATH", file=sys.stderr) print("error: neither 'docker' nor 'podman' found on PATH", file=sys.stderr)
sys.exit(2) sys.exit(2)
@ -96,9 +96,7 @@ def _hardware_flags(labels: dict, no_gpu: bool, no_usb: bool, no_host_net: bool)
if _gpu_available(): if _gpu_available():
flags += ["--gpus", "all"] flags += ["--gpus", "all"]
else: else:
notes.append( notes.append("image wants GPU but no NVIDIA runtime detected — skipping --gpus (use --force-gpu to override)")
"image wants GPU but no NVIDIA runtime detected — skipping --gpus (use --force-gpu to override)"
)
if hw_items & {"pluto", "rtlsdr", "hackrf", "bladerf"} and not no_usb: if hw_items & {"pluto", "rtlsdr", "hackrf", "bladerf"} and not no_usb:
flags += ["--device", "/dev/bus/usb"] flags += ["--device", "/dev/bus/usb"]

View File

@ -40,19 +40,15 @@ class RemoteTransmitter:
try: try:
if radio_str in ("pluto", "plutosdr"): if radio_str in ("pluto", "plutosdr"):
from ria_toolkit_oss.sdr.pluto import Pluto from ria_toolkit_oss.sdr.pluto import Pluto
self._sdr = Pluto(identifier) self._sdr = Pluto(identifier)
elif radio_str in ("usrp",): elif radio_str in ("usrp",):
from ria_toolkit_oss.sdr.usrp import USRP from ria_toolkit_oss.sdr.usrp import USRP
self._sdr = USRP(identifier) self._sdr = USRP(identifier)
elif radio_str in ("hackrf", "hackrf_one"): elif radio_str in ("hackrf", "hackrf_one"):
from ria_toolkit_oss.sdr.hackrf import HackRF from ria_toolkit_oss.sdr.hackrf import HackRF
self._sdr = HackRF(identifier) self._sdr = HackRF(identifier)
elif radio_str in ("bladerf", "blade"): elif radio_str in ("bladerf", "blade"):
from ria_toolkit_oss.sdr.blade import Blade from ria_toolkit_oss.sdr.blade import Blade
self._sdr = Blade(identifier) self._sdr = Blade(identifier)
else: else:
raise ValueError(f"Unknown SDR type: {radio_str!r}") raise ValueError(f"Unknown SDR type: {radio_str!r}")
@ -81,7 +77,6 @@ class RemoteTransmitter:
if self._sdr is None: if self._sdr is None:
raise RuntimeError("Call set_radio() and init_tx() before transmit()") raise RuntimeError("Call set_radio() and init_tx() before transmit()")
import time import time
# Transmit in a loop until duration has elapsed # Transmit in a loop until duration has elapsed
end = time.monotonic() + duration_s end = time.monotonic() + duration_s
while time.monotonic() < end: while time.monotonic() < end:

View File

@ -13,11 +13,6 @@ import json
import logging import logging
import threading import threading
import time import time
from typing import TYPE_CHECKING
if TYPE_CHECKING:
import paramiko
import zmq
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -163,21 +158,16 @@ class RemoteTransmitterController:
""" """
logger.info( logger.info(
"init_tx: fc=%.3f MHz, fs=%.3f MHz, gain=%.1f dB, ch=%d", "init_tx: fc=%.3f MHz, fs=%.3f MHz, gain=%.1f dB, ch=%d",
center_frequency / 1e6, center_frequency / 1e6, sample_rate / 1e6, gain, channel,
sample_rate / 1e6,
gain,
channel,
)
self._send(
{
"function_name": "init_tx",
"center_frequency": center_frequency,
"sample_rate": sample_rate,
"gain": gain,
"channel": channel,
"gain_mode": gain_mode,
}
) )
self._send({
"function_name": "init_tx",
"center_frequency": center_frequency,
"sample_rate": sample_rate,
"gain": gain,
"channel": channel,
"gain_mode": gain_mode,
})
def transmit_async(self, duration_s: float) -> None: def transmit_async(self, duration_s: float) -> None:
"""Start a timed CW transmission in a background thread. """Start a timed CW transmission in a background thread.

View File

@ -15,13 +15,8 @@ __all__ = [
] ]
from .mock import MockSDR from .mock import MockSDR
from .sdr import ( # noqa: F401 from .sdr import SDR, SDRError, SdrDisconnectedError, SDRParameterError, translate_disconnect # noqa: F401
SDR,
SdrDisconnectedError,
SDRError,
SDRParameterError,
translate_disconnect,
)
_DRIVER_CANDIDATES: tuple[tuple[str, str, str], ...] = ( _DRIVER_CANDIDATES: tuple[tuple[str, str, str], ...] = (
("mock", "ria_toolkit_oss.sdr.mock", "MockSDR"), ("mock", "ria_toolkit_oss.sdr.mock", "MockSDR"),

View File

@ -8,12 +8,7 @@ import adi
import numpy as np import numpy as np
from ria_toolkit_oss.datatypes.recording import Recording from ria_toolkit_oss.datatypes.recording import Recording
from ria_toolkit_oss.sdr.sdr import ( from ria_toolkit_oss.sdr.sdr import SDR, SDRError, SDRParameterError, translate_disconnect
SDR,
SDRError,
SDRParameterError,
translate_disconnect,
)
class Pluto(SDR): class Pluto(SDR):

View File

@ -583,7 +583,7 @@ _DISCONNECT_MARKERS = (
"i/o error", "i/o error",
"input/output error", "input/output error",
"errno 19", # ENODEV "errno 19", # ENODEV
"errno 5", # EIO "errno 5", # EIO
) )

View File

@ -26,11 +26,9 @@ class _FakeResp:
def _run_register(argv: list[str], cfg_path) -> int: def _run_register(argv: list[str], cfg_path) -> int:
fake_resp = _FakeResp({"agent_id": "agent-1", "token": "tok-abc"}) fake_resp = _FakeResp({"agent_id": "agent-1", "token": "tok-abc"})
with ( with patch.dict("os.environ", {"RIA_AGENT_CONFIG": str(cfg_path)}, clear=False), \
patch.dict("os.environ", {"RIA_AGENT_CONFIG": str(cfg_path)}, clear=False), patch("urllib.request.urlopen", return_value=fake_resp), \
patch("urllib.request.urlopen", return_value=fake_resp), patch.object(sys, "argv", ["ria-agent", *argv]):
patch.object(sys, "argv", ["ria-agent", *argv]),
):
try: try:
agent_cli.main() agent_cli.main()
except SystemExit as exc: except SystemExit as exc:
@ -98,11 +96,9 @@ def test_stream_allow_tx_does_not_persist(tmp_path):
captured["cfg"] = cfg captured["cfg"] = cfg
return None return None
with ( with patch.dict("os.environ", {"RIA_AGENT_CONFIG": str(cfg_path)}, clear=False), \
patch.dict("os.environ", {"RIA_AGENT_CONFIG": str(cfg_path)}, clear=False), patch("ria_toolkit_oss.agent.streamer.run_streamer", new=_fake_run_streamer), \
patch("ria_toolkit_oss.agent.streamer.run_streamer", new=_fake_run_streamer), patch.object(sys, "argv", ["ria-agent", "stream", "--allow-tx"]):
patch.object(sys, "argv", ["ria-agent", "stream", "--allow-tx"]),
):
try: try:
agent_cli.main() agent_cli.main()
except SystemExit: except SystemExit:

View File

@ -70,7 +70,9 @@ def test_server_start_stream_stop_cycle_over_real_ws():
reconnect_pause=0.05, reconnect_pause=0.05,
) )
streamer = Streamer(ws=client, sdr_factory=lambda d, i: MockSDR(buffer_size=32, seed=0)) streamer = Streamer(ws=client, sdr_factory=lambda d, i: MockSDR(buffer_size=32, seed=0))
task = asyncio.create_task(client.run(on_message=streamer.on_message, heartbeat=streamer.build_heartbeat)) task = asyncio.create_task(
client.run(on_message=streamer.on_message, heartbeat=streamer.build_heartbeat)
)
await asyncio.wait_for(ready.wait(), timeout=3.0) await asyncio.wait_for(ready.wait(), timeout=3.0)
await asyncio.wait_for(stopped.wait(), timeout=3.0) await asyncio.wait_for(stopped.wait(), timeout=3.0)
client.stop() client.stop()

View File

@ -77,7 +77,10 @@ def test_server_tx_start_binary_stop_cycle_over_real_ws():
msg = await asyncio.wait_for(ws.recv(), timeout=2.0) msg = await asyncio.wait_for(ws.recv(), timeout=2.0)
if isinstance(msg, str): if isinstance(msg, str):
control_frames.append(json.loads(msg)) control_frames.append(json.loads(msg))
if any(f.get("type") == "tx_status" and f.get("state") == "transmitting" for f in control_frames): if any(
f.get("type") == "tx_status" and f.get("state") == "transmitting"
for f in control_frames
):
break break
await ws.send(json.dumps({"type": "tx_stop", "app_id": "tx-app"})) await ws.send(json.dumps({"type": "tx_stop", "app_id": "tx-app"}))

View File

@ -30,6 +30,7 @@ from ria_toolkit_oss.agent.config import AgentConfig
from ria_toolkit_oss.agent.streamer import Streamer from ria_toolkit_oss.agent.streamer import Streamer
from ria_toolkit_oss.sdr.mock import MockSDR from ria_toolkit_oss.sdr.mock import MockSDR
_STRESS_S = float(os.environ.get("RIA_LOCK_STRESS_S", "2.0")) _STRESS_S = float(os.environ.get("RIA_LOCK_STRESS_S", "2.0"))
@ -155,21 +156,18 @@ def test_full_duplex_stays_healthy_over_stress_window():
s = Streamer(ws=ws, sdr_factory=lambda d, i: sdr, cfg=AgentConfig(tx_enabled=True)) s = Streamer(ws=ws, sdr_factory=lambda d, i: sdr, cfg=AgentConfig(tx_enabled=True))
await s.on_message( await s.on_message(
{"type": "start", "app_id": "app-1", "radio_config": {"device": "mock", "buffer_size": BUF}} {"type": "start", "app_id": "app-1",
"radio_config": {"device": "mock", "buffer_size": BUF}}
) )
await s.on_message( await s.on_message(
{ {"type": "tx_start", "app_id": "app-1",
"type": "tx_start", "radio_config": {
"app_id": "app-1", "device": "mock", "buffer_size": BUF,
"radio_config": { "tx_sample_rate": 1_000_000,
"device": "mock", "tx_center_frequency": 2.45e9,
"buffer_size": BUF, "tx_gain": -20,
"tx_sample_rate": 1_000_000, "underrun_policy": "zero",
"tx_center_frequency": 2.45e9, }}
"tx_gain": -20,
"underrun_policy": "zero",
},
}
) )
marker = np.arange(BUF, dtype=np.complex64) + 1 marker = np.arange(BUF, dtype=np.complex64) + 1
@ -182,10 +180,12 @@ def test_full_duplex_stays_healthy_over_stress_window():
# which routes through the same setters the stress test above # which routes through the same setters the stress test above
# verifies. # verifies.
await s.on_message( await s.on_message(
{"type": "tx_configure", "app_id": "app-1", "radio_config": {"tx_sample_rate": 1_000_000 + i}} {"type": "tx_configure", "app_id": "app-1",
"radio_config": {"tx_sample_rate": 1_000_000 + i}}
) )
await s.on_message( await s.on_message(
{"type": "configure", "app_id": "app-1", "radio_config": {"sample_rate": 2_000_000 + i}} {"type": "configure", "app_id": "app-1",
"radio_config": {"sample_rate": 2_000_000 + i}}
) )
i += 1 i += 1
await asyncio.sleep(0.005) await asyncio.sleep(0.005)
@ -197,7 +197,8 @@ def test_full_duplex_stays_healthy_over_stress_window():
ws, s = asyncio.run(scenario()) ws, s = asyncio.run(scenario())
# No error frame leaked out. # No error frame leaked out.
errors = [m for m in ws.json_sent if m.get("type") in ("error", "tx_status") and m.get("state") == "error"] errors = [m for m in ws.json_sent
if m.get("type") in ("error", "tx_status") and m.get("state") == "error"]
assert errors == [], f"Unexpected error frames: {errors}" assert errors == [], f"Unexpected error frames: {errors}"
# RX produced IQ frames and TX's callback ran — heartbeat-level contention # RX produced IQ frames and TX's callback ran — heartbeat-level contention
# check: both setter paths were hit at least once during configure dispatch. # check: both setter paths were hit at least once during configure dispatch.

View File

@ -121,7 +121,9 @@ def test_start_without_device_emits_error():
def test_configure_queues_update(): def test_configure_queues_update():
async def scenario(): async def scenario():
streamer = Streamer(ws=FakeWs(), sdr_factory=_factory) streamer = Streamer(ws=FakeWs(), sdr_factory=_factory)
await streamer.on_message({"type": "configure", "app_id": "x", "radio_config": {"center_frequency": 915e6}}) await streamer.on_message(
{"type": "configure", "app_id": "x", "radio_config": {"center_frequency": 915e6}}
)
# Before start(), pending config lives on the standalone dict exposed via the _pending_config shim. # Before start(), pending config lives on the standalone dict exposed via the _pending_config shim.
return streamer._pending_config return streamer._pending_config

View File

@ -143,7 +143,10 @@ def test_rejects_duplicate_tx_session():
return ws return ws
ws = asyncio.run(scenario()) ws = asyncio.run(scenario())
errors = [m for m in ws.json_sent if m.get("type") == "tx_status" and m.get("state") == "error"] errors = [
m for m in ws.json_sent
if m.get("type") == "tx_status" and m.get("state") == "error"
]
assert any("already active" in e.get("message", "") for e in errors) assert any("already active" in e.get("message", "") for e in errors)

View File

@ -70,7 +70,10 @@ def test_underrun_pause_stops_session_and_emits_status():
# Do not push any buffers. The callback underruns on first tick and # Do not push any buffers. The callback underruns on first tick and
# the watchdog should emit "underrun" and tear down. # the watchdog should emit "underrun" and tear down.
for _ in range(100): for _ in range(100):
if any(m.get("type") == "tx_status" and m.get("state") == "underrun" for m in ws.json_sent): if any(
m.get("type") == "tx_status" and m.get("state") == "underrun"
for m in ws.json_sent
):
break break
await asyncio.sleep(0.01) await asyncio.sleep(0.01)
for _ in range(50): for _ in range(50):
@ -100,7 +103,9 @@ def test_underrun_zero_keeps_session_alive():
ws, still_alive = asyncio.run(scenario()) ws, still_alive = asyncio.run(scenario())
# No underrun status emitted (policy absorbs it silently). # No underrun status emitted (policy absorbs it silently).
assert not any(m.get("type") == "tx_status" and m.get("state") == "underrun" for m in ws.json_sent) assert not any(
m.get("type") == "tx_status" and m.get("state") == "underrun" for m in ws.json_sent
)
assert still_alive assert still_alive
# All produced buffers are zero (no real data was pushed). # All produced buffers are zero (no real data was pushed).
assert sdr.tx_produced, "expected at least one TX callback invocation" assert sdr.tx_produced, "expected at least one TX callback invocation"
@ -124,7 +129,9 @@ def test_underrun_repeat_replays_last_buffer():
ws, sdr = asyncio.run(scenario()) ws, sdr = asyncio.run(scenario())
# No underrun status emitted. # No underrun status emitted.
assert not any(m.get("type") == "tx_status" and m.get("state") == "underrun" for m in ws.json_sent) assert not any(
m.get("type") == "tx_status" and m.get("state") == "underrun" for m in ws.json_sent
)
# At least two buffers equal to the marker — the real one and ≥1 repeat. # At least two buffers equal to the marker — the real one and ≥1 repeat.
matching = [b for b in sdr.tx_produced if np.array_equal(b, marker)] matching = [b for b in sdr.tx_produced if np.array_equal(b, marker)]
assert len(matching) >= 2, f"expected ≥2 buffers matching marker, got {len(matching)}" assert len(matching) >= 2, f"expected ≥2 buffers matching marker, got {len(matching)}"

View File

@ -142,7 +142,9 @@ def test_malformed_control_frame_does_not_crash():
async def on_msg(m): async def on_msg(m):
handled.append(m) handled.append(m)
task = asyncio.create_task(client.run(on_message=on_msg, heartbeat=lambda: {"type": "heartbeat"})) task = asyncio.create_task(
client.run(on_message=on_msg, heartbeat=lambda: {"type": "heartbeat"})
)
for _ in range(50): for _ in range(50):
if handled: if handled:
break break

View File

@ -102,7 +102,9 @@ def test_binary_frame_dropped_when_no_handler():
async def on_msg(m): async def on_msg(m):
messages.append(m) messages.append(m)
task = asyncio.create_task(client.run(on_message=on_msg, heartbeat=lambda: {"type": "heartbeat"})) task = asyncio.create_task(
client.run(on_message=on_msg, heartbeat=lambda: {"type": "heartbeat"})
)
for _ in range(50): for _ in range(50):
if messages: if messages:
break break

View File

@ -12,6 +12,7 @@ import pytest
from ria_toolkit_oss.remote_control.remote_transmitter import RemoteTransmitter from ria_toolkit_oss.remote_control.remote_transmitter import RemoteTransmitter
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# Helpers # Helpers
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@ -240,40 +241,34 @@ class TestRunFunction:
def test_init_tx_without_radio_returns_failure(self): def test_init_tx_without_radio_returns_failure(self):
tx = RemoteTransmitter() tx = RemoteTransmitter()
resp = tx.run_function( resp = tx.run_function({
{ "function_name": "init_tx",
"function_name": "init_tx", "center_frequency": 2.4e9,
"center_frequency": 2.4e9, "sample_rate": 20e6,
"sample_rate": 20e6, "gain": 0,
"gain": 0, })
}
)
assert resp["status"] is False assert resp["status"] is False
assert resp["error_message"] assert resp["error_message"]
def test_init_tx_with_radio_success(self): def test_init_tx_with_radio_success(self):
tx = self._tx_with_mock_sdr() tx = self._tx_with_mock_sdr()
resp = tx.run_function( resp = tx.run_function({
{ "function_name": "init_tx",
"function_name": "init_tx", "center_frequency": 2.4e9,
"center_frequency": 2.4e9, "sample_rate": 20e6,
"sample_rate": 20e6, "gain": 30,
"gain": 30, })
}
)
assert resp["status"] is True assert resp["status"] is True
def test_transmit_runs_for_short_duration(self): def test_transmit_runs_for_short_duration(self):
tx = self._tx_with_mock_sdr() tx = self._tx_with_mock_sdr()
tx._sdr.init_tx = MagicMock() tx._sdr.init_tx = MagicMock()
resp = tx.run_function( resp = tx.run_function({
{ "function_name": "init_tx",
"function_name": "init_tx", "center_frequency": 2.4e9,
"center_frequency": 2.4e9, "sample_rate": 20e6,
"sample_rate": 20e6, "gain": 0,
"gain": 0, })
}
)
resp = tx.run_function({"function_name": "transmit", "duration_s": 0.02}) resp = tx.run_function({"function_name": "transmit", "duration_s": 0.02})
assert resp["status"] is True assert resp["status"] is True

View File

@ -7,6 +7,8 @@ sys.modules so they run regardless of whether the packages are installed.
from __future__ import annotations from __future__ import annotations
import json import json
import sys
import threading
import time import time
from types import ModuleType from types import ModuleType
from unittest.mock import MagicMock, patch from unittest.mock import MagicMock, patch
@ -197,11 +199,15 @@ class TestErrorHandling:
def test_missing_paramiko_raises_runtime_error(self): def test_missing_paramiko_raises_runtime_error(self):
"""If paramiko is absent, connecting gives a clear RuntimeError.""" """If paramiko is absent, connecting gives a clear RuntimeError."""
import importlib
import ria_toolkit_oss.remote_control.remote_transmitter_controller as mod import ria_toolkit_oss.remote_control.remote_transmitter_controller as mod
with patch.dict("sys.modules", {"paramiko": None}): with patch.dict("sys.modules", {"paramiko": None}):
with pytest.raises((RuntimeError, ImportError)): with pytest.raises((RuntimeError, ImportError)):
mod.RemoteTransmitterController(host="h", ssh_user="u", ssh_key_path="/k") mod.RemoteTransmitterController(
host="h", ssh_user="u", ssh_key_path="/k"
)
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------

View File

@ -2,7 +2,7 @@
from __future__ import annotations from __future__ import annotations
from unittest.mock import MagicMock, patch from unittest.mock import MagicMock, call, patch
import pytest import pytest
@ -12,6 +12,7 @@ from ria_toolkit_oss.orchestration.campaign import (
TransmitterConfig, TransmitterConfig,
) )
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# Helpers # Helpers
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@ -178,7 +179,9 @@ class TestInitRemoteTxControllers:
} }
] ]
executor = _make_executor(d) executor = _make_executor(d)
with patch("ria_toolkit_oss.remote_control.RemoteTransmitterController") as mock_cls: with patch(
"ria_toolkit_oss.remote_control.RemoteTransmitterController"
) as mock_cls:
executor._init_remote_tx_controllers() executor._init_remote_tx_controllers()
mock_cls.assert_not_called() mock_cls.assert_not_called()
assert executor._remote_tx_controllers == {} assert executor._remote_tx_controllers == {}
@ -261,7 +264,7 @@ class TestStartTransmitterSdrRemote:
tx = executor.config.transmitters[0] tx = executor.config.transmitters[0]
step = CaptureStep(duration=5.0, label="nochan") step = CaptureStep(duration=5.0, label="nochan")
executor._start_transmitter(tx, step) executor._start_transmitter(tx, step)
_, kwargs = ctrl.init_tx.call_args _, kwargs = mock_ctrl_kwarg = ctrl.init_tx.call_args
assert kwargs["channel"] == 0 assert kwargs["channel"] == 0
def test_missing_controller_raises(self): def test_missing_controller_raises(self):
@ -378,11 +381,7 @@ class TestRunWithSdrRemote:
), ),
patch.object(executor, "_close_sdr"), patch.object(executor, "_close_sdr"),
patch.object(executor, "_close_remote_tx_controllers"), patch.object(executor, "_close_remote_tx_controllers"),
patch.object( patch.object(executor, "_execute_step", return_value=MagicMock(error=None, qa=MagicMock(flagged=False, snr_db=20.0, duration_s=10.0))),
executor,
"_execute_step",
return_value=MagicMock(error=None, qa=MagicMock(flagged=False, snr_db=20.0, duration_s=10.0)),
),
): ):
executor.run() executor.run()
@ -402,7 +401,6 @@ class TestTransmitBufferAndTimeout:
def _executor_with_ctrl(self): def _executor_with_ctrl(self):
from ria_toolkit_oss.orchestration.executor import CampaignExecutor from ria_toolkit_oss.orchestration.executor import CampaignExecutor
cfg = CampaignConfig.from_dict(_FULL_CAMPAIGN_DICT) cfg = CampaignConfig.from_dict(_FULL_CAMPAIGN_DICT)
executor = CampaignExecutor(cfg) executor = CampaignExecutor(cfg)
ctrl = MagicMock() ctrl = MagicMock()