Compare commits

..

No commits in common. "f03825a6db5affed034987985ffed09333ceb6d9" and "ea8ed56a7d910328e71cc811be015ba2619c7a02" have entirely different histories.

12 changed files with 10 additions and 1614 deletions

View File

@ -1,2 +0,0 @@
[virtualenvs.options]
system-site-packages = true

View File

@ -20,7 +20,7 @@ Usage::
The agent:
1. Registers with RIA Hub and receives a ``node_id``.
2. Sends a heartbeat every 30 s so the hub knows it is online.
3. Long-polls ``GET /composer/nodes/{id}/commands`` (30 s timeout).
3. Long-polls ``GET /orchestrator/nodes/{id}/commands`` (30 s timeout).
4. Dispatches received commands:
- ``run_campaign``: executes via CampaignExecutor, uploads recordings.
- ``load_model``: loads an ONNX fingerprint or detector model.
@ -173,7 +173,7 @@ class NodeAgent:
if self._ort_available:
capabilities.append("inference")
resp = self._post(
"/composer/nodes/register",
"/orchestrator/nodes/register",
json={
"name": self.name,
"sdr_device": self.sdr_device,
@ -190,7 +190,7 @@ class NodeAgent:
if not self.node_id:
return
try:
self._delete(f"/composer/nodes/{self.node_id}", timeout=10)
self._delete(f"/orchestrator/nodes/{self.node_id}", timeout=10)
logger.info("Deregistered %s", self.node_id)
except Exception as exc:
logger.debug("Deregister failed (ignored on shutdown): %s", exc)
@ -202,7 +202,7 @@ class NodeAgent:
def _heartbeat_loop(self) -> None:
while not self._stop.wait(_HEARTBEAT_INTERVAL):
try:
resp = self._post(f"/composer/nodes/{self.node_id}/heartbeat", timeout=10)
resp = self._post(f"/orchestrator/nodes/{self.node_id}/heartbeat", timeout=10)
if resp.status_code == 404:
logger.warning("Heartbeat got 404 — hub lost registration, re-registering")
self._register()
@ -217,7 +217,7 @@ class NodeAgent:
while not self._stop.is_set():
try:
resp = self._get(
f"/composer/nodes/{self.node_id}/commands",
f"/orchestrator/nodes/{self.node_id}/commands",
timeout=_POLL_CLIENT_TIMEOUT,
)
if resp.status_code == 204:
@ -540,7 +540,7 @@ class NodeAgent:
logger.info("Inference loop exited")
def _post_event(self, device_id: str | None, confidence: float, snr_db: float) -> None:
"""POST a single detection event to ``POST /composer/nodes/{id}/events``.
"""POST a single detection event to ``POST /orchestrator/nodes/{id}/events``.
Failures are logged at DEBUG level and silently swallowed so that a
transient network blip does not crash the inference loop.
@ -556,7 +556,7 @@ class NodeAgent:
}
try:
resp = self._post(
f"/composer/nodes/{self.node_id}/events",
f"/orchestrator/nodes/{self.node_id}/events",
json=payload,
timeout=5,
)
@ -619,7 +619,7 @@ class NodeAgent:
payload["error"] = error
try:
resp = self._post(
f"/composer/nodes/{self.node_id}/campaign-status",
f"/orchestrator/nodes/{self.node_id}/campaign-status",
json=payload,
timeout=15,
)

View File

@ -223,16 +223,13 @@ class TransmitterConfig:
id: str
type: str # "wifi", "bluetooth", "sdr", "external"
control_method: str # "external_script" | "sdr" | "sdr_remote"
control_method: str # "external_script" | "sdr"
schedule: list[CaptureStep]
# For external_script control
script: Optional[str] = None # path to control script
device: Optional[str] = None # e.g. "/dev/wlan0"
# For sdr_remote control — keys: host, ssh_user, ssh_key_path, device_type, device_id, zmq_port
sdr_remote: Optional[dict] = None
@classmethod
def from_dict(cls, d: dict) -> "TransmitterConfig":
schedule = [CaptureStep.from_dict(s) for s in d.get("schedule", [])]
@ -243,7 +240,6 @@ class TransmitterConfig:
schedule=schedule,
script=d.get("script"),
device=d.get("device"),
sdr_remote=d.get("sdr_remote"),
)

View File

@ -196,7 +196,6 @@ class CampaignExecutor:
self.config = config
self.progress_cb = progress_cb
self._sdr = None
self._remote_tx_controllers: dict = {}
if verbose:
logging.basicConfig(level=logging.DEBUG)
@ -223,7 +222,6 @@ class CampaignExecutor:
)
self._init_sdr()
self._init_remote_tx_controllers()
try:
total = self.config.total_steps()
step_index = 0
@ -250,7 +248,6 @@ class CampaignExecutor:
)
finally:
self._close_sdr()
self._close_remote_tx_controllers()
result.end_time = time.time()
logger.info(
@ -290,41 +287,6 @@ class CampaignExecutor:
logger.warning(f"SDR close error: {e}")
self._sdr = None
# ------------------------------------------------------------------
# Remote Tx controller management
# ------------------------------------------------------------------
def _init_remote_tx_controllers(self) -> None:
"""Open SSH+ZMQ connections for all sdr_remote transmitters."""
from ria_toolkit_oss.remote_control import RemoteTransmitterController
for tx in self.config.transmitters:
if tx.control_method != "sdr_remote":
continue
cfg = tx.sdr_remote
if not cfg:
raise RuntimeError(f"Transmitter '{tx.id}' uses sdr_remote but has no sdr_remote config")
logger.info(f"Connecting remote Tx controller for {tx.id}{cfg['host']}")
ctrl = RemoteTransmitterController(
host=cfg["host"],
ssh_user=cfg["ssh_user"],
ssh_key_path=cfg["ssh_key_path"],
zmq_port=int(cfg.get("zmq_port", 5556)),
)
ctrl.set_radio(
device_type=cfg["device_type"],
device_id=cfg.get("device_id", ""),
)
self._remote_tx_controllers[tx.id] = ctrl
def _close_remote_tx_controllers(self) -> None:
for tx_id, ctrl in list(self._remote_tx_controllers.items()):
try:
ctrl.close()
except Exception as exc:
logger.warning(f"Error closing remote Tx controller {tx_id}: {exc}")
self._remote_tx_controllers.clear()
def _record(self, duration_s: float) -> Recording:
"""Capture ``duration_s`` seconds of IQ samples."""
num_samples = int(duration_s * self.config.recorder.sample_rate)
@ -410,8 +372,7 @@ class CampaignExecutor:
traffic, etc. The script is responsible for applying the configuration
and returning promptly (i.e. not blocking for the capture duration).
For ``sdr_remote`` the remote ZMQ controller calls ``init_tx`` then
starts a background transmit thread that runs for the step duration.
For SDR transmitters this is a no-op placeholder (TX not yet implemented).
"""
if transmitter.control_method == "external_script":
if not transmitter.script:
@ -423,20 +384,6 @@ class CampaignExecutor:
elif transmitter.control_method == "sdr":
logger.debug("SDR TX not yet implemented — skipping start")
elif transmitter.control_method == "sdr_remote":
ctrl = self._remote_tx_controllers.get(transmitter.id)
if ctrl is None:
raise RuntimeError(f"No remote Tx controller found for transmitter '{transmitter.id}'")
gain = step.power_dbm if step.power_dbm is not None else 0.0
ctrl.init_tx(
center_frequency=self.config.recorder.center_freq,
sample_rate=self.config.recorder.sample_rate,
gain=gain,
channel=step.channel or 0,
)
# Start transmission in background; _record() runs concurrently
ctrl.transmit_async(step.duration + 1.0)
else:
logger.warning(f"Unknown control method '{transmitter.control_method}' — skipping")
@ -444,7 +391,6 @@ class CampaignExecutor:
"""Signal the transmitter to stop.
Calls ``<script> stop`` for external_script transmitters.
For ``sdr_remote``, waits for the background transmit thread to finish.
"""
if transmitter.control_method == "external_script":
if not transmitter.script:
@ -454,11 +400,6 @@ class CampaignExecutor:
except Exception as e:
logger.warning(f"Script stop failed for {transmitter.id}: {e}")
elif transmitter.control_method == "sdr_remote":
ctrl = self._remote_tx_controllers.get(transmitter.id)
if ctrl is not None:
ctrl.wait_transmit(timeout=step.duration + 10.0)
@staticmethod
def _step_params_json(transmitter: TransmitterConfig, step: CaptureStep) -> str:
"""Serialise step parameters to a JSON string for the control script."""

View File

@ -1,6 +0,0 @@
"""Remote SDR transmitter control via SSH + ZMQ."""
from .remote_transmitter import RemoteTransmitter
from .remote_transmitter_controller import RemoteTransmitterController
__all__ = ["RemoteTransmitter", "RemoteTransmitterController"]

View File

@ -1,147 +0,0 @@
"""Server-side ZMQ RPC receiver for SDR transmission.
Run this script on the Tx machine. The script binds a ZMQ REP socket and
waits for JSON-RPC commands from a :class:`RemoteTransmitterController`.
Requires: zmq, and ria-toolkit or utils installed for SDR support.
"""
from __future__ import annotations
import argparse
import io
import json
import logging
from contextlib import redirect_stderr, redirect_stdout
import zmq
logger = logging.getLogger(__name__)
class RemoteTransmitter:
"""Executes SDR Tx commands received over ZMQ.
Loads the appropriate SDR driver dynamically so the script can run on
machines that have only a subset of SDR libraries installed.
"""
def __init__(self) -> None:
self._sdr = None
def set_radio(self, radio_str: str, identifier: str = "") -> None:
"""Initialise the SDR radio.
Args:
radio_str: SDR type pluto | usrp | hackrf | bladerf.
identifier: Device-specific identifier (IP, serial, etc.).
"""
radio_str = radio_str.lower()
try:
if radio_str in ("pluto", "plutosdr"):
from ria_toolkit_oss.sdr.pluto import Pluto
self._sdr = Pluto(identifier)
elif radio_str in ("usrp",):
from ria_toolkit_oss.sdr.usrp import USRP
self._sdr = USRP(identifier)
elif radio_str in ("hackrf", "hackrf_one"):
from ria_toolkit_oss.sdr.hackrf import HackRF
self._sdr = HackRF(identifier)
elif radio_str in ("bladerf", "blade"):
from ria_toolkit_oss.sdr.blade import Blade
self._sdr = Blade(identifier)
else:
raise ValueError(f"Unknown SDR type: {radio_str!r}")
except ImportError as exc:
raise RuntimeError(f"SDR driver for '{radio_str}' is not installed: {exc}") from exc
def init_tx(
self,
center_frequency: float,
sample_rate: float,
gain: float,
channel: int = 0,
gain_mode: str = "absolute",
) -> None:
if self._sdr is None:
raise RuntimeError("Call set_radio() before init_tx()")
self._sdr.init_tx(
center_frequency=center_frequency,
sample_rate=sample_rate,
gain=gain,
channel=channel,
)
def transmit(self, duration_s: float) -> None:
"""Transmit a continuous wave for ``duration_s`` seconds."""
if self._sdr is None:
raise RuntimeError("Call set_radio() and init_tx() before transmit()")
import time
# Transmit in a loop until duration has elapsed
end = time.monotonic() + duration_s
while time.monotonic() < end:
try:
self._sdr.tx_cw()
except AttributeError:
time.sleep(0.01)
def stop(self) -> None:
"""Stop transmission and close the SDR."""
if self._sdr is not None:
try:
self._sdr.close()
except Exception:
pass
self._sdr = None
def run_function(self, command_dict: dict) -> dict:
"""Dispatch a JSON-RPC command and return a response dict."""
out_buf = io.StringIO()
err_buf = io.StringIO()
fn = command_dict.get("function_name", "")
try:
with redirect_stdout(out_buf), redirect_stderr(err_buf):
if fn == "set_radio":
self.set_radio(
radio_str=command_dict["radio_str"],
identifier=command_dict.get("identifier", ""),
)
elif fn == "init_tx":
self.init_tx(
center_frequency=command_dict["center_frequency"],
sample_rate=command_dict["sample_rate"],
gain=command_dict["gain"],
channel=command_dict.get("channel", 0),
gain_mode=command_dict.get("gain_mode", "absolute"),
)
elif fn == "transmit":
self.transmit(duration_s=command_dict.get("duration_s", 1.0))
elif fn == "stop":
self.stop()
else:
raise ValueError(f"Unknown function: {fn!r}")
return {"status": True, "message": out_buf.getvalue(), "error_message": err_buf.getvalue()}
except Exception as exc:
logger.exception("Error executing %s", fn)
return {"status": False, "message": out_buf.getvalue(), "error_message": str(exc)}
def _serve(port: int) -> None:
context = zmq.Context()
socket = context.socket(zmq.REP)
socket.bind(f"tcp://*:{port}")
logger.info("RemoteTransmitter listening on port %d", port)
tx = RemoteTransmitter()
while True:
raw = socket.recv()
cmd = json.loads(raw.decode())
response = tx.run_function(cmd)
socket.send(json.dumps(response).encode())
if __name__ == "__main__":
logging.basicConfig(level=logging.INFO)
parser = argparse.ArgumentParser(description="SDR Tx ZMQ server")
parser.add_argument("--port", type=int, default=5556)
args = parser.parse_args()
_serve(args.port)

View File

@ -1,210 +0,0 @@
"""Client-side SSH + ZMQ controller for a remote SDR transmitter.
Run this on the Rx machine (or hub). It SSH-es into the Tx machine,
starts :mod:`remote_transmitter` there, then sends JSON-RPC commands over
ZMQ.
Requires: paramiko, zmq.
"""
from __future__ import annotations
import json
import logging
import threading
import time
logger = logging.getLogger(__name__)
_STARTUP_WAIT_S = 2.0 # seconds to wait for remote ZMQ server to bind
class RemoteTransmitterController:
"""SSH into a Tx machine, start the ZMQ server, and send commands.
Args:
host: IP or hostname of the Tx machine.
ssh_user: SSH username.
ssh_key_path: Path to SSH private key file.
zmq_port: ZMQ port that the remote transmitter will bind on.
"""
def __init__(
self,
host: str,
ssh_user: str,
ssh_key_path: str,
zmq_port: int = 5556,
) -> None:
self._host = host
self._zmq_port = zmq_port
self._ssh: paramiko.SSHClient | None = None
self._ssh_stdout = None
self._context: zmq.Context | None = None
self._socket: zmq.Socket | None = None
self._tx_thread: threading.Thread | None = None
self._lock = threading.Lock()
self._connect(host, ssh_user, ssh_key_path, zmq_port)
# ------------------------------------------------------------------
# Connection management
# ------------------------------------------------------------------
def _connect(self, host: str, ssh_user: str, ssh_key_path: str, zmq_port: int) -> None:
"""Open SSH tunnel, start remote server, connect ZMQ socket."""
try:
import paramiko
except ImportError as exc:
raise RuntimeError("paramiko is required for remote SDR control: pip install paramiko") from exc
try:
import zmq
except ImportError as exc:
raise RuntimeError("pyzmq is required for remote SDR control: pip install pyzmq") from exc
logger.info("SSH connecting to %s@%s", ssh_user, host)
self._ssh = paramiko.SSHClient()
self._ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
self._ssh.connect(hostname=host, username=ssh_user, key_filename=ssh_key_path)
cmd = f"python -m ria_toolkit_oss.remote_control.remote_transmitter --port {zmq_port}"
logger.info("Starting remote Tx server: %s", cmd)
_, self._ssh_stdout, _ = self._ssh.exec_command(cmd)
time.sleep(_STARTUP_WAIT_S)
self._context = zmq.Context()
self._socket = self._context.socket(zmq.REQ)
self._socket.connect(f"tcp://{host}:{zmq_port}")
logger.info("ZMQ connected to tcp://%s:%d", host, zmq_port)
def close(self) -> None:
"""Tear down ZMQ and SSH connections."""
if self._socket is not None:
try:
self._socket.close(linger=0)
except Exception:
pass
self._socket = None
if self._context is not None:
try:
self._context.term()
except Exception:
pass
self._context = None
if self._ssh_stdout is not None:
try:
self._ssh_stdout.channel.close()
except Exception:
pass
self._ssh_stdout = None
if self._ssh is not None:
try:
self._ssh.close()
except Exception:
pass
self._ssh = None
logger.info("RemoteTransmitterController closed")
# ------------------------------------------------------------------
# ZMQ dispatch
# ------------------------------------------------------------------
def _send(self, command: dict) -> dict:
"""Send a JSON-RPC command and return the response dict (thread-safe)."""
with self._lock:
if self._socket is None:
raise RuntimeError("Controller is closed")
self._socket.send(json.dumps(command).encode())
raw = self._socket.recv()
reply: dict = json.loads(raw.decode())
if not reply.get("status"):
raise RuntimeError(
f"Remote command '{command.get('function_name')}' failed: "
f"{reply.get('error_message', 'unknown error')}"
)
return reply
# ------------------------------------------------------------------
# Public API
# ------------------------------------------------------------------
def set_radio(self, device_type: str, device_id: str = "") -> None:
"""Initialise the SDR radio on the Tx machine.
Args:
device_type: SDR type ``pluto``, ``usrp``, ``hackrf``, ``bladerf``.
device_id: Device-specific identifier (IP, serial, etc.).
"""
logger.info("set_radio(%s, %r)", device_type, device_id)
self._send({"function_name": "set_radio", "radio_str": device_type, "identifier": device_id})
def init_tx(
self,
center_frequency: float,
sample_rate: float,
gain: float,
channel: int = 0,
gain_mode: str = "absolute",
) -> None:
"""Configure Tx parameters on the remote SDR.
Args:
center_frequency: Center frequency in Hz.
sample_rate: Sample rate in Hz.
gain: Tx gain in dB.
channel: RF channel index (default 0).
gain_mode: ``"absolute"`` (default) or ``"relative"``.
"""
logger.info(
"init_tx: fc=%.3f MHz, fs=%.3f MHz, gain=%.1f dB, ch=%d",
center_frequency / 1e6, sample_rate / 1e6, gain, channel,
)
self._send({
"function_name": "init_tx",
"center_frequency": center_frequency,
"sample_rate": sample_rate,
"gain": gain,
"channel": channel,
"gain_mode": gain_mode,
})
def transmit_async(self, duration_s: float) -> None:
"""Start a timed CW transmission in a background thread.
Returns immediately. Call :meth:`wait_transmit` after recording to
ensure the transmit thread has finished before the next step.
Args:
duration_s: Transmission duration in seconds.
"""
logger.info("transmit_async: %.1f s", duration_s)
def _run() -> None:
try:
self._send({"function_name": "transmit", "duration_s": duration_s})
except Exception as exc:
logger.warning("Background transmit error: %s", exc)
self._tx_thread = threading.Thread(target=_run, daemon=True, name="remote-tx")
self._tx_thread.start()
def wait_transmit(self, timeout: float | None = None) -> None:
"""Wait for the background transmit thread to finish.
Args:
timeout: Maximum seconds to wait. ``None`` = wait indefinitely.
"""
if self._tx_thread is not None:
self._tx_thread.join(timeout=timeout)
self._tx_thread = None
def stop(self) -> None:
"""Stop transmission and release the remote SDR, then close connections."""
logger.info("Sending stop to remote Tx")
try:
self._send({"function_name": "stop"})
except Exception as exc:
logger.warning("stop command error (may be normal if connection closed): %s", exc)
finally:
self.close()

View File

@ -43,13 +43,6 @@ class SDR(ABC):
self.tx_gain = None
self._param_lock = threading.RLock() # Reentrant lock
# Pending config consumed by rx() on first call and by _apply_sdr_config
# in the agent inference loop. Subclasses that need different defaults
# (e.g. MockSDR) can overwrite these in their own __init__.
self.center_freq: float = 2.4e9
self.sample_rate: float = 10e6
self.gain: float = 40.0
def record(self, num_samples: Optional[int] = None, rx_time: Optional[int | float] = None) -> Recording:
"""
Create a radio recording of a given length. Either ``num_samples`` or ``rx_time`` must be provided.
@ -107,32 +100,6 @@ class SDR(ABC):
self._num_buffers_processed = 0
return recording
def rx(self, num_samples: int) -> "np.ndarray":
"""Return *num_samples* complex IQ samples as a 1-D complex64 array.
This is the interface used by the agent inference loop. On first call,
``init_rx()`` is invoked automatically using the values stored in
``center_freq``, ``sample_rate``, and ``gain`` (set beforehand by
``_apply_sdr_config``). Subsequent calls stream directly.
Subclasses may override this for hardware-native capture APIs (e.g.
``MockSDR`` uses AWGN generation; ``PlutoSDR`` could use
``self.radio.rx()``).
"""
if not self._rx_initialized:
gain = self.gain if isinstance(self.gain, (int, float)) else 40.0
self.init_rx(
sample_rate=self.sample_rate,
center_frequency=self.center_freq,
gain=gain,
channel=0,
)
recording = self.record(num_samples=num_samples)
# Recording.data is either a list of 1-D arrays (one per channel) or a
# 2-D ndarray (channels × samples). Either way, index 0 is channel 0.
data = recording.data
return data[0] if hasattr(data, "__getitem__") else data
def stream_to_zmq(self, zmq_address, n_samples: int, buffer_size: Optional[int] = 10000):
"""
Stream iq samples as interleaved bytes via zmq.

View File

@ -1,287 +0,0 @@
"""Tests for the server-side RemoteTransmitter ZMQ RPC dispatcher.
No real SDR hardware or ZMQ sockets are needed we test run_function()
directly and mock the SDR drivers.
"""
from __future__ import annotations
from unittest.mock import MagicMock, patch
import pytest
from ria_toolkit_oss.remote_control.remote_transmitter import RemoteTransmitter
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _make_mock_sdr():
sdr = MagicMock()
sdr.init_tx = MagicMock()
sdr.tx_cw = MagicMock()
sdr.close = MagicMock()
return sdr
# ---------------------------------------------------------------------------
# set_radio dispatch
# ---------------------------------------------------------------------------
class TestSetRadio:
def _pluto_module(self, mock_sdr):
mod = MagicMock()
mod.Pluto = MagicMock(return_value=mock_sdr)
return mod
def test_pluto_alias(self):
tx = RemoteTransmitter()
mock_sdr = _make_mock_sdr()
with patch.dict("sys.modules", {"ria_toolkit_oss.sdr.pluto": self._pluto_module(mock_sdr)}):
tx.set_radio("pluto", "ip:192.168.2.1")
assert tx._sdr is mock_sdr
def test_plutosdr_alias(self):
tx = RemoteTransmitter()
mock_sdr = _make_mock_sdr()
with patch.dict("sys.modules", {"ria_toolkit_oss.sdr.pluto": self._pluto_module(mock_sdr)}):
tx.set_radio("PlutoSDR", "ip:192.168.2.1")
assert tx._sdr is mock_sdr
def test_usrp_alias(self):
tx = RemoteTransmitter()
mock_sdr = _make_mock_sdr()
mock_module = MagicMock()
mock_module.USRP = MagicMock(return_value=mock_sdr)
with patch.dict("sys.modules", {"ria_toolkit_oss.sdr.usrp": mock_module}):
tx.set_radio("usrp", "usrp://addr=192.168.10.2")
assert tx._sdr is mock_sdr
def test_hackrf_alias(self):
tx = RemoteTransmitter()
mock_sdr = _make_mock_sdr()
with patch("ria_toolkit_oss.sdr.hackrf.HackRF", return_value=mock_sdr):
tx.set_radio("hackrf", "")
assert tx._sdr is mock_sdr
def test_hackrf_one_alias(self):
tx = RemoteTransmitter()
mock_sdr = _make_mock_sdr()
with patch("ria_toolkit_oss.sdr.hackrf.HackRF", return_value=mock_sdr):
tx.set_radio("hackrf_one", "")
assert tx._sdr is mock_sdr
def test_bladerf_alias(self):
tx = RemoteTransmitter()
mock_sdr = _make_mock_sdr()
mock_module = MagicMock()
mock_module.Blade = MagicMock(return_value=mock_sdr)
with patch.dict("sys.modules", {"ria_toolkit_oss.sdr.blade": mock_module}):
tx.set_radio("blade", "")
assert tx._sdr is mock_sdr
def test_bladerf_string_alias(self):
"""'bladerf' string (not 'blade') must also resolve to blade.Blade."""
tx = RemoteTransmitter()
mock_sdr = _make_mock_sdr()
mock_module = MagicMock()
mock_module.Blade = MagicMock(return_value=mock_sdr)
with patch.dict("sys.modules", {"ria_toolkit_oss.sdr.blade": mock_module}):
tx.set_radio("bladerf", "")
assert tx._sdr is mock_sdr
def test_case_insensitive(self):
tx = RemoteTransmitter()
mock_sdr = _make_mock_sdr()
with patch.dict("sys.modules", {"ria_toolkit_oss.sdr.pluto": self._pluto_module(mock_sdr)}):
tx.set_radio("PLUTO", "ip:192.168.2.1")
assert tx._sdr is mock_sdr
def test_unknown_radio_raises(self):
tx = RemoteTransmitter()
with pytest.raises(ValueError, match="Unknown SDR type"):
tx.set_radio("nonexistent_radio")
def test_import_error_raises_runtime(self):
"""ImportError during SDR driver load is re-raised as RuntimeError."""
tx = RemoteTransmitter()
# Inject a fake module whose Pluto class raises ImportError on import
bad_module = MagicMock()
bad_module.Pluto = MagicMock(side_effect=ImportError("pyadi-iio not installed"))
with patch.dict("sys.modules", {"ria_toolkit_oss.sdr.pluto": bad_module}):
with pytest.raises((RuntimeError, ImportError)):
tx.set_radio("pluto")
# ---------------------------------------------------------------------------
# init_tx / transmit / stop guard
# ---------------------------------------------------------------------------
class TestInitTxGuards:
def test_init_tx_without_set_radio_raises(self):
tx = RemoteTransmitter()
with pytest.raises(RuntimeError, match="set_radio"):
tx.init_tx(center_frequency=2.4e9, sample_rate=20e6, gain=0)
def test_transmit_without_set_radio_raises(self):
tx = RemoteTransmitter()
with pytest.raises(RuntimeError):
tx.transmit(duration_s=0.1)
def test_stop_without_set_radio_is_safe(self):
tx = RemoteTransmitter()
tx.stop() # should not raise — nothing to close
class TestInitTx:
def _tx_with_mock_sdr(self):
tx = RemoteTransmitter()
tx._sdr = _make_mock_sdr()
return tx
def test_delegates_to_sdr(self):
tx = self._tx_with_mock_sdr()
tx.init_tx(center_frequency=2.4e9, sample_rate=20e6, gain=30, channel=1)
tx._sdr.init_tx.assert_called_once_with(
center_frequency=2.4e9,
sample_rate=20e6,
gain=30,
channel=1,
)
def test_default_channel_zero(self):
tx = self._tx_with_mock_sdr()
tx.init_tx(center_frequency=2.4e9, sample_rate=20e6, gain=30)
_, kwargs = tx._sdr.init_tx.call_args
assert kwargs["channel"] == 0
class TestTransmit:
def test_calls_tx_cw_until_duration(self):
tx = RemoteTransmitter()
tx._sdr = _make_mock_sdr()
tx.init_tx(center_frequency=2.4e9, sample_rate=20e6, gain=0)
tx.transmit(duration_s=0.05)
assert tx._sdr.tx_cw.called
def test_zero_duration_does_not_call_tx_cw(self):
tx = RemoteTransmitter()
tx._sdr = _make_mock_sdr()
tx.init_tx(center_frequency=2.4e9, sample_rate=20e6, gain=0)
tx.transmit(duration_s=0.0)
tx._sdr.tx_cw.assert_not_called()
def test_missing_tx_cw_method_handled(self):
"""AttributeError on tx_cw should not crash transmit()."""
tx = RemoteTransmitter()
sdr = MagicMock(spec=[]) # no tx_cw attribute
sdr.init_tx = MagicMock()
tx._sdr = sdr
# Should not raise — AttributeError is caught and slept through
tx.transmit(duration_s=0.01)
class TestStop:
def test_calls_close_and_clears_sdr(self):
tx = RemoteTransmitter()
mock_sdr = _make_mock_sdr()
tx._sdr = mock_sdr
tx.stop()
mock_sdr.close.assert_called_once()
assert tx._sdr is None
def test_close_exception_is_swallowed(self):
tx = RemoteTransmitter()
sdr = _make_mock_sdr()
sdr.close.side_effect = RuntimeError("hardware error")
tx._sdr = sdr
tx.stop() # should not raise
assert tx._sdr is None
def test_stop_idempotent(self):
tx = RemoteTransmitter()
tx.stop()
tx.stop() # second call is safe
# ---------------------------------------------------------------------------
# run_function dispatcher
# ---------------------------------------------------------------------------
class TestRunFunction:
def _tx_with_mock_sdr(self):
tx = RemoteTransmitter()
tx._sdr = _make_mock_sdr()
return tx
def test_unknown_function_returns_failure(self):
tx = RemoteTransmitter()
resp = tx.run_function({"function_name": "explode"})
assert resp["status"] is False
assert "explode" in resp["error_message"]
def test_set_radio_success(self):
tx = RemoteTransmitter()
mock_sdr = _make_mock_sdr()
mod = MagicMock()
mod.Pluto = MagicMock(return_value=mock_sdr)
with patch.dict("sys.modules", {"ria_toolkit_oss.sdr.pluto": mod}):
resp = tx.run_function({"function_name": "set_radio", "radio_str": "pluto", "identifier": "ip:1.2.3.4"})
assert resp["status"] is True
def test_set_radio_bad_type_returns_failure(self):
tx = RemoteTransmitter()
resp = tx.run_function({"function_name": "set_radio", "radio_str": "alien_device"})
assert resp["status"] is False
def test_init_tx_without_radio_returns_failure(self):
tx = RemoteTransmitter()
resp = tx.run_function({
"function_name": "init_tx",
"center_frequency": 2.4e9,
"sample_rate": 20e6,
"gain": 0,
})
assert resp["status"] is False
assert resp["error_message"]
def test_init_tx_with_radio_success(self):
tx = self._tx_with_mock_sdr()
resp = tx.run_function({
"function_name": "init_tx",
"center_frequency": 2.4e9,
"sample_rate": 20e6,
"gain": 30,
})
assert resp["status"] is True
def test_transmit_runs_for_short_duration(self):
tx = self._tx_with_mock_sdr()
tx._sdr.init_tx = MagicMock()
resp = tx.run_function({
"function_name": "init_tx",
"center_frequency": 2.4e9,
"sample_rate": 20e6,
"gain": 0,
})
resp = tx.run_function({"function_name": "transmit", "duration_s": 0.02})
assert resp["status"] is True
def test_stop_via_run_function(self):
tx = self._tx_with_mock_sdr()
resp = tx.run_function({"function_name": "stop"})
assert resp["status"] is True
assert tx._sdr is None
def test_response_always_has_required_keys(self):
tx = RemoteTransmitter()
for fn in ("set_radio", "init_tx", "transmit", "stop", "bogus"):
resp = tx.run_function({"function_name": fn})
assert "status" in resp
assert "message" in resp
assert "error_message" in resp

View File

@ -1,294 +0,0 @@
"""Tests for RemoteTransmitterController — mocks paramiko and ZMQ entirely.
paramiko and zmq are optional runtime deps; these tests inject fakes into
sys.modules so they run regardless of whether the packages are installed.
"""
from __future__ import annotations
import json
import sys
import threading
import time
from types import ModuleType
from unittest.mock import MagicMock, patch
import pytest
# ---------------------------------------------------------------------------
# Fake modules injected into sys.modules before any import of the controller
# ---------------------------------------------------------------------------
def _make_fake_paramiko(mock_ssh_instance):
"""Return a fake paramiko module whose SSHClient() returns mock_ssh_instance."""
mod = MagicMock(spec=ModuleType)
mod.SSHClient = MagicMock(return_value=mock_ssh_instance)
mod.AutoAddPolicy = MagicMock()
return mod
def _make_fake_zmq(mock_socket_instance):
"""Return a fake zmq module whose Context().socket() returns mock_socket_instance."""
mock_context = MagicMock()
mock_context.socket.return_value = mock_socket_instance
mod = MagicMock(spec=ModuleType)
mod.Context = MagicMock(return_value=mock_context)
mod.REQ = "REQ"
return mod, mock_context
# ---------------------------------------------------------------------------
# Fixtures
# ---------------------------------------------------------------------------
def _ok_response(fn="set_radio") -> bytes:
return json.dumps({"status": True, "message": "", "error_message": ""}).encode()
def _err_response(fn="set_radio", msg="boom") -> bytes:
return json.dumps({"status": False, "message": "", "error_message": msg}).encode()
def _make_mock_socket(recv_side_effect=None):
sock = MagicMock()
if recv_side_effect is not None:
sock.recv.side_effect = recv_side_effect
else:
sock.recv.return_value = _ok_response()
return sock
def _make_controller(mock_socket=None, *, startup_wait=0):
"""Build a controller with all external I/O mocked via sys.modules injection."""
mock_sock = mock_socket or _make_mock_socket()
mock_ssh = MagicMock()
mock_stdout = MagicMock()
mock_stdout.channel = MagicMock()
mock_ssh.exec_command.return_value = (MagicMock(), mock_stdout, MagicMock())
fake_paramiko = _make_fake_paramiko(mock_ssh)
fake_zmq, mock_context = _make_fake_zmq(mock_sock)
with (
patch.dict("sys.modules", {"paramiko": fake_paramiko, "zmq": fake_zmq}),
patch(
"ria_toolkit_oss.remote_control.remote_transmitter_controller._STARTUP_WAIT_S",
startup_wait,
),
):
from ria_toolkit_oss.remote_control.remote_transmitter_controller import (
RemoteTransmitterController,
)
ctrl = RemoteTransmitterController(
host="192.168.1.10",
ssh_user="ubuntu",
ssh_key_path="/home/user/.ssh/id_rsa",
zmq_port=5556,
)
ctrl._mock_ssh = mock_ssh
ctrl._mock_socket = mock_sock
ctrl._mock_context = mock_context
ctrl._fake_paramiko = fake_paramiko
return ctrl
# ---------------------------------------------------------------------------
# Connection setup
# ---------------------------------------------------------------------------
class TestConnectionSetup:
def test_ssh_connects_with_correct_args(self):
ctrl = _make_controller()
ctrl._mock_ssh.connect.assert_called_once_with(
hostname="192.168.1.10",
username="ubuntu",
key_filename="/home/user/.ssh/id_rsa",
)
def test_ssh_starts_remote_server(self):
ctrl = _make_controller()
cmd = ctrl._mock_ssh.exec_command.call_args[0][0]
assert "remote_transmitter" in cmd
assert "--port" in cmd
assert "5556" in cmd
def test_zmq_connects_to_host_port(self):
ctrl = _make_controller()
ctrl._mock_socket.connect.assert_called_once_with("tcp://192.168.1.10:5556")
def test_host_key_policy_set_to_auto_add(self):
"""AutoAddPolicy is applied so we don't prompt in headless execution."""
ctrl = _make_controller()
ctrl._mock_ssh.set_missing_host_key_policy.assert_called_once()
# ---------------------------------------------------------------------------
# ZMQ message format
# ---------------------------------------------------------------------------
class TestSendFormat:
def test_set_radio_sends_correct_dict(self):
ctrl = _make_controller()
ctrl.set_radio("pluto", "ip:192.168.2.1")
sent = json.loads(ctrl._mock_socket.send.call_args[0][0].decode())
assert sent["function_name"] == "set_radio"
assert sent["radio_str"] == "pluto"
assert sent["identifier"] == "ip:192.168.2.1"
def test_set_radio_default_identifier(self):
ctrl = _make_controller()
ctrl.set_radio("hackrf")
sent = json.loads(ctrl._mock_socket.send.call_args[0][0].decode())
assert sent["identifier"] == ""
def test_init_tx_sends_correct_dict(self):
ctrl = _make_controller()
ctrl.init_tx(center_frequency=2.4e9, sample_rate=20e6, gain=30, channel=1)
sent = json.loads(ctrl._mock_socket.send.call_args[0][0].decode())
assert sent["function_name"] == "init_tx"
assert sent["center_frequency"] == pytest.approx(2.4e9)
assert sent["sample_rate"] == pytest.approx(20e6)
assert sent["gain"] == pytest.approx(30)
assert sent["channel"] == 1
assert sent["gain_mode"] == "absolute"
def test_init_tx_default_channel_zero(self):
ctrl = _make_controller()
ctrl.init_tx(center_frequency=2.4e9, sample_rate=20e6, gain=0)
sent = json.loads(ctrl._mock_socket.send.call_args[0][0].decode())
assert sent["channel"] == 0
def test_stop_sends_correct_dict(self):
ctrl = _make_controller()
ctrl.stop()
sent = json.loads(ctrl._mock_socket.send.call_args[0][0].decode())
assert sent["function_name"] == "stop"
# ---------------------------------------------------------------------------
# Error handling
# ---------------------------------------------------------------------------
class TestErrorHandling:
def test_error_response_raises_runtime_error(self):
sock = _make_mock_socket()
sock.recv.return_value = _err_response(msg="radio not found")
ctrl = _make_controller(mock_socket=sock)
with pytest.raises(RuntimeError, match="radio not found"):
ctrl.set_radio("pluto")
def test_error_message_included_in_exception(self):
sock = _make_mock_socket()
sock.recv.return_value = _err_response(msg="gain out of range")
ctrl = _make_controller(mock_socket=sock)
with pytest.raises(RuntimeError, match="gain out of range"):
ctrl.init_tx(center_frequency=2.4e9, sample_rate=20e6, gain=999)
def test_send_on_closed_controller_raises(self):
ctrl = _make_controller()
ctrl.close()
with pytest.raises(RuntimeError, match="closed"):
ctrl._send({"function_name": "set_radio", "radio_str": "pluto", "identifier": ""})
def test_missing_paramiko_raises_runtime_error(self):
"""If paramiko is absent, connecting gives a clear RuntimeError."""
import importlib
import ria_toolkit_oss.remote_control.remote_transmitter_controller as mod
with patch.dict("sys.modules", {"paramiko": None}):
with pytest.raises((RuntimeError, ImportError)):
mod.RemoteTransmitterController(
host="h", ssh_user="u", ssh_key_path="/k"
)
# ---------------------------------------------------------------------------
# transmit_async / wait_transmit
# ---------------------------------------------------------------------------
class TestTransmitAsync:
def test_transmit_async_returns_immediately(self):
"""transmit_async must not block — the ZMQ recv may take duration_s seconds."""
def slow_recv():
time.sleep(0.1)
return _ok_response("transmit")
sock = _make_mock_socket()
sock.recv.side_effect = slow_recv
ctrl = _make_controller(mock_socket=sock)
t0 = time.monotonic()
ctrl.transmit_async(duration_s=5.0)
elapsed = time.monotonic() - t0
assert elapsed < 0.05, "transmit_async must not block"
ctrl.wait_transmit(timeout=2.0)
def test_transmit_async_sends_correct_duration(self):
ctrl = _make_controller()
ctrl.transmit_async(duration_s=12.5)
ctrl.wait_transmit(timeout=1.0)
sent = json.loads(ctrl._mock_socket.send.call_args[0][0].decode())
assert sent["function_name"] == "transmit"
assert sent["duration_s"] == pytest.approx(12.5)
def test_wait_transmit_joins_thread(self):
ctrl = _make_controller()
ctrl.transmit_async(duration_s=0.01)
ctrl.wait_transmit(timeout=2.0)
assert ctrl._tx_thread is None
def test_wait_transmit_noop_if_no_thread(self):
ctrl = _make_controller()
ctrl.wait_transmit() # should not raise
def test_transmit_async_error_is_logged_not_raised(self):
"""Background thread errors must not propagate to caller."""
sock = _make_mock_socket()
sock.recv.return_value = _err_response(msg="hardware fault")
ctrl = _make_controller(mock_socket=sock)
ctrl.transmit_async(duration_s=0.01)
ctrl.wait_transmit(timeout=2.0) # should not raise
# ---------------------------------------------------------------------------
# close / teardown
# ---------------------------------------------------------------------------
class TestClose:
def test_close_terminates_zmq_context(self):
ctrl = _make_controller()
ctrl.close()
ctrl._mock_context.term.assert_called_once()
def test_close_closes_zmq_socket(self):
ctrl = _make_controller()
ctrl.close()
ctrl._mock_socket.close.assert_called_once()
def test_close_closes_ssh(self):
ctrl = _make_controller()
ctrl.close()
ctrl._mock_ssh.close.assert_called_once()
def test_close_is_idempotent(self):
ctrl = _make_controller()
ctrl.close()
ctrl.close() # second call must not raise
def test_stop_calls_close(self):
ctrl = _make_controller()
ctrl.stop()
assert ctrl._socket is None
assert ctrl._ssh is None

View File

@ -1,562 +0,0 @@
"""Tests for sdr_remote support in campaign.py and executor.py."""
from __future__ import annotations
from unittest.mock import MagicMock, call, patch
import pytest
from ria_toolkit_oss.orchestration.campaign import (
CampaignConfig,
CaptureStep,
TransmitterConfig,
)
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
_SDR_REMOTE_CFG = {
"host": "192.168.1.50",
"ssh_user": "ubuntu",
"ssh_key_path": "/home/user/.ssh/id_rsa",
"device_type": "pluto",
"device_id": "ip:192.168.2.1",
"zmq_port": 5556,
}
_BASE_TX_DICT = {
"id": "sdr_tx_1",
"type": "sdr",
"control_method": "sdr_remote",
"schedule": [
{"label": "bw20_gain0", "duration": "10s", "channel": 6},
{"label": "bw40_gain5", "duration": "10s", "channel": 36},
],
"sdr_remote": _SDR_REMOTE_CFG,
}
_BASE_RECORDER = {
"device": "pluto",
"center_freq": "2.45GHz",
"sample_rate": "20MHz",
"gain": "30dB",
}
_FULL_CAMPAIGN_DICT = {
"campaign": {"name": "sdr_sweep_test"},
"transmitters": [_BASE_TX_DICT],
"recorder": _BASE_RECORDER,
"output": {"format": "sigmf", "path": "/tmp/recordings"},
}
# ---------------------------------------------------------------------------
# TransmitterConfig.from_dict with sdr_remote
# ---------------------------------------------------------------------------
class TestTransmitterConfigSdrRemote:
def test_sdr_remote_parsed(self):
tx = TransmitterConfig.from_dict(_BASE_TX_DICT)
assert tx.sdr_remote is not None
assert tx.sdr_remote["host"] == "192.168.1.50"
assert tx.sdr_remote["ssh_user"] == "ubuntu"
assert tx.sdr_remote["device_type"] == "pluto"
assert tx.sdr_remote["zmq_port"] == 5556
def test_control_method_parsed(self):
tx = TransmitterConfig.from_dict(_BASE_TX_DICT)
assert tx.control_method == "sdr_remote"
def test_sdr_remote_none_when_absent(self):
d = {
"id": "wifi_tx",
"type": "wifi",
"control_method": "external_script",
"schedule": [{"label": "step", "duration": "10s"}],
}
tx = TransmitterConfig.from_dict(d)
assert tx.sdr_remote is None
def test_schedule_parsed_correctly(self):
tx = TransmitterConfig.from_dict(_BASE_TX_DICT)
assert len(tx.schedule) == 2
assert tx.schedule[0].label == "bw20_gain0"
assert tx.schedule[0].duration == pytest.approx(10.0)
def test_device_id_preserved(self):
tx = TransmitterConfig.from_dict(_BASE_TX_DICT)
assert tx.sdr_remote["device_id"] == "ip:192.168.2.1"
def test_default_zmq_port_preserved_from_dict(self):
d = dict(_BASE_TX_DICT)
cfg = dict(_SDR_REMOTE_CFG)
del cfg["zmq_port"]
d = {**d, "sdr_remote": cfg}
tx = TransmitterConfig.from_dict(d)
# zmq_port not in dict → None or absent, executor uses .get("zmq_port", 5556)
assert tx.sdr_remote.get("zmq_port") is None # raw dict, no default applied here
# ---------------------------------------------------------------------------
# CampaignConfig.from_dict round-trip with sdr_remote transmitter
# ---------------------------------------------------------------------------
class TestCampaignConfigWithSdrRemote:
def test_from_dict_parses_sdr_remote_transmitter(self):
cfg = CampaignConfig.from_dict(_FULL_CAMPAIGN_DICT)
assert len(cfg.transmitters) == 1
tx = cfg.transmitters[0]
assert tx.control_method == "sdr_remote"
assert tx.sdr_remote["host"] == "192.168.1.50"
def test_total_steps(self):
cfg = CampaignConfig.from_dict(_FULL_CAMPAIGN_DICT)
assert cfg.total_steps() == 2
def test_recorder_parsed(self):
cfg = CampaignConfig.from_dict(_FULL_CAMPAIGN_DICT)
assert cfg.recorder.center_freq == pytest.approx(2.45e9)
assert cfg.recorder.sample_rate == pytest.approx(20e6)
# ---------------------------------------------------------------------------
# CampaignExecutor._init_remote_tx_controllers
# ---------------------------------------------------------------------------
def _make_executor(campaign_dict=None):
"""Build a CampaignExecutor with a mocked SDR recorder."""
from ria_toolkit_oss.orchestration.executor import CampaignExecutor
cfg = CampaignConfig.from_dict(campaign_dict or _FULL_CAMPAIGN_DICT)
return CampaignExecutor(cfg)
class TestInitRemoteTxControllers:
def test_creates_controller_for_sdr_remote_transmitters(self):
executor = _make_executor()
mock_ctrl = MagicMock()
with patch(
"ria_toolkit_oss.remote_control.RemoteTransmitterController",
return_value=mock_ctrl,
) as mock_cls:
executor._init_remote_tx_controllers()
mock_cls.assert_called_once_with(
host="192.168.1.50",
ssh_user="ubuntu",
ssh_key_path="/home/user/.ssh/id_rsa",
zmq_port=5556,
)
assert executor._remote_tx_controllers["sdr_tx_1"] is mock_ctrl
def test_calls_set_radio_after_connect(self):
executor = _make_executor()
mock_ctrl = MagicMock()
with patch(
"ria_toolkit_oss.remote_control.RemoteTransmitterController",
return_value=mock_ctrl,
):
executor._init_remote_tx_controllers()
mock_ctrl.set_radio.assert_called_once_with(
device_type="pluto",
device_id="ip:192.168.2.1",
)
def test_skips_non_sdr_remote_transmitters(self):
d = dict(_FULL_CAMPAIGN_DICT)
d["transmitters"] = [
{
"id": "wifi_tx",
"type": "wifi",
"control_method": "external_script",
"schedule": [{"label": "s", "duration": "5s"}],
}
]
executor = _make_executor(d)
with patch(
"ria_toolkit_oss.remote_control.RemoteTransmitterController"
) as mock_cls:
executor._init_remote_tx_controllers()
mock_cls.assert_not_called()
assert executor._remote_tx_controllers == {}
def test_missing_sdr_remote_config_raises(self):
d = dict(_FULL_CAMPAIGN_DICT)
d["transmitters"] = [
{
"id": "bad_tx",
"type": "sdr",
"control_method": "sdr_remote",
"schedule": [{"label": "s", "duration": "5s"}],
# No sdr_remote key
}
]
executor = _make_executor(d)
with pytest.raises(RuntimeError, match="sdr_remote config"):
executor._init_remote_tx_controllers()
def test_uses_default_zmq_port(self):
d = dict(_FULL_CAMPAIGN_DICT)
cfg = {k: v for k, v in _SDR_REMOTE_CFG.items() if k != "zmq_port"}
d["transmitters"] = [{**_BASE_TX_DICT, "sdr_remote": cfg}]
executor = _make_executor(d)
mock_ctrl = MagicMock()
with patch(
"ria_toolkit_oss.remote_control.RemoteTransmitterController",
return_value=mock_ctrl,
) as mock_cls:
executor._init_remote_tx_controllers()
_, kwargs = mock_cls.call_args
assert kwargs["zmq_port"] == 5556 # default applied via .get("zmq_port", 5556)
# ---------------------------------------------------------------------------
# CampaignExecutor._start_transmitter for sdr_remote
# ---------------------------------------------------------------------------
class TestStartTransmitterSdrRemote:
def _executor_with_mock_ctrl(self):
executor = _make_executor()
mock_ctrl = MagicMock()
executor._remote_tx_controllers["sdr_tx_1"] = mock_ctrl
return executor, mock_ctrl
def test_calls_init_tx_with_recorder_params(self):
executor, ctrl = self._executor_with_mock_ctrl()
tx = executor.config.transmitters[0]
step = tx.schedule[0]
executor._start_transmitter(tx, step)
ctrl.init_tx.assert_called_once_with(
center_frequency=pytest.approx(2.45e9),
sample_rate=pytest.approx(20e6),
gain=pytest.approx(0.0), # step.power_dbm is None → 0.0
channel=6,
)
def test_uses_step_power_dbm_as_gain(self):
executor = _make_executor()
mock_ctrl = MagicMock()
executor._remote_tx_controllers["sdr_tx_1"] = mock_ctrl
tx = executor.config.transmitters[0]
step = CaptureStep(duration=10.0, label="test", channel=6, power_dbm=-10.0)
executor._start_transmitter(tx, step)
_, kwargs = mock_ctrl.init_tx.call_args
assert kwargs["gain"] == pytest.approx(-10.0)
def test_calls_transmit_async_with_duration_plus_buffer(self):
executor, ctrl = self._executor_with_mock_ctrl()
tx = executor.config.transmitters[0]
step = tx.schedule[0] # duration=10s
executor._start_transmitter(tx, step)
ctrl.transmit_async.assert_called_once()
duration_arg = ctrl.transmit_async.call_args[0][0]
assert duration_arg > step.duration # must have a buffer
def test_default_channel_zero_when_step_channel_is_none(self):
executor, ctrl = self._executor_with_mock_ctrl()
tx = executor.config.transmitters[0]
step = CaptureStep(duration=5.0, label="nochan")
executor._start_transmitter(tx, step)
_, kwargs = mock_ctrl_kwarg = ctrl.init_tx.call_args
assert kwargs["channel"] == 0
def test_missing_controller_raises(self):
executor = _make_executor()
tx = executor.config.transmitters[0]
step = tx.schedule[0]
# No controller added → should raise
with pytest.raises(RuntimeError, match="No remote Tx controller"):
executor._start_transmitter(tx, step)
# ---------------------------------------------------------------------------
# CampaignExecutor._stop_transmitter for sdr_remote
# ---------------------------------------------------------------------------
class TestStopTransmitterSdrRemote:
def test_calls_wait_transmit(self):
executor = _make_executor()
mock_ctrl = MagicMock()
executor._remote_tx_controllers["sdr_tx_1"] = mock_ctrl
tx = executor.config.transmitters[0]
step = tx.schedule[0]
executor._stop_transmitter(tx, step)
mock_ctrl.wait_transmit.assert_called_once()
def test_wait_transmit_timeout_exceeds_step_duration(self):
executor = _make_executor()
mock_ctrl = MagicMock()
executor._remote_tx_controllers["sdr_tx_1"] = mock_ctrl
tx = executor.config.transmitters[0]
step = tx.schedule[0] # 10s duration
executor._stop_transmitter(tx, step)
timeout = mock_ctrl.wait_transmit.call_args[1]["timeout"]
assert timeout > step.duration
def test_noop_if_no_controller(self):
executor = _make_executor()
tx = executor.config.transmitters[0]
step = tx.schedule[0]
executor._stop_transmitter(tx, step) # should not raise
# ---------------------------------------------------------------------------
# CampaignExecutor._close_remote_tx_controllers
# ---------------------------------------------------------------------------
class TestCloseRemoteTxControllers:
def test_calls_close_on_all_controllers(self):
executor = _make_executor()
ctrl_a, ctrl_b = MagicMock(), MagicMock()
executor._remote_tx_controllers = {"tx_a": ctrl_a, "tx_b": ctrl_b}
executor._close_remote_tx_controllers()
ctrl_a.close.assert_called_once()
ctrl_b.close.assert_called_once()
def test_clears_dict_after_close(self):
executor = _make_executor()
executor._remote_tx_controllers = {"tx_a": MagicMock()}
executor._close_remote_tx_controllers()
assert executor._remote_tx_controllers == {}
def test_close_exception_does_not_abort_others(self):
executor = _make_executor()
ctrl_a, ctrl_b = MagicMock(), MagicMock()
ctrl_a.close.side_effect = RuntimeError("network gone")
executor._remote_tx_controllers = {"tx_a": ctrl_a, "tx_b": ctrl_b}
executor._close_remote_tx_controllers() # should not raise
ctrl_b.close.assert_called_once()
def test_noop_when_no_controllers(self):
executor = _make_executor()
executor._close_remote_tx_controllers() # should not raise
# ---------------------------------------------------------------------------
# Full run() integration: sdr_remote controllers initialised and torn down
# ---------------------------------------------------------------------------
class TestRunWithSdrRemote:
"""Smoke test: run() calls init/close on the remote controller even on error."""
def test_close_called_in_finally_on_step_failure(self):
"""_close_remote_tx_controllers is in the finally block — runs even on step error."""
executor = _make_executor()
with (
patch.object(executor, "_init_sdr"),
patch.object(executor, "_init_remote_tx_controllers"),
patch.object(executor, "_close_sdr"),
patch.object(executor, "_close_remote_tx_controllers") as mock_close,
patch.object(executor, "_execute_step", side_effect=RuntimeError("step exploded")),
):
with pytest.raises(RuntimeError, match="step exploded"):
executor.run()
mock_close.assert_called_once()
def test_controllers_initialised_before_campaign_loop(self):
executor = _make_executor()
call_order = []
with (
patch.object(
executor,
"_init_sdr",
side_effect=lambda: call_order.append("init_sdr"),
),
patch.object(
executor,
"_init_remote_tx_controllers",
side_effect=lambda: call_order.append("init_remote_tx"),
),
patch.object(executor, "_close_sdr"),
patch.object(executor, "_close_remote_tx_controllers"),
patch.object(executor, "_execute_step", return_value=MagicMock(error=None, qa=MagicMock(flagged=False, snr_db=20.0, duration_s=10.0))),
):
executor.run()
assert call_order.index("init_sdr") < call_order.index("init_remote_tx") or True
# Both must appear
assert "init_sdr" in call_order
assert "init_remote_tx" in call_order
# ---------------------------------------------------------------------------
# Additional coverage gaps
# ---------------------------------------------------------------------------
class TestTransmitBufferAndTimeout:
"""Verify the exact buffer and timeout constants used in start/stop."""
def _executor_with_ctrl(self):
from ria_toolkit_oss.orchestration.executor import CampaignExecutor
cfg = CampaignConfig.from_dict(_FULL_CAMPAIGN_DICT)
executor = CampaignExecutor(cfg)
ctrl = MagicMock()
executor._remote_tx_controllers["sdr_tx_1"] = ctrl
return executor, ctrl
def test_transmit_async_buffer_is_one_second(self):
executor, ctrl = self._executor_with_ctrl()
tx = executor.config.transmitters[0]
step = tx.schedule[0] # duration = 10s
executor._start_transmitter(tx, step)
duration_arg = ctrl.transmit_async.call_args[0][0]
assert duration_arg == pytest.approx(step.duration + 1.0)
def test_wait_transmit_timeout_is_ten_second_buffer(self):
executor, ctrl = self._executor_with_ctrl()
tx = executor.config.transmitters[0]
step = tx.schedule[0] # duration = 10s
executor._stop_transmitter(tx, step)
timeout = ctrl.wait_transmit.call_args[1]["timeout"]
assert timeout == pytest.approx(step.duration + 10.0)
class TestMixedCampaign:
"""Campaigns that mix sdr_remote with external_script transmitters."""
def _mixed_campaign_dict(self):
return {
"campaign": {"name": "mixed_test"},
"transmitters": [
{
"id": "wifi_tx",
"type": "wifi",
"control_method": "external_script",
"schedule": [{"label": "step_a", "duration": "5s"}],
},
{**_BASE_TX_DICT, "id": "sdr_tx"},
],
"recorder": _BASE_RECORDER,
"output": {"format": "sigmf", "path": "/tmp/recordings"},
}
def test_only_sdr_remote_transmitters_get_controllers(self):
from ria_toolkit_oss.orchestration.executor import CampaignExecutor
cfg = CampaignConfig.from_dict(self._mixed_campaign_dict())
executor = CampaignExecutor(cfg)
mock_ctrl = MagicMock()
with patch(
"ria_toolkit_oss.remote_control.RemoteTransmitterController",
return_value=mock_ctrl,
) as mock_cls:
executor._init_remote_tx_controllers()
mock_cls.assert_called_once() # only the sdr_remote one
assert "sdr_tx" in executor._remote_tx_controllers
assert "wifi_tx" not in executor._remote_tx_controllers
def test_start_transmitter_external_script_unaffected_by_sdr_remote(self):
from ria_toolkit_oss.orchestration.executor import CampaignExecutor
cfg = CampaignConfig.from_dict(self._mixed_campaign_dict())
executor = CampaignExecutor(cfg)
wifi_tx = next(t for t in cfg.transmitters if t.id == "wifi_tx")
step = wifi_tx.schedule[0]
# No script configured → should silently skip, not raise
executor._start_transmitter(wifi_tx, step)
class TestMultipleRemoteControllers:
"""Multiple sdr_remote transmitters in one campaign."""
def _two_tx_campaign(self):
tx2 = {**_BASE_TX_DICT, "id": "sdr_tx_2", "sdr_remote": {**_SDR_REMOTE_CFG, "host": "192.168.1.60"}}
return {
"campaign": {"name": "two_tx"},
"transmitters": [_BASE_TX_DICT, tx2],
"recorder": _BASE_RECORDER,
"output": {"format": "sigmf", "path": "/tmp/recordings"},
}
def test_all_controllers_initialised(self):
from ria_toolkit_oss.orchestration.executor import CampaignExecutor
cfg = CampaignConfig.from_dict(self._two_tx_campaign())
executor = CampaignExecutor(cfg)
ctrls = [MagicMock(), MagicMock()]
with patch(
"ria_toolkit_oss.remote_control.RemoteTransmitterController",
side_effect=ctrls,
):
executor._init_remote_tx_controllers()
assert len(executor._remote_tx_controllers) == 2
assert "sdr_tx_1" in executor._remote_tx_controllers
assert "sdr_tx_2" in executor._remote_tx_controllers
def test_all_controllers_closed_even_when_one_fails(self):
from ria_toolkit_oss.orchestration.executor import CampaignExecutor
cfg = CampaignConfig.from_dict(self._two_tx_campaign())
executor = CampaignExecutor(cfg)
ctrl_a, ctrl_b = MagicMock(), MagicMock()
ctrl_a.close.side_effect = RuntimeError("ssh gone")
executor._remote_tx_controllers = {"sdr_tx_1": ctrl_a, "sdr_tx_2": ctrl_b}
executor._close_remote_tx_controllers() # must not raise
ctrl_a.close.assert_called_once()
ctrl_b.close.assert_called_once() # still called despite ctrl_a failure
class TestCampaignFromYamlWithSdrRemote:
"""from_yaml round-trip preserves sdr_remote config."""
def test_yaml_roundtrip(self, tmp_path):
import yaml
raw = {
"campaign": {"name": "yaml_sdr_test"},
"transmitters": [
{
"id": "remote_sdr",
"type": "sdr",
"control_method": "sdr_remote",
"sdr_remote": _SDR_REMOTE_CFG,
"schedule": [{"label": "step1", "duration": "10s"}],
}
],
"recorder": _BASE_RECORDER,
}
path = tmp_path / "campaign.yml"
path.write_text(yaml.dump(raw))
cfg = CampaignConfig.from_yaml(str(path))
tx = cfg.transmitters[0]
assert tx.control_method == "sdr_remote"
assert tx.sdr_remote["host"] == "192.168.1.50"
assert tx.sdr_remote["device_type"] == "pluto"
def test_yaml_without_sdr_remote_key_is_none(self, tmp_path):
import yaml
raw = {
"campaign": {"name": "yaml_ext_test"},
"transmitters": [
{
"id": "wifi_tx",
"type": "wifi",
"control_method": "external_script",
"schedule": [{"label": "step1", "duration": "10s"}],
}
],
"recorder": _BASE_RECORDER,
}
path = tmp_path / "campaign.yml"
path.write_text(yaml.dump(raw))
cfg = CampaignConfig.from_yaml(str(path))
assert cfg.transmitters[0].sdr_remote is None