Compare commits

..

No commits in common. "f03825a6db5affed034987985ffed09333ceb6d9" and "ea8ed56a7d910328e71cc811be015ba2619c7a02" have entirely different histories.

12 changed files with 10 additions and 1614 deletions

View File

@ -1,2 +0,0 @@
[virtualenvs.options]
system-site-packages = true

View File

@ -20,7 +20,7 @@ Usage::
The agent: The agent:
1. Registers with RIA Hub and receives a ``node_id``. 1. Registers with RIA Hub and receives a ``node_id``.
2. Sends a heartbeat every 30 s so the hub knows it is online. 2. Sends a heartbeat every 30 s so the hub knows it is online.
3. Long-polls ``GET /composer/nodes/{id}/commands`` (30 s timeout). 3. Long-polls ``GET /orchestrator/nodes/{id}/commands`` (30 s timeout).
4. Dispatches received commands: 4. Dispatches received commands:
- ``run_campaign``: executes via CampaignExecutor, uploads recordings. - ``run_campaign``: executes via CampaignExecutor, uploads recordings.
- ``load_model``: loads an ONNX fingerprint or detector model. - ``load_model``: loads an ONNX fingerprint or detector model.
@ -173,7 +173,7 @@ class NodeAgent:
if self._ort_available: if self._ort_available:
capabilities.append("inference") capabilities.append("inference")
resp = self._post( resp = self._post(
"/composer/nodes/register", "/orchestrator/nodes/register",
json={ json={
"name": self.name, "name": self.name,
"sdr_device": self.sdr_device, "sdr_device": self.sdr_device,
@ -190,7 +190,7 @@ class NodeAgent:
if not self.node_id: if not self.node_id:
return return
try: try:
self._delete(f"/composer/nodes/{self.node_id}", timeout=10) self._delete(f"/orchestrator/nodes/{self.node_id}", timeout=10)
logger.info("Deregistered %s", self.node_id) logger.info("Deregistered %s", self.node_id)
except Exception as exc: except Exception as exc:
logger.debug("Deregister failed (ignored on shutdown): %s", exc) logger.debug("Deregister failed (ignored on shutdown): %s", exc)
@ -202,7 +202,7 @@ class NodeAgent:
def _heartbeat_loop(self) -> None: def _heartbeat_loop(self) -> None:
while not self._stop.wait(_HEARTBEAT_INTERVAL): while not self._stop.wait(_HEARTBEAT_INTERVAL):
try: try:
resp = self._post(f"/composer/nodes/{self.node_id}/heartbeat", timeout=10) resp = self._post(f"/orchestrator/nodes/{self.node_id}/heartbeat", timeout=10)
if resp.status_code == 404: if resp.status_code == 404:
logger.warning("Heartbeat got 404 — hub lost registration, re-registering") logger.warning("Heartbeat got 404 — hub lost registration, re-registering")
self._register() self._register()
@ -217,7 +217,7 @@ class NodeAgent:
while not self._stop.is_set(): while not self._stop.is_set():
try: try:
resp = self._get( resp = self._get(
f"/composer/nodes/{self.node_id}/commands", f"/orchestrator/nodes/{self.node_id}/commands",
timeout=_POLL_CLIENT_TIMEOUT, timeout=_POLL_CLIENT_TIMEOUT,
) )
if resp.status_code == 204: if resp.status_code == 204:
@ -540,7 +540,7 @@ class NodeAgent:
logger.info("Inference loop exited") logger.info("Inference loop exited")
def _post_event(self, device_id: str | None, confidence: float, snr_db: float) -> None: def _post_event(self, device_id: str | None, confidence: float, snr_db: float) -> None:
"""POST a single detection event to ``POST /composer/nodes/{id}/events``. """POST a single detection event to ``POST /orchestrator/nodes/{id}/events``.
Failures are logged at DEBUG level and silently swallowed so that a Failures are logged at DEBUG level and silently swallowed so that a
transient network blip does not crash the inference loop. transient network blip does not crash the inference loop.
@ -556,7 +556,7 @@ class NodeAgent:
} }
try: try:
resp = self._post( resp = self._post(
f"/composer/nodes/{self.node_id}/events", f"/orchestrator/nodes/{self.node_id}/events",
json=payload, json=payload,
timeout=5, timeout=5,
) )
@ -619,7 +619,7 @@ class NodeAgent:
payload["error"] = error payload["error"] = error
try: try:
resp = self._post( resp = self._post(
f"/composer/nodes/{self.node_id}/campaign-status", f"/orchestrator/nodes/{self.node_id}/campaign-status",
json=payload, json=payload,
timeout=15, timeout=15,
) )

View File

@ -223,16 +223,13 @@ class TransmitterConfig:
id: str id: str
type: str # "wifi", "bluetooth", "sdr", "external" type: str # "wifi", "bluetooth", "sdr", "external"
control_method: str # "external_script" | "sdr" | "sdr_remote" control_method: str # "external_script" | "sdr"
schedule: list[CaptureStep] schedule: list[CaptureStep]
# For external_script control # For external_script control
script: Optional[str] = None # path to control script script: Optional[str] = None # path to control script
device: Optional[str] = None # e.g. "/dev/wlan0" device: Optional[str] = None # e.g. "/dev/wlan0"
# For sdr_remote control — keys: host, ssh_user, ssh_key_path, device_type, device_id, zmq_port
sdr_remote: Optional[dict] = None
@classmethod @classmethod
def from_dict(cls, d: dict) -> "TransmitterConfig": def from_dict(cls, d: dict) -> "TransmitterConfig":
schedule = [CaptureStep.from_dict(s) for s in d.get("schedule", [])] schedule = [CaptureStep.from_dict(s) for s in d.get("schedule", [])]
@ -243,7 +240,6 @@ class TransmitterConfig:
schedule=schedule, schedule=schedule,
script=d.get("script"), script=d.get("script"),
device=d.get("device"), device=d.get("device"),
sdr_remote=d.get("sdr_remote"),
) )

View File

@ -196,7 +196,6 @@ class CampaignExecutor:
self.config = config self.config = config
self.progress_cb = progress_cb self.progress_cb = progress_cb
self._sdr = None self._sdr = None
self._remote_tx_controllers: dict = {}
if verbose: if verbose:
logging.basicConfig(level=logging.DEBUG) logging.basicConfig(level=logging.DEBUG)
@ -223,7 +222,6 @@ class CampaignExecutor:
) )
self._init_sdr() self._init_sdr()
self._init_remote_tx_controllers()
try: try:
total = self.config.total_steps() total = self.config.total_steps()
step_index = 0 step_index = 0
@ -250,7 +248,6 @@ class CampaignExecutor:
) )
finally: finally:
self._close_sdr() self._close_sdr()
self._close_remote_tx_controllers()
result.end_time = time.time() result.end_time = time.time()
logger.info( logger.info(
@ -290,41 +287,6 @@ class CampaignExecutor:
logger.warning(f"SDR close error: {e}") logger.warning(f"SDR close error: {e}")
self._sdr = None self._sdr = None
# ------------------------------------------------------------------
# Remote Tx controller management
# ------------------------------------------------------------------
def _init_remote_tx_controllers(self) -> None:
"""Open SSH+ZMQ connections for all sdr_remote transmitters."""
from ria_toolkit_oss.remote_control import RemoteTransmitterController
for tx in self.config.transmitters:
if tx.control_method != "sdr_remote":
continue
cfg = tx.sdr_remote
if not cfg:
raise RuntimeError(f"Transmitter '{tx.id}' uses sdr_remote but has no sdr_remote config")
logger.info(f"Connecting remote Tx controller for {tx.id}{cfg['host']}")
ctrl = RemoteTransmitterController(
host=cfg["host"],
ssh_user=cfg["ssh_user"],
ssh_key_path=cfg["ssh_key_path"],
zmq_port=int(cfg.get("zmq_port", 5556)),
)
ctrl.set_radio(
device_type=cfg["device_type"],
device_id=cfg.get("device_id", ""),
)
self._remote_tx_controllers[tx.id] = ctrl
def _close_remote_tx_controllers(self) -> None:
for tx_id, ctrl in list(self._remote_tx_controllers.items()):
try:
ctrl.close()
except Exception as exc:
logger.warning(f"Error closing remote Tx controller {tx_id}: {exc}")
self._remote_tx_controllers.clear()
def _record(self, duration_s: float) -> Recording: def _record(self, duration_s: float) -> Recording:
"""Capture ``duration_s`` seconds of IQ samples.""" """Capture ``duration_s`` seconds of IQ samples."""
num_samples = int(duration_s * self.config.recorder.sample_rate) num_samples = int(duration_s * self.config.recorder.sample_rate)
@ -410,8 +372,7 @@ class CampaignExecutor:
traffic, etc. The script is responsible for applying the configuration traffic, etc. The script is responsible for applying the configuration
and returning promptly (i.e. not blocking for the capture duration). and returning promptly (i.e. not blocking for the capture duration).
For ``sdr_remote`` the remote ZMQ controller calls ``init_tx`` then For SDR transmitters this is a no-op placeholder (TX not yet implemented).
starts a background transmit thread that runs for the step duration.
""" """
if transmitter.control_method == "external_script": if transmitter.control_method == "external_script":
if not transmitter.script: if not transmitter.script:
@ -423,20 +384,6 @@ class CampaignExecutor:
elif transmitter.control_method == "sdr": elif transmitter.control_method == "sdr":
logger.debug("SDR TX not yet implemented — skipping start") logger.debug("SDR TX not yet implemented — skipping start")
elif transmitter.control_method == "sdr_remote":
ctrl = self._remote_tx_controllers.get(transmitter.id)
if ctrl is None:
raise RuntimeError(f"No remote Tx controller found for transmitter '{transmitter.id}'")
gain = step.power_dbm if step.power_dbm is not None else 0.0
ctrl.init_tx(
center_frequency=self.config.recorder.center_freq,
sample_rate=self.config.recorder.sample_rate,
gain=gain,
channel=step.channel or 0,
)
# Start transmission in background; _record() runs concurrently
ctrl.transmit_async(step.duration + 1.0)
else: else:
logger.warning(f"Unknown control method '{transmitter.control_method}' — skipping") logger.warning(f"Unknown control method '{transmitter.control_method}' — skipping")
@ -444,7 +391,6 @@ class CampaignExecutor:
"""Signal the transmitter to stop. """Signal the transmitter to stop.
Calls ``<script> stop`` for external_script transmitters. Calls ``<script> stop`` for external_script transmitters.
For ``sdr_remote``, waits for the background transmit thread to finish.
""" """
if transmitter.control_method == "external_script": if transmitter.control_method == "external_script":
if not transmitter.script: if not transmitter.script:
@ -454,11 +400,6 @@ class CampaignExecutor:
except Exception as e: except Exception as e:
logger.warning(f"Script stop failed for {transmitter.id}: {e}") logger.warning(f"Script stop failed for {transmitter.id}: {e}")
elif transmitter.control_method == "sdr_remote":
ctrl = self._remote_tx_controllers.get(transmitter.id)
if ctrl is not None:
ctrl.wait_transmit(timeout=step.duration + 10.0)
@staticmethod @staticmethod
def _step_params_json(transmitter: TransmitterConfig, step: CaptureStep) -> str: def _step_params_json(transmitter: TransmitterConfig, step: CaptureStep) -> str:
"""Serialise step parameters to a JSON string for the control script.""" """Serialise step parameters to a JSON string for the control script."""

View File

@ -1,6 +0,0 @@
"""Remote SDR transmitter control via SSH + ZMQ."""
from .remote_transmitter import RemoteTransmitter
from .remote_transmitter_controller import RemoteTransmitterController
__all__ = ["RemoteTransmitter", "RemoteTransmitterController"]

View File

@ -1,147 +0,0 @@
"""Server-side ZMQ RPC receiver for SDR transmission.
Run this script on the Tx machine. The script binds a ZMQ REP socket and
waits for JSON-RPC commands from a :class:`RemoteTransmitterController`.
Requires: zmq, and ria-toolkit or utils installed for SDR support.
"""
from __future__ import annotations
import argparse
import io
import json
import logging
from contextlib import redirect_stderr, redirect_stdout
import zmq
logger = logging.getLogger(__name__)
class RemoteTransmitter:
"""Executes SDR Tx commands received over ZMQ.
Loads the appropriate SDR driver dynamically so the script can run on
machines that have only a subset of SDR libraries installed.
"""
def __init__(self) -> None:
self._sdr = None
def set_radio(self, radio_str: str, identifier: str = "") -> None:
"""Initialise the SDR radio.
Args:
radio_str: SDR type pluto | usrp | hackrf | bladerf.
identifier: Device-specific identifier (IP, serial, etc.).
"""
radio_str = radio_str.lower()
try:
if radio_str in ("pluto", "plutosdr"):
from ria_toolkit_oss.sdr.pluto import Pluto
self._sdr = Pluto(identifier)
elif radio_str in ("usrp",):
from ria_toolkit_oss.sdr.usrp import USRP
self._sdr = USRP(identifier)
elif radio_str in ("hackrf", "hackrf_one"):
from ria_toolkit_oss.sdr.hackrf import HackRF
self._sdr = HackRF(identifier)
elif radio_str in ("bladerf", "blade"):
from ria_toolkit_oss.sdr.blade import Blade
self._sdr = Blade(identifier)
else:
raise ValueError(f"Unknown SDR type: {radio_str!r}")
except ImportError as exc:
raise RuntimeError(f"SDR driver for '{radio_str}' is not installed: {exc}") from exc
def init_tx(
self,
center_frequency: float,
sample_rate: float,
gain: float,
channel: int = 0,
gain_mode: str = "absolute",
) -> None:
if self._sdr is None:
raise RuntimeError("Call set_radio() before init_tx()")
self._sdr.init_tx(
center_frequency=center_frequency,
sample_rate=sample_rate,
gain=gain,
channel=channel,
)
def transmit(self, duration_s: float) -> None:
"""Transmit a continuous wave for ``duration_s`` seconds."""
if self._sdr is None:
raise RuntimeError("Call set_radio() and init_tx() before transmit()")
import time
# Transmit in a loop until duration has elapsed
end = time.monotonic() + duration_s
while time.monotonic() < end:
try:
self._sdr.tx_cw()
except AttributeError:
time.sleep(0.01)
def stop(self) -> None:
"""Stop transmission and close the SDR."""
if self._sdr is not None:
try:
self._sdr.close()
except Exception:
pass
self._sdr = None
def run_function(self, command_dict: dict) -> dict:
"""Dispatch a JSON-RPC command and return a response dict."""
out_buf = io.StringIO()
err_buf = io.StringIO()
fn = command_dict.get("function_name", "")
try:
with redirect_stdout(out_buf), redirect_stderr(err_buf):
if fn == "set_radio":
self.set_radio(
radio_str=command_dict["radio_str"],
identifier=command_dict.get("identifier", ""),
)
elif fn == "init_tx":
self.init_tx(
center_frequency=command_dict["center_frequency"],
sample_rate=command_dict["sample_rate"],
gain=command_dict["gain"],
channel=command_dict.get("channel", 0),
gain_mode=command_dict.get("gain_mode", "absolute"),
)
elif fn == "transmit":
self.transmit(duration_s=command_dict.get("duration_s", 1.0))
elif fn == "stop":
self.stop()
else:
raise ValueError(f"Unknown function: {fn!r}")
return {"status": True, "message": out_buf.getvalue(), "error_message": err_buf.getvalue()}
except Exception as exc:
logger.exception("Error executing %s", fn)
return {"status": False, "message": out_buf.getvalue(), "error_message": str(exc)}
def _serve(port: int) -> None:
context = zmq.Context()
socket = context.socket(zmq.REP)
socket.bind(f"tcp://*:{port}")
logger.info("RemoteTransmitter listening on port %d", port)
tx = RemoteTransmitter()
while True:
raw = socket.recv()
cmd = json.loads(raw.decode())
response = tx.run_function(cmd)
socket.send(json.dumps(response).encode())
if __name__ == "__main__":
logging.basicConfig(level=logging.INFO)
parser = argparse.ArgumentParser(description="SDR Tx ZMQ server")
parser.add_argument("--port", type=int, default=5556)
args = parser.parse_args()
_serve(args.port)

View File

@ -1,210 +0,0 @@
"""Client-side SSH + ZMQ controller for a remote SDR transmitter.
Run this on the Rx machine (or hub). It SSH-es into the Tx machine,
starts :mod:`remote_transmitter` there, then sends JSON-RPC commands over
ZMQ.
Requires: paramiko, zmq.
"""
from __future__ import annotations
import json
import logging
import threading
import time
logger = logging.getLogger(__name__)
_STARTUP_WAIT_S = 2.0 # seconds to wait for remote ZMQ server to bind
class RemoteTransmitterController:
"""SSH into a Tx machine, start the ZMQ server, and send commands.
Args:
host: IP or hostname of the Tx machine.
ssh_user: SSH username.
ssh_key_path: Path to SSH private key file.
zmq_port: ZMQ port that the remote transmitter will bind on.
"""
def __init__(
self,
host: str,
ssh_user: str,
ssh_key_path: str,
zmq_port: int = 5556,
) -> None:
self._host = host
self._zmq_port = zmq_port
self._ssh: paramiko.SSHClient | None = None
self._ssh_stdout = None
self._context: zmq.Context | None = None
self._socket: zmq.Socket | None = None
self._tx_thread: threading.Thread | None = None
self._lock = threading.Lock()
self._connect(host, ssh_user, ssh_key_path, zmq_port)
# ------------------------------------------------------------------
# Connection management
# ------------------------------------------------------------------
def _connect(self, host: str, ssh_user: str, ssh_key_path: str, zmq_port: int) -> None:
"""Open SSH tunnel, start remote server, connect ZMQ socket."""
try:
import paramiko
except ImportError as exc:
raise RuntimeError("paramiko is required for remote SDR control: pip install paramiko") from exc
try:
import zmq
except ImportError as exc:
raise RuntimeError("pyzmq is required for remote SDR control: pip install pyzmq") from exc
logger.info("SSH connecting to %s@%s", ssh_user, host)
self._ssh = paramiko.SSHClient()
self._ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
self._ssh.connect(hostname=host, username=ssh_user, key_filename=ssh_key_path)
cmd = f"python -m ria_toolkit_oss.remote_control.remote_transmitter --port {zmq_port}"
logger.info("Starting remote Tx server: %s", cmd)
_, self._ssh_stdout, _ = self._ssh.exec_command(cmd)
time.sleep(_STARTUP_WAIT_S)
self._context = zmq.Context()
self._socket = self._context.socket(zmq.REQ)
self._socket.connect(f"tcp://{host}:{zmq_port}")
logger.info("ZMQ connected to tcp://%s:%d", host, zmq_port)
def close(self) -> None:
"""Tear down ZMQ and SSH connections."""
if self._socket is not None:
try:
self._socket.close(linger=0)
except Exception:
pass
self._socket = None
if self._context is not None:
try:
self._context.term()
except Exception:
pass
self._context = None
if self._ssh_stdout is not None:
try:
self._ssh_stdout.channel.close()
except Exception:
pass
self._ssh_stdout = None
if self._ssh is not None:
try:
self._ssh.close()
except Exception:
pass
self._ssh = None
logger.info("RemoteTransmitterController closed")
# ------------------------------------------------------------------
# ZMQ dispatch
# ------------------------------------------------------------------
def _send(self, command: dict) -> dict:
"""Send a JSON-RPC command and return the response dict (thread-safe)."""
with self._lock:
if self._socket is None:
raise RuntimeError("Controller is closed")
self._socket.send(json.dumps(command).encode())
raw = self._socket.recv()
reply: dict = json.loads(raw.decode())
if not reply.get("status"):
raise RuntimeError(
f"Remote command '{command.get('function_name')}' failed: "
f"{reply.get('error_message', 'unknown error')}"
)
return reply
# ------------------------------------------------------------------
# Public API
# ------------------------------------------------------------------
def set_radio(self, device_type: str, device_id: str = "") -> None:
"""Initialise the SDR radio on the Tx machine.
Args:
device_type: SDR type ``pluto``, ``usrp``, ``hackrf``, ``bladerf``.
device_id: Device-specific identifier (IP, serial, etc.).
"""
logger.info("set_radio(%s, %r)", device_type, device_id)
self._send({"function_name": "set_radio", "radio_str": device_type, "identifier": device_id})
def init_tx(
self,
center_frequency: float,
sample_rate: float,
gain: float,
channel: int = 0,
gain_mode: str = "absolute",
) -> None:
"""Configure Tx parameters on the remote SDR.
Args:
center_frequency: Center frequency in Hz.
sample_rate: Sample rate in Hz.
gain: Tx gain in dB.
channel: RF channel index (default 0).
gain_mode: ``"absolute"`` (default) or ``"relative"``.
"""
logger.info(
"init_tx: fc=%.3f MHz, fs=%.3f MHz, gain=%.1f dB, ch=%d",
center_frequency / 1e6, sample_rate / 1e6, gain, channel,
)
self._send({
"function_name": "init_tx",
"center_frequency": center_frequency,
"sample_rate": sample_rate,
"gain": gain,
"channel": channel,
"gain_mode": gain_mode,
})
def transmit_async(self, duration_s: float) -> None:
"""Start a timed CW transmission in a background thread.
Returns immediately. Call :meth:`wait_transmit` after recording to
ensure the transmit thread has finished before the next step.
Args:
duration_s: Transmission duration in seconds.
"""
logger.info("transmit_async: %.1f s", duration_s)
def _run() -> None:
try:
self._send({"function_name": "transmit", "duration_s": duration_s})
except Exception as exc:
logger.warning("Background transmit error: %s", exc)
self._tx_thread = threading.Thread(target=_run, daemon=True, name="remote-tx")
self._tx_thread.start()
def wait_transmit(self, timeout: float | None = None) -> None:
"""Wait for the background transmit thread to finish.
Args:
timeout: Maximum seconds to wait. ``None`` = wait indefinitely.
"""
if self._tx_thread is not None:
self._tx_thread.join(timeout=timeout)
self._tx_thread = None
def stop(self) -> None:
"""Stop transmission and release the remote SDR, then close connections."""
logger.info("Sending stop to remote Tx")
try:
self._send({"function_name": "stop"})
except Exception as exc:
logger.warning("stop command error (may be normal if connection closed): %s", exc)
finally:
self.close()

View File

@ -43,13 +43,6 @@ class SDR(ABC):
self.tx_gain = None self.tx_gain = None
self._param_lock = threading.RLock() # Reentrant lock self._param_lock = threading.RLock() # Reentrant lock
# Pending config consumed by rx() on first call and by _apply_sdr_config
# in the agent inference loop. Subclasses that need different defaults
# (e.g. MockSDR) can overwrite these in their own __init__.
self.center_freq: float = 2.4e9
self.sample_rate: float = 10e6
self.gain: float = 40.0
def record(self, num_samples: Optional[int] = None, rx_time: Optional[int | float] = None) -> Recording: def record(self, num_samples: Optional[int] = None, rx_time: Optional[int | float] = None) -> Recording:
""" """
Create a radio recording of a given length. Either ``num_samples`` or ``rx_time`` must be provided. Create a radio recording of a given length. Either ``num_samples`` or ``rx_time`` must be provided.
@ -107,32 +100,6 @@ class SDR(ABC):
self._num_buffers_processed = 0 self._num_buffers_processed = 0
return recording return recording
def rx(self, num_samples: int) -> "np.ndarray":
"""Return *num_samples* complex IQ samples as a 1-D complex64 array.
This is the interface used by the agent inference loop. On first call,
``init_rx()`` is invoked automatically using the values stored in
``center_freq``, ``sample_rate``, and ``gain`` (set beforehand by
``_apply_sdr_config``). Subsequent calls stream directly.
Subclasses may override this for hardware-native capture APIs (e.g.
``MockSDR`` uses AWGN generation; ``PlutoSDR`` could use
``self.radio.rx()``).
"""
if not self._rx_initialized:
gain = self.gain if isinstance(self.gain, (int, float)) else 40.0
self.init_rx(
sample_rate=self.sample_rate,
center_frequency=self.center_freq,
gain=gain,
channel=0,
)
recording = self.record(num_samples=num_samples)
# Recording.data is either a list of 1-D arrays (one per channel) or a
# 2-D ndarray (channels × samples). Either way, index 0 is channel 0.
data = recording.data
return data[0] if hasattr(data, "__getitem__") else data
def stream_to_zmq(self, zmq_address, n_samples: int, buffer_size: Optional[int] = 10000): def stream_to_zmq(self, zmq_address, n_samples: int, buffer_size: Optional[int] = 10000):
""" """
Stream iq samples as interleaved bytes via zmq. Stream iq samples as interleaved bytes via zmq.

View File

@ -1,287 +0,0 @@
"""Tests for the server-side RemoteTransmitter ZMQ RPC dispatcher.
No real SDR hardware or ZMQ sockets are needed we test run_function()
directly and mock the SDR drivers.
"""
from __future__ import annotations
from unittest.mock import MagicMock, patch
import pytest
from ria_toolkit_oss.remote_control.remote_transmitter import RemoteTransmitter
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _make_mock_sdr():
sdr = MagicMock()
sdr.init_tx = MagicMock()
sdr.tx_cw = MagicMock()
sdr.close = MagicMock()
return sdr
# ---------------------------------------------------------------------------
# set_radio dispatch
# ---------------------------------------------------------------------------
class TestSetRadio:
def _pluto_module(self, mock_sdr):
mod = MagicMock()
mod.Pluto = MagicMock(return_value=mock_sdr)
return mod
def test_pluto_alias(self):
tx = RemoteTransmitter()
mock_sdr = _make_mock_sdr()
with patch.dict("sys.modules", {"ria_toolkit_oss.sdr.pluto": self._pluto_module(mock_sdr)}):
tx.set_radio("pluto", "ip:192.168.2.1")
assert tx._sdr is mock_sdr
def test_plutosdr_alias(self):
tx = RemoteTransmitter()
mock_sdr = _make_mock_sdr()
with patch.dict("sys.modules", {"ria_toolkit_oss.sdr.pluto": self._pluto_module(mock_sdr)}):
tx.set_radio("PlutoSDR", "ip:192.168.2.1")
assert tx._sdr is mock_sdr
def test_usrp_alias(self):
tx = RemoteTransmitter()
mock_sdr = _make_mock_sdr()
mock_module = MagicMock()
mock_module.USRP = MagicMock(return_value=mock_sdr)
with patch.dict("sys.modules", {"ria_toolkit_oss.sdr.usrp": mock_module}):
tx.set_radio("usrp", "usrp://addr=192.168.10.2")
assert tx._sdr is mock_sdr
def test_hackrf_alias(self):
tx = RemoteTransmitter()
mock_sdr = _make_mock_sdr()
with patch("ria_toolkit_oss.sdr.hackrf.HackRF", return_value=mock_sdr):
tx.set_radio("hackrf", "")
assert tx._sdr is mock_sdr
def test_hackrf_one_alias(self):
tx = RemoteTransmitter()
mock_sdr = _make_mock_sdr()
with patch("ria_toolkit_oss.sdr.hackrf.HackRF", return_value=mock_sdr):
tx.set_radio("hackrf_one", "")
assert tx._sdr is mock_sdr
def test_bladerf_alias(self):
tx = RemoteTransmitter()
mock_sdr = _make_mock_sdr()
mock_module = MagicMock()
mock_module.Blade = MagicMock(return_value=mock_sdr)
with patch.dict("sys.modules", {"ria_toolkit_oss.sdr.blade": mock_module}):
tx.set_radio("blade", "")
assert tx._sdr is mock_sdr
def test_bladerf_string_alias(self):
"""'bladerf' string (not 'blade') must also resolve to blade.Blade."""
tx = RemoteTransmitter()
mock_sdr = _make_mock_sdr()
mock_module = MagicMock()
mock_module.Blade = MagicMock(return_value=mock_sdr)
with patch.dict("sys.modules", {"ria_toolkit_oss.sdr.blade": mock_module}):
tx.set_radio("bladerf", "")
assert tx._sdr is mock_sdr
def test_case_insensitive(self):
tx = RemoteTransmitter()
mock_sdr = _make_mock_sdr()
with patch.dict("sys.modules", {"ria_toolkit_oss.sdr.pluto": self._pluto_module(mock_sdr)}):
tx.set_radio("PLUTO", "ip:192.168.2.1")
assert tx._sdr is mock_sdr
def test_unknown_radio_raises(self):
tx = RemoteTransmitter()
with pytest.raises(ValueError, match="Unknown SDR type"):
tx.set_radio("nonexistent_radio")
def test_import_error_raises_runtime(self):
"""ImportError during SDR driver load is re-raised as RuntimeError."""
tx = RemoteTransmitter()
# Inject a fake module whose Pluto class raises ImportError on import
bad_module = MagicMock()
bad_module.Pluto = MagicMock(side_effect=ImportError("pyadi-iio not installed"))
with patch.dict("sys.modules", {"ria_toolkit_oss.sdr.pluto": bad_module}):
with pytest.raises((RuntimeError, ImportError)):
tx.set_radio("pluto")
# ---------------------------------------------------------------------------
# init_tx / transmit / stop guard
# ---------------------------------------------------------------------------
class TestInitTxGuards:
def test_init_tx_without_set_radio_raises(self):
tx = RemoteTransmitter()
with pytest.raises(RuntimeError, match="set_radio"):
tx.init_tx(center_frequency=2.4e9, sample_rate=20e6, gain=0)
def test_transmit_without_set_radio_raises(self):
tx = RemoteTransmitter()
with pytest.raises(RuntimeError):
tx.transmit(duration_s=0.1)
def test_stop_without_set_radio_is_safe(self):
tx = RemoteTransmitter()
tx.stop() # should not raise — nothing to close
class TestInitTx:
def _tx_with_mock_sdr(self):
tx = RemoteTransmitter()
tx._sdr = _make_mock_sdr()
return tx
def test_delegates_to_sdr(self):
tx = self._tx_with_mock_sdr()
tx.init_tx(center_frequency=2.4e9, sample_rate=20e6, gain=30, channel=1)
tx._sdr.init_tx.assert_called_once_with(
center_frequency=2.4e9,
sample_rate=20e6,
gain=30,
channel=1,
)
def test_default_channel_zero(self):
tx = self._tx_with_mock_sdr()
tx.init_tx(center_frequency=2.4e9, sample_rate=20e6, gain=30)
_, kwargs = tx._sdr.init_tx.call_args
assert kwargs["channel"] == 0
class TestTransmit:
def test_calls_tx_cw_until_duration(self):
tx = RemoteTransmitter()
tx._sdr = _make_mock_sdr()
tx.init_tx(center_frequency=2.4e9, sample_rate=20e6, gain=0)
tx.transmit(duration_s=0.05)
assert tx._sdr.tx_cw.called
def test_zero_duration_does_not_call_tx_cw(self):
tx = RemoteTransmitter()
tx._sdr = _make_mock_sdr()
tx.init_tx(center_frequency=2.4e9, sample_rate=20e6, gain=0)
tx.transmit(duration_s=0.0)
tx._sdr.tx_cw.assert_not_called()
def test_missing_tx_cw_method_handled(self):
"""AttributeError on tx_cw should not crash transmit()."""
tx = RemoteTransmitter()
sdr = MagicMock(spec=[]) # no tx_cw attribute
sdr.init_tx = MagicMock()
tx._sdr = sdr
# Should not raise — AttributeError is caught and slept through
tx.transmit(duration_s=0.01)
class TestStop:
def test_calls_close_and_clears_sdr(self):
tx = RemoteTransmitter()
mock_sdr = _make_mock_sdr()
tx._sdr = mock_sdr
tx.stop()
mock_sdr.close.assert_called_once()
assert tx._sdr is None
def test_close_exception_is_swallowed(self):
tx = RemoteTransmitter()
sdr = _make_mock_sdr()
sdr.close.side_effect = RuntimeError("hardware error")
tx._sdr = sdr
tx.stop() # should not raise
assert tx._sdr is None
def test_stop_idempotent(self):
tx = RemoteTransmitter()
tx.stop()
tx.stop() # second call is safe
# ---------------------------------------------------------------------------
# run_function dispatcher
# ---------------------------------------------------------------------------
class TestRunFunction:
def _tx_with_mock_sdr(self):
tx = RemoteTransmitter()
tx._sdr = _make_mock_sdr()
return tx
def test_unknown_function_returns_failure(self):
tx = RemoteTransmitter()
resp = tx.run_function({"function_name": "explode"})
assert resp["status"] is False
assert "explode" in resp["error_message"]
def test_set_radio_success(self):
tx = RemoteTransmitter()
mock_sdr = _make_mock_sdr()
mod = MagicMock()
mod.Pluto = MagicMock(return_value=mock_sdr)
with patch.dict("sys.modules", {"ria_toolkit_oss.sdr.pluto": mod}):
resp = tx.run_function({"function_name": "set_radio", "radio_str": "pluto", "identifier": "ip:1.2.3.4"})
assert resp["status"] is True
def test_set_radio_bad_type_returns_failure(self):
tx = RemoteTransmitter()
resp = tx.run_function({"function_name": "set_radio", "radio_str": "alien_device"})
assert resp["status"] is False
def test_init_tx_without_radio_returns_failure(self):
tx = RemoteTransmitter()
resp = tx.run_function({
"function_name": "init_tx",
"center_frequency": 2.4e9,
"sample_rate": 20e6,
"gain": 0,
})
assert resp["status"] is False
assert resp["error_message"]
def test_init_tx_with_radio_success(self):
tx = self._tx_with_mock_sdr()
resp = tx.run_function({
"function_name": "init_tx",
"center_frequency": 2.4e9,
"sample_rate": 20e6,
"gain": 30,
})
assert resp["status"] is True
def test_transmit_runs_for_short_duration(self):
tx = self._tx_with_mock_sdr()
tx._sdr.init_tx = MagicMock()
resp = tx.run_function({
"function_name": "init_tx",
"center_frequency": 2.4e9,
"sample_rate": 20e6,
"gain": 0,
})
resp = tx.run_function({"function_name": "transmit", "duration_s": 0.02})
assert resp["status"] is True
def test_stop_via_run_function(self):
tx = self._tx_with_mock_sdr()
resp = tx.run_function({"function_name": "stop"})
assert resp["status"] is True
assert tx._sdr is None
def test_response_always_has_required_keys(self):
tx = RemoteTransmitter()
for fn in ("set_radio", "init_tx", "transmit", "stop", "bogus"):
resp = tx.run_function({"function_name": fn})
assert "status" in resp
assert "message" in resp
assert "error_message" in resp

View File

@ -1,294 +0,0 @@
"""Tests for RemoteTransmitterController — mocks paramiko and ZMQ entirely.
paramiko and zmq are optional runtime deps; these tests inject fakes into
sys.modules so they run regardless of whether the packages are installed.
"""
from __future__ import annotations
import json
import sys
import threading
import time
from types import ModuleType
from unittest.mock import MagicMock, patch
import pytest
# ---------------------------------------------------------------------------
# Fake modules injected into sys.modules before any import of the controller
# ---------------------------------------------------------------------------
def _make_fake_paramiko(mock_ssh_instance):
"""Return a fake paramiko module whose SSHClient() returns mock_ssh_instance."""
mod = MagicMock(spec=ModuleType)
mod.SSHClient = MagicMock(return_value=mock_ssh_instance)
mod.AutoAddPolicy = MagicMock()
return mod
def _make_fake_zmq(mock_socket_instance):
"""Return a fake zmq module whose Context().socket() returns mock_socket_instance."""
mock_context = MagicMock()
mock_context.socket.return_value = mock_socket_instance
mod = MagicMock(spec=ModuleType)
mod.Context = MagicMock(return_value=mock_context)
mod.REQ = "REQ"
return mod, mock_context
# ---------------------------------------------------------------------------
# Fixtures
# ---------------------------------------------------------------------------
def _ok_response(fn="set_radio") -> bytes:
return json.dumps({"status": True, "message": "", "error_message": ""}).encode()
def _err_response(fn="set_radio", msg="boom") -> bytes:
return json.dumps({"status": False, "message": "", "error_message": msg}).encode()
def _make_mock_socket(recv_side_effect=None):
sock = MagicMock()
if recv_side_effect is not None:
sock.recv.side_effect = recv_side_effect
else:
sock.recv.return_value = _ok_response()
return sock
def _make_controller(mock_socket=None, *, startup_wait=0):
"""Build a controller with all external I/O mocked via sys.modules injection."""
mock_sock = mock_socket or _make_mock_socket()
mock_ssh = MagicMock()
mock_stdout = MagicMock()
mock_stdout.channel = MagicMock()
mock_ssh.exec_command.return_value = (MagicMock(), mock_stdout, MagicMock())
fake_paramiko = _make_fake_paramiko(mock_ssh)
fake_zmq, mock_context = _make_fake_zmq(mock_sock)
with (
patch.dict("sys.modules", {"paramiko": fake_paramiko, "zmq": fake_zmq}),
patch(
"ria_toolkit_oss.remote_control.remote_transmitter_controller._STARTUP_WAIT_S",
startup_wait,
),
):
from ria_toolkit_oss.remote_control.remote_transmitter_controller import (
RemoteTransmitterController,
)
ctrl = RemoteTransmitterController(
host="192.168.1.10",
ssh_user="ubuntu",
ssh_key_path="/home/user/.ssh/id_rsa",
zmq_port=5556,
)
ctrl._mock_ssh = mock_ssh
ctrl._mock_socket = mock_sock
ctrl._mock_context = mock_context
ctrl._fake_paramiko = fake_paramiko
return ctrl
# ---------------------------------------------------------------------------
# Connection setup
# ---------------------------------------------------------------------------
class TestConnectionSetup:
def test_ssh_connects_with_correct_args(self):
ctrl = _make_controller()
ctrl._mock_ssh.connect.assert_called_once_with(
hostname="192.168.1.10",
username="ubuntu",
key_filename="/home/user/.ssh/id_rsa",
)
def test_ssh_starts_remote_server(self):
ctrl = _make_controller()
cmd = ctrl._mock_ssh.exec_command.call_args[0][0]
assert "remote_transmitter" in cmd
assert "--port" in cmd
assert "5556" in cmd
def test_zmq_connects_to_host_port(self):
ctrl = _make_controller()
ctrl._mock_socket.connect.assert_called_once_with("tcp://192.168.1.10:5556")
def test_host_key_policy_set_to_auto_add(self):
"""AutoAddPolicy is applied so we don't prompt in headless execution."""
ctrl = _make_controller()
ctrl._mock_ssh.set_missing_host_key_policy.assert_called_once()
# ---------------------------------------------------------------------------
# ZMQ message format
# ---------------------------------------------------------------------------
class TestSendFormat:
def test_set_radio_sends_correct_dict(self):
ctrl = _make_controller()
ctrl.set_radio("pluto", "ip:192.168.2.1")
sent = json.loads(ctrl._mock_socket.send.call_args[0][0].decode())
assert sent["function_name"] == "set_radio"
assert sent["radio_str"] == "pluto"
assert sent["identifier"] == "ip:192.168.2.1"
def test_set_radio_default_identifier(self):
ctrl = _make_controller()
ctrl.set_radio("hackrf")
sent = json.loads(ctrl._mock_socket.send.call_args[0][0].decode())
assert sent["identifier"] == ""
def test_init_tx_sends_correct_dict(self):
ctrl = _make_controller()
ctrl.init_tx(center_frequency=2.4e9, sample_rate=20e6, gain=30, channel=1)
sent = json.loads(ctrl._mock_socket.send.call_args[0][0].decode())
assert sent["function_name"] == "init_tx"
assert sent["center_frequency"] == pytest.approx(2.4e9)
assert sent["sample_rate"] == pytest.approx(20e6)
assert sent["gain"] == pytest.approx(30)
assert sent["channel"] == 1
assert sent["gain_mode"] == "absolute"
def test_init_tx_default_channel_zero(self):
ctrl = _make_controller()
ctrl.init_tx(center_frequency=2.4e9, sample_rate=20e6, gain=0)
sent = json.loads(ctrl._mock_socket.send.call_args[0][0].decode())
assert sent["channel"] == 0
def test_stop_sends_correct_dict(self):
ctrl = _make_controller()
ctrl.stop()
sent = json.loads(ctrl._mock_socket.send.call_args[0][0].decode())
assert sent["function_name"] == "stop"
# ---------------------------------------------------------------------------
# Error handling
# ---------------------------------------------------------------------------
class TestErrorHandling:
def test_error_response_raises_runtime_error(self):
sock = _make_mock_socket()
sock.recv.return_value = _err_response(msg="radio not found")
ctrl = _make_controller(mock_socket=sock)
with pytest.raises(RuntimeError, match="radio not found"):
ctrl.set_radio("pluto")
def test_error_message_included_in_exception(self):
sock = _make_mock_socket()
sock.recv.return_value = _err_response(msg="gain out of range")
ctrl = _make_controller(mock_socket=sock)
with pytest.raises(RuntimeError, match="gain out of range"):
ctrl.init_tx(center_frequency=2.4e9, sample_rate=20e6, gain=999)
def test_send_on_closed_controller_raises(self):
ctrl = _make_controller()
ctrl.close()
with pytest.raises(RuntimeError, match="closed"):
ctrl._send({"function_name": "set_radio", "radio_str": "pluto", "identifier": ""})
def test_missing_paramiko_raises_runtime_error(self):
"""If paramiko is absent, connecting gives a clear RuntimeError."""
import importlib
import ria_toolkit_oss.remote_control.remote_transmitter_controller as mod
with patch.dict("sys.modules", {"paramiko": None}):
with pytest.raises((RuntimeError, ImportError)):
mod.RemoteTransmitterController(
host="h", ssh_user="u", ssh_key_path="/k"
)
# ---------------------------------------------------------------------------
# transmit_async / wait_transmit
# ---------------------------------------------------------------------------
class TestTransmitAsync:
def test_transmit_async_returns_immediately(self):
"""transmit_async must not block — the ZMQ recv may take duration_s seconds."""
def slow_recv():
time.sleep(0.1)
return _ok_response("transmit")
sock = _make_mock_socket()
sock.recv.side_effect = slow_recv
ctrl = _make_controller(mock_socket=sock)
t0 = time.monotonic()
ctrl.transmit_async(duration_s=5.0)
elapsed = time.monotonic() - t0
assert elapsed < 0.05, "transmit_async must not block"
ctrl.wait_transmit(timeout=2.0)
def test_transmit_async_sends_correct_duration(self):
ctrl = _make_controller()
ctrl.transmit_async(duration_s=12.5)
ctrl.wait_transmit(timeout=1.0)
sent = json.loads(ctrl._mock_socket.send.call_args[0][0].decode())
assert sent["function_name"] == "transmit"
assert sent["duration_s"] == pytest.approx(12.5)
def test_wait_transmit_joins_thread(self):
ctrl = _make_controller()
ctrl.transmit_async(duration_s=0.01)
ctrl.wait_transmit(timeout=2.0)
assert ctrl._tx_thread is None
def test_wait_transmit_noop_if_no_thread(self):
ctrl = _make_controller()
ctrl.wait_transmit() # should not raise
def test_transmit_async_error_is_logged_not_raised(self):
"""Background thread errors must not propagate to caller."""
sock = _make_mock_socket()
sock.recv.return_value = _err_response(msg="hardware fault")
ctrl = _make_controller(mock_socket=sock)
ctrl.transmit_async(duration_s=0.01)
ctrl.wait_transmit(timeout=2.0) # should not raise
# ---------------------------------------------------------------------------
# close / teardown
# ---------------------------------------------------------------------------
class TestClose:
def test_close_terminates_zmq_context(self):
ctrl = _make_controller()
ctrl.close()
ctrl._mock_context.term.assert_called_once()
def test_close_closes_zmq_socket(self):
ctrl = _make_controller()
ctrl.close()
ctrl._mock_socket.close.assert_called_once()
def test_close_closes_ssh(self):
ctrl = _make_controller()
ctrl.close()
ctrl._mock_ssh.close.assert_called_once()
def test_close_is_idempotent(self):
ctrl = _make_controller()
ctrl.close()
ctrl.close() # second call must not raise
def test_stop_calls_close(self):
ctrl = _make_controller()
ctrl.stop()
assert ctrl._socket is None
assert ctrl._ssh is None

View File

@ -1,562 +0,0 @@
"""Tests for sdr_remote support in campaign.py and executor.py."""
from __future__ import annotations
from unittest.mock import MagicMock, call, patch
import pytest
from ria_toolkit_oss.orchestration.campaign import (
CampaignConfig,
CaptureStep,
TransmitterConfig,
)
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
_SDR_REMOTE_CFG = {
"host": "192.168.1.50",
"ssh_user": "ubuntu",
"ssh_key_path": "/home/user/.ssh/id_rsa",
"device_type": "pluto",
"device_id": "ip:192.168.2.1",
"zmq_port": 5556,
}
_BASE_TX_DICT = {
"id": "sdr_tx_1",
"type": "sdr",
"control_method": "sdr_remote",
"schedule": [
{"label": "bw20_gain0", "duration": "10s", "channel": 6},
{"label": "bw40_gain5", "duration": "10s", "channel": 36},
],
"sdr_remote": _SDR_REMOTE_CFG,
}
_BASE_RECORDER = {
"device": "pluto",
"center_freq": "2.45GHz",
"sample_rate": "20MHz",
"gain": "30dB",
}
_FULL_CAMPAIGN_DICT = {
"campaign": {"name": "sdr_sweep_test"},
"transmitters": [_BASE_TX_DICT],
"recorder": _BASE_RECORDER,
"output": {"format": "sigmf", "path": "/tmp/recordings"},
}
# ---------------------------------------------------------------------------
# TransmitterConfig.from_dict with sdr_remote
# ---------------------------------------------------------------------------
class TestTransmitterConfigSdrRemote:
def test_sdr_remote_parsed(self):
tx = TransmitterConfig.from_dict(_BASE_TX_DICT)
assert tx.sdr_remote is not None
assert tx.sdr_remote["host"] == "192.168.1.50"
assert tx.sdr_remote["ssh_user"] == "ubuntu"
assert tx.sdr_remote["device_type"] == "pluto"
assert tx.sdr_remote["zmq_port"] == 5556
def test_control_method_parsed(self):
tx = TransmitterConfig.from_dict(_BASE_TX_DICT)
assert tx.control_method == "sdr_remote"
def test_sdr_remote_none_when_absent(self):
d = {
"id": "wifi_tx",
"type": "wifi",
"control_method": "external_script",
"schedule": [{"label": "step", "duration": "10s"}],
}
tx = TransmitterConfig.from_dict(d)
assert tx.sdr_remote is None
def test_schedule_parsed_correctly(self):
tx = TransmitterConfig.from_dict(_BASE_TX_DICT)
assert len(tx.schedule) == 2
assert tx.schedule[0].label == "bw20_gain0"
assert tx.schedule[0].duration == pytest.approx(10.0)
def test_device_id_preserved(self):
tx = TransmitterConfig.from_dict(_BASE_TX_DICT)
assert tx.sdr_remote["device_id"] == "ip:192.168.2.1"
def test_default_zmq_port_preserved_from_dict(self):
d = dict(_BASE_TX_DICT)
cfg = dict(_SDR_REMOTE_CFG)
del cfg["zmq_port"]
d = {**d, "sdr_remote": cfg}
tx = TransmitterConfig.from_dict(d)
# zmq_port not in dict → None or absent, executor uses .get("zmq_port", 5556)
assert tx.sdr_remote.get("zmq_port") is None # raw dict, no default applied here
# ---------------------------------------------------------------------------
# CampaignConfig.from_dict round-trip with sdr_remote transmitter
# ---------------------------------------------------------------------------
class TestCampaignConfigWithSdrRemote:
def test_from_dict_parses_sdr_remote_transmitter(self):
cfg = CampaignConfig.from_dict(_FULL_CAMPAIGN_DICT)
assert len(cfg.transmitters) == 1
tx = cfg.transmitters[0]
assert tx.control_method == "sdr_remote"
assert tx.sdr_remote["host"] == "192.168.1.50"
def test_total_steps(self):
cfg = CampaignConfig.from_dict(_FULL_CAMPAIGN_DICT)
assert cfg.total_steps() == 2
def test_recorder_parsed(self):
cfg = CampaignConfig.from_dict(_FULL_CAMPAIGN_DICT)
assert cfg.recorder.center_freq == pytest.approx(2.45e9)
assert cfg.recorder.sample_rate == pytest.approx(20e6)
# ---------------------------------------------------------------------------
# CampaignExecutor._init_remote_tx_controllers
# ---------------------------------------------------------------------------
def _make_executor(campaign_dict=None):
"""Build a CampaignExecutor with a mocked SDR recorder."""
from ria_toolkit_oss.orchestration.executor import CampaignExecutor
cfg = CampaignConfig.from_dict(campaign_dict or _FULL_CAMPAIGN_DICT)
return CampaignExecutor(cfg)
class TestInitRemoteTxControllers:
def test_creates_controller_for_sdr_remote_transmitters(self):
executor = _make_executor()
mock_ctrl = MagicMock()
with patch(
"ria_toolkit_oss.remote_control.RemoteTransmitterController",
return_value=mock_ctrl,
) as mock_cls:
executor._init_remote_tx_controllers()
mock_cls.assert_called_once_with(
host="192.168.1.50",
ssh_user="ubuntu",
ssh_key_path="/home/user/.ssh/id_rsa",
zmq_port=5556,
)
assert executor._remote_tx_controllers["sdr_tx_1"] is mock_ctrl
def test_calls_set_radio_after_connect(self):
executor = _make_executor()
mock_ctrl = MagicMock()
with patch(
"ria_toolkit_oss.remote_control.RemoteTransmitterController",
return_value=mock_ctrl,
):
executor._init_remote_tx_controllers()
mock_ctrl.set_radio.assert_called_once_with(
device_type="pluto",
device_id="ip:192.168.2.1",
)
def test_skips_non_sdr_remote_transmitters(self):
d = dict(_FULL_CAMPAIGN_DICT)
d["transmitters"] = [
{
"id": "wifi_tx",
"type": "wifi",
"control_method": "external_script",
"schedule": [{"label": "s", "duration": "5s"}],
}
]
executor = _make_executor(d)
with patch(
"ria_toolkit_oss.remote_control.RemoteTransmitterController"
) as mock_cls:
executor._init_remote_tx_controllers()
mock_cls.assert_not_called()
assert executor._remote_tx_controllers == {}
def test_missing_sdr_remote_config_raises(self):
d = dict(_FULL_CAMPAIGN_DICT)
d["transmitters"] = [
{
"id": "bad_tx",
"type": "sdr",
"control_method": "sdr_remote",
"schedule": [{"label": "s", "duration": "5s"}],
# No sdr_remote key
}
]
executor = _make_executor(d)
with pytest.raises(RuntimeError, match="sdr_remote config"):
executor._init_remote_tx_controllers()
def test_uses_default_zmq_port(self):
d = dict(_FULL_CAMPAIGN_DICT)
cfg = {k: v for k, v in _SDR_REMOTE_CFG.items() if k != "zmq_port"}
d["transmitters"] = [{**_BASE_TX_DICT, "sdr_remote": cfg}]
executor = _make_executor(d)
mock_ctrl = MagicMock()
with patch(
"ria_toolkit_oss.remote_control.RemoteTransmitterController",
return_value=mock_ctrl,
) as mock_cls:
executor._init_remote_tx_controllers()
_, kwargs = mock_cls.call_args
assert kwargs["zmq_port"] == 5556 # default applied via .get("zmq_port", 5556)
# ---------------------------------------------------------------------------
# CampaignExecutor._start_transmitter for sdr_remote
# ---------------------------------------------------------------------------
class TestStartTransmitterSdrRemote:
def _executor_with_mock_ctrl(self):
executor = _make_executor()
mock_ctrl = MagicMock()
executor._remote_tx_controllers["sdr_tx_1"] = mock_ctrl
return executor, mock_ctrl
def test_calls_init_tx_with_recorder_params(self):
executor, ctrl = self._executor_with_mock_ctrl()
tx = executor.config.transmitters[0]
step = tx.schedule[0]
executor._start_transmitter(tx, step)
ctrl.init_tx.assert_called_once_with(
center_frequency=pytest.approx(2.45e9),
sample_rate=pytest.approx(20e6),
gain=pytest.approx(0.0), # step.power_dbm is None → 0.0
channel=6,
)
def test_uses_step_power_dbm_as_gain(self):
executor = _make_executor()
mock_ctrl = MagicMock()
executor._remote_tx_controllers["sdr_tx_1"] = mock_ctrl
tx = executor.config.transmitters[0]
step = CaptureStep(duration=10.0, label="test", channel=6, power_dbm=-10.0)
executor._start_transmitter(tx, step)
_, kwargs = mock_ctrl.init_tx.call_args
assert kwargs["gain"] == pytest.approx(-10.0)
def test_calls_transmit_async_with_duration_plus_buffer(self):
executor, ctrl = self._executor_with_mock_ctrl()
tx = executor.config.transmitters[0]
step = tx.schedule[0] # duration=10s
executor._start_transmitter(tx, step)
ctrl.transmit_async.assert_called_once()
duration_arg = ctrl.transmit_async.call_args[0][0]
assert duration_arg > step.duration # must have a buffer
def test_default_channel_zero_when_step_channel_is_none(self):
executor, ctrl = self._executor_with_mock_ctrl()
tx = executor.config.transmitters[0]
step = CaptureStep(duration=5.0, label="nochan")
executor._start_transmitter(tx, step)
_, kwargs = mock_ctrl_kwarg = ctrl.init_tx.call_args
assert kwargs["channel"] == 0
def test_missing_controller_raises(self):
executor = _make_executor()
tx = executor.config.transmitters[0]
step = tx.schedule[0]
# No controller added → should raise
with pytest.raises(RuntimeError, match="No remote Tx controller"):
executor._start_transmitter(tx, step)
# ---------------------------------------------------------------------------
# CampaignExecutor._stop_transmitter for sdr_remote
# ---------------------------------------------------------------------------
class TestStopTransmitterSdrRemote:
def test_calls_wait_transmit(self):
executor = _make_executor()
mock_ctrl = MagicMock()
executor._remote_tx_controllers["sdr_tx_1"] = mock_ctrl
tx = executor.config.transmitters[0]
step = tx.schedule[0]
executor._stop_transmitter(tx, step)
mock_ctrl.wait_transmit.assert_called_once()
def test_wait_transmit_timeout_exceeds_step_duration(self):
executor = _make_executor()
mock_ctrl = MagicMock()
executor._remote_tx_controllers["sdr_tx_1"] = mock_ctrl
tx = executor.config.transmitters[0]
step = tx.schedule[0] # 10s duration
executor._stop_transmitter(tx, step)
timeout = mock_ctrl.wait_transmit.call_args[1]["timeout"]
assert timeout > step.duration
def test_noop_if_no_controller(self):
executor = _make_executor()
tx = executor.config.transmitters[0]
step = tx.schedule[0]
executor._stop_transmitter(tx, step) # should not raise
# ---------------------------------------------------------------------------
# CampaignExecutor._close_remote_tx_controllers
# ---------------------------------------------------------------------------
class TestCloseRemoteTxControllers:
def test_calls_close_on_all_controllers(self):
executor = _make_executor()
ctrl_a, ctrl_b = MagicMock(), MagicMock()
executor._remote_tx_controllers = {"tx_a": ctrl_a, "tx_b": ctrl_b}
executor._close_remote_tx_controllers()
ctrl_a.close.assert_called_once()
ctrl_b.close.assert_called_once()
def test_clears_dict_after_close(self):
executor = _make_executor()
executor._remote_tx_controllers = {"tx_a": MagicMock()}
executor._close_remote_tx_controllers()
assert executor._remote_tx_controllers == {}
def test_close_exception_does_not_abort_others(self):
executor = _make_executor()
ctrl_a, ctrl_b = MagicMock(), MagicMock()
ctrl_a.close.side_effect = RuntimeError("network gone")
executor._remote_tx_controllers = {"tx_a": ctrl_a, "tx_b": ctrl_b}
executor._close_remote_tx_controllers() # should not raise
ctrl_b.close.assert_called_once()
def test_noop_when_no_controllers(self):
executor = _make_executor()
executor._close_remote_tx_controllers() # should not raise
# ---------------------------------------------------------------------------
# Full run() integration: sdr_remote controllers initialised and torn down
# ---------------------------------------------------------------------------
class TestRunWithSdrRemote:
"""Smoke test: run() calls init/close on the remote controller even on error."""
def test_close_called_in_finally_on_step_failure(self):
"""_close_remote_tx_controllers is in the finally block — runs even on step error."""
executor = _make_executor()
with (
patch.object(executor, "_init_sdr"),
patch.object(executor, "_init_remote_tx_controllers"),
patch.object(executor, "_close_sdr"),
patch.object(executor, "_close_remote_tx_controllers") as mock_close,
patch.object(executor, "_execute_step", side_effect=RuntimeError("step exploded")),
):
with pytest.raises(RuntimeError, match="step exploded"):
executor.run()
mock_close.assert_called_once()
def test_controllers_initialised_before_campaign_loop(self):
executor = _make_executor()
call_order = []
with (
patch.object(
executor,
"_init_sdr",
side_effect=lambda: call_order.append("init_sdr"),
),
patch.object(
executor,
"_init_remote_tx_controllers",
side_effect=lambda: call_order.append("init_remote_tx"),
),
patch.object(executor, "_close_sdr"),
patch.object(executor, "_close_remote_tx_controllers"),
patch.object(executor, "_execute_step", return_value=MagicMock(error=None, qa=MagicMock(flagged=False, snr_db=20.0, duration_s=10.0))),
):
executor.run()
assert call_order.index("init_sdr") < call_order.index("init_remote_tx") or True
# Both must appear
assert "init_sdr" in call_order
assert "init_remote_tx" in call_order
# ---------------------------------------------------------------------------
# Additional coverage gaps
# ---------------------------------------------------------------------------
class TestTransmitBufferAndTimeout:
"""Verify the exact buffer and timeout constants used in start/stop."""
def _executor_with_ctrl(self):
from ria_toolkit_oss.orchestration.executor import CampaignExecutor
cfg = CampaignConfig.from_dict(_FULL_CAMPAIGN_DICT)
executor = CampaignExecutor(cfg)
ctrl = MagicMock()
executor._remote_tx_controllers["sdr_tx_1"] = ctrl
return executor, ctrl
def test_transmit_async_buffer_is_one_second(self):
executor, ctrl = self._executor_with_ctrl()
tx = executor.config.transmitters[0]
step = tx.schedule[0] # duration = 10s
executor._start_transmitter(tx, step)
duration_arg = ctrl.transmit_async.call_args[0][0]
assert duration_arg == pytest.approx(step.duration + 1.0)
def test_wait_transmit_timeout_is_ten_second_buffer(self):
executor, ctrl = self._executor_with_ctrl()
tx = executor.config.transmitters[0]
step = tx.schedule[0] # duration = 10s
executor._stop_transmitter(tx, step)
timeout = ctrl.wait_transmit.call_args[1]["timeout"]
assert timeout == pytest.approx(step.duration + 10.0)
class TestMixedCampaign:
"""Campaigns that mix sdr_remote with external_script transmitters."""
def _mixed_campaign_dict(self):
return {
"campaign": {"name": "mixed_test"},
"transmitters": [
{
"id": "wifi_tx",
"type": "wifi",
"control_method": "external_script",
"schedule": [{"label": "step_a", "duration": "5s"}],
},
{**_BASE_TX_DICT, "id": "sdr_tx"},
],
"recorder": _BASE_RECORDER,
"output": {"format": "sigmf", "path": "/tmp/recordings"},
}
def test_only_sdr_remote_transmitters_get_controllers(self):
from ria_toolkit_oss.orchestration.executor import CampaignExecutor
cfg = CampaignConfig.from_dict(self._mixed_campaign_dict())
executor = CampaignExecutor(cfg)
mock_ctrl = MagicMock()
with patch(
"ria_toolkit_oss.remote_control.RemoteTransmitterController",
return_value=mock_ctrl,
) as mock_cls:
executor._init_remote_tx_controllers()
mock_cls.assert_called_once() # only the sdr_remote one
assert "sdr_tx" in executor._remote_tx_controllers
assert "wifi_tx" not in executor._remote_tx_controllers
def test_start_transmitter_external_script_unaffected_by_sdr_remote(self):
from ria_toolkit_oss.orchestration.executor import CampaignExecutor
cfg = CampaignConfig.from_dict(self._mixed_campaign_dict())
executor = CampaignExecutor(cfg)
wifi_tx = next(t for t in cfg.transmitters if t.id == "wifi_tx")
step = wifi_tx.schedule[0]
# No script configured → should silently skip, not raise
executor._start_transmitter(wifi_tx, step)
class TestMultipleRemoteControllers:
"""Multiple sdr_remote transmitters in one campaign."""
def _two_tx_campaign(self):
tx2 = {**_BASE_TX_DICT, "id": "sdr_tx_2", "sdr_remote": {**_SDR_REMOTE_CFG, "host": "192.168.1.60"}}
return {
"campaign": {"name": "two_tx"},
"transmitters": [_BASE_TX_DICT, tx2],
"recorder": _BASE_RECORDER,
"output": {"format": "sigmf", "path": "/tmp/recordings"},
}
def test_all_controllers_initialised(self):
from ria_toolkit_oss.orchestration.executor import CampaignExecutor
cfg = CampaignConfig.from_dict(self._two_tx_campaign())
executor = CampaignExecutor(cfg)
ctrls = [MagicMock(), MagicMock()]
with patch(
"ria_toolkit_oss.remote_control.RemoteTransmitterController",
side_effect=ctrls,
):
executor._init_remote_tx_controllers()
assert len(executor._remote_tx_controllers) == 2
assert "sdr_tx_1" in executor._remote_tx_controllers
assert "sdr_tx_2" in executor._remote_tx_controllers
def test_all_controllers_closed_even_when_one_fails(self):
from ria_toolkit_oss.orchestration.executor import CampaignExecutor
cfg = CampaignConfig.from_dict(self._two_tx_campaign())
executor = CampaignExecutor(cfg)
ctrl_a, ctrl_b = MagicMock(), MagicMock()
ctrl_a.close.side_effect = RuntimeError("ssh gone")
executor._remote_tx_controllers = {"sdr_tx_1": ctrl_a, "sdr_tx_2": ctrl_b}
executor._close_remote_tx_controllers() # must not raise
ctrl_a.close.assert_called_once()
ctrl_b.close.assert_called_once() # still called despite ctrl_a failure
class TestCampaignFromYamlWithSdrRemote:
"""from_yaml round-trip preserves sdr_remote config."""
def test_yaml_roundtrip(self, tmp_path):
import yaml
raw = {
"campaign": {"name": "yaml_sdr_test"},
"transmitters": [
{
"id": "remote_sdr",
"type": "sdr",
"control_method": "sdr_remote",
"sdr_remote": _SDR_REMOTE_CFG,
"schedule": [{"label": "step1", "duration": "10s"}],
}
],
"recorder": _BASE_RECORDER,
}
path = tmp_path / "campaign.yml"
path.write_text(yaml.dump(raw))
cfg = CampaignConfig.from_yaml(str(path))
tx = cfg.transmitters[0]
assert tx.control_method == "sdr_remote"
assert tx.sdr_remote["host"] == "192.168.1.50"
assert tx.sdr_remote["device_type"] == "pluto"
def test_yaml_without_sdr_remote_key_is_none(self, tmp_path):
import yaml
raw = {
"campaign": {"name": "yaml_ext_test"},
"transmitters": [
{
"id": "wifi_tx",
"type": "wifi",
"control_method": "external_script",
"schedule": [{"label": "step1", "duration": "10s"}],
}
],
"recorder": _BASE_RECORDER,
}
path = tmp_path / "campaign.yml"
path.write_text(yaml.dump(raw))
cfg = CampaignConfig.from_yaml(str(path))
assert cfg.transmitters[0].sdr_remote is None