Merge pull request 'zfp-oss' (#27) from zfp-oss into main
Some checks failed
Build Sphinx Docs Set / Build Docs (push) Successful in 44m43s
Test with tox / Test with tox (3.10) (push) Successful in 1h4m45s
Build Project / Build Project (3.10) (push) Successful in 1h16m56s
Build Project / Build Project (3.12) (push) Successful in 1h16m52s
Test with tox / Test with tox (3.12) (push) Successful in 31m45s
Test with tox / Test with tox (3.11) (push) Successful in 47m45s
Build Project / Build Project (3.11) (push) Failing after 1h9m0s

Reviewed-on: #27
This commit is contained in:
benchinnery 2026-04-23 11:10:43 -04:00
commit 2881aaf06e
17 changed files with 1321 additions and 80 deletions

12
poetry.lock generated
View File

@ -1,4 +1,4 @@
# This file is automatically @generated by Poetry 2.3.3 and should not be changed by hand. # This file is automatically @generated by Poetry 2.1.2 and should not be changed by hand.
[[package]] [[package]]
name = "alabaster" name = "alabaster"
@ -230,14 +230,14 @@ uvloop = ["uvloop (>=0.15.2) ; sys_platform != \"win32\"", "winloop (>=0.5.0) ;
[[package]] [[package]]
name = "cachetools" name = "cachetools"
version = "7.0.5" version = "7.0.6"
description = "Extensible memoizing collections and decorators" description = "Extensible memoizing collections and decorators"
optional = false optional = false
python-versions = ">=3.10" python-versions = ">=3.10"
groups = ["test"] groups = ["test"]
files = [ files = [
{file = "cachetools-7.0.5-py3-none-any.whl", hash = "sha256:46bc8ebefbe485407621d0a4264b23c080cedd913921bad7ac3ed2f26c183114"}, {file = "cachetools-7.0.6-py3-none-any.whl", hash = "sha256:4e94956cfdd3086f12042cdd29318f5ced3893014f7d0d059bf3ead3f85b7f8b"},
{file = "cachetools-7.0.5.tar.gz", hash = "sha256:0cd042c24377200c1dcd225f8b7b12b0ca53cc2c961b43757e774ebe190fd990"}, {file = "cachetools-7.0.6.tar.gz", hash = "sha256:e5d524d36d65703a87243a26ff08ad84f73352adbeafb1cde81e207b456aaf24"},
] ]
[[package]] [[package]]
@ -1271,7 +1271,7 @@ files = [
[package.dependencies] [package.dependencies]
attrs = ">=22.2.0" attrs = ">=22.2.0"
jsonschema-specifications = ">=2023.3.6" jsonschema-specifications = ">=2023.03.6"
referencing = ">=0.28.4" referencing = ">=0.28.4"
rpds-py = ">=0.25.0" rpds-py = ">=0.25.0"
@ -3749,4 +3749,4 @@ files = [
[metadata] [metadata]
lock-version = "2.1" lock-version = "2.1"
python-versions = ">=3.10" python-versions = ">=3.10"
content-hash = "ffde300b2fc93161d2279a6e2b899bc988d3b5eb3833135821830affc9a5fb62" content-hash = "66c9adf647316db90f963da05e8a83574378bfa4db2c69ce751446b5ee7c408c"

View File

@ -50,7 +50,7 @@ dependencies = [
"pyyaml (>=6.0.3,<7.0.0)", "pyyaml (>=6.0.3,<7.0.0)",
"click (>=8.1.0,<9.0.0)", "click (>=8.1.0,<9.0.0)",
"matplotlib (>=3.8.0,<4.0.0)", "matplotlib (>=3.8.0,<4.0.0)",
"paramiko (>=4.0.0)" "paramiko (>=3.5.1)"
] ]
# [project.optional-dependencies] Commented out to prevent Tox tests from failing # [project.optional-dependencies] Commented out to prevent Tox tests from failing
@ -149,6 +149,11 @@ exclude = '''
[tool.pytest.ini_options] [tool.pytest.ini_options]
pythonpath = ["src"] pythonpath = ["src"]
filterwarnings = [
# FastAPI emits this internally when handling 422 responses; the constant
# is not yet renamed in the installed starlette version, so we can't migrate.
"ignore:'HTTP_422_UNPROCESSABLE_ENTITY' is deprecated:DeprecationWarning",
]
[tool.isort] [tool.isort]
profile = "black" profile = "black"

View File

@ -68,7 +68,7 @@ _HEARTBEAT_INTERVAL = 30 # seconds between heartbeats
_POLL_TIMEOUT = 30 # server-side long-poll duration _POLL_TIMEOUT = 30 # server-side long-poll duration
_POLL_CLIENT_TIMEOUT = 40 # client read timeout — slightly longer than server _POLL_CLIENT_TIMEOUT = 40 # client read timeout — slightly longer than server
_RECONNECT_PAUSE = 5 # seconds to wait after a poll error before retrying _RECONNECT_PAUSE = 5 # seconds to wait after a poll error before retrying
_CHUNK_SIZE = 50 * 1024 * 1024 # 50 MB — well below Cloudflare's 100 MB limit _CHUNK_SIZE = 10 * 1024 * 1024 # 10 MB per chunk — fast enough for git-LFS to process within timeout
_DIRECT_THRESHOLD = 90 * 1024 * 1024 # files above this use chunked upload _DIRECT_THRESHOLD = 90 * 1024 * 1024 # files above this use chunked upload
_CAPTURE_SAMPLES = 4096 # IQ samples per inference window _CAPTURE_SAMPLES = 4096 # IQ samples per inference window
_IDLE_LABELS = frozenset({"noise", "idle", "no_signal", "unknown_protocol", "background"}) _IDLE_LABELS = frozenset({"noise", "idle", "no_signal", "unknown_protocol", "background"})
@ -93,16 +93,24 @@ class NodeAgent:
name: str, name: str,
sdr_device: str = "unknown", sdr_device: str = "unknown",
insecure: bool = False, insecure: bool = False,
role: str = "general",
session_code: str | None = None,
) -> None: ) -> None:
self.hub_url = hub_url.rstrip("/") self.hub_url = hub_url.rstrip("/")
self.api_key = api_key self.api_key = api_key
self.name = name self.name = name
self.sdr_device = sdr_device self.sdr_device = sdr_device
self.insecure = insecure self.insecure = insecure
self.role = role
self.session_code = session_code
self.node_id: str | None = None self.node_id: str | None = None
self._stop = threading.Event() self._stop = threading.Event()
# ── TX state ────────────────────────────────────────────────────────
self._tx_stop = threading.Event()
self._tx_thread: threading.Thread | None = None
# ── Inference state ───────────────────────────────────────────────── # ── Inference state ─────────────────────────────────────────────────
# Protected by _inf_lock for cross-thread model swaps. # Protected by _inf_lock for cross-thread model swaps.
self._inf_lock = threading.Lock() self._inf_lock = threading.Lock()
@ -172,19 +180,27 @@ class NodeAgent:
capabilities = ["campaign"] capabilities = ["campaign"]
if self._ort_available: if self._ort_available:
capabilities.append("inference") capabilities.append("inference")
resp = self._post( if self.role == "tx":
"/composer/nodes/register", capabilities.append("transmit")
json={ payload: dict = {
"name": self.name, "name": self.name,
"sdr_device": self.sdr_device, "sdr_device": self.sdr_device,
"ria_toolkit_version": self._ria_version, "ria_toolkit_version": self._ria_version,
"capabilities": capabilities, "capabilities": capabilities,
}, "role": self.role,
timeout=15, }
) if self.session_code:
payload["session_code"] = self.session_code
resp = self._post("/composer/nodes/register", json=payload, timeout=15)
resp.raise_for_status() resp.raise_for_status()
self.node_id = resp.json()["node_id"] self.node_id = resp.json()["node_id"]
logger.info("Registered as %r (node_id=%s)", self.name, self.node_id) logger.info(
"Registered as %r (node_id=%s, role=%s%s)",
self.name,
self.node_id,
self.role,
f", session_code={self.session_code!r}" if self.session_code else "",
)
def _deregister(self) -> None: def _deregister(self) -> None:
if not self.node_id: if not self.node_id:
@ -245,9 +261,10 @@ class NodeAgent:
if command == "run_campaign": if command == "run_campaign":
campaign_id: str = cmd.get("campaign_id") or str(uuid.uuid4()) campaign_id: str = cmd.get("campaign_id") or str(uuid.uuid4())
config_dict: dict = cmd.get("payload") or {} config_dict: dict = cmd.get("payload") or {}
skip_local_tx: bool = bool(cmd.get("skip_local_tx", False))
threading.Thread( threading.Thread(
target=self._run_campaign, target=self._run_campaign,
args=(campaign_id, config_dict), args=(campaign_id, config_dict, skip_local_tx),
daemon=True, daemon=True,
name=f"campaign-{campaign_id[:8]}", name=f"campaign-{campaign_id[:8]}",
).start() ).start()
@ -269,6 +286,17 @@ class NodeAgent:
self._stop_inference() self._stop_inference()
elif command == "configure_inference": elif command == "configure_inference":
self._queue_sdr_config(cmd) self._queue_sdr_config(cmd)
elif command == "start_transmit":
threading.Thread(
target=self._start_transmit,
args=(cmd,),
daemon=True,
name="ria-start-tx",
).start()
elif command == "stop_transmit":
self._stop_transmit()
elif command == "configure_transmit":
logger.info("configure_transmit received — will apply on next step boundary")
else: else:
logger.warning("Unknown command %r — ignored", command) logger.warning("Unknown command %r — ignored", command)
@ -276,7 +304,7 @@ class NodeAgent:
# Campaign execution # Campaign execution
# ------------------------------------------------------------------ # ------------------------------------------------------------------
def _run_campaign(self, campaign_id: str, config_dict: dict) -> None: def _run_campaign(self, campaign_id: str, config_dict: dict, skip_local_tx: bool = False) -> None:
try: try:
from ria_toolkit_oss.orchestration.campaign import CampaignConfig from ria_toolkit_oss.orchestration.campaign import CampaignConfig
from ria_toolkit_oss.orchestration.executor import CampaignExecutor from ria_toolkit_oss.orchestration.executor import CampaignExecutor
@ -288,10 +316,10 @@ class NodeAgent:
) )
return return
logger.info("Campaign %s starting", campaign_id[:8]) logger.info("Campaign %s starting (skip_local_tx=%s)", campaign_id[:8], skip_local_tx)
try: try:
config = CampaignConfig.from_dict(config_dict) config = CampaignConfig.from_dict(config_dict)
executor = CampaignExecutor(config) executor = CampaignExecutor(config, skip_local_tx=skip_local_tx)
result = executor.run() result = executor.run()
logger.info("Campaign %s completed — uploading recordings", campaign_id[:8]) logger.info("Campaign %s completed — uploading recordings", campaign_id[:8])
self._upload_recordings(campaign_id, config, result) self._upload_recordings(campaign_id, config, result)
@ -301,6 +329,58 @@ class NodeAgent:
logger.error("Campaign %s failed: %s", campaign_id[:8], exc) logger.error("Campaign %s failed: %s", campaign_id[:8], exc)
self._report_campaign_status(campaign_id, "failed", error=str(exc)) self._report_campaign_status(campaign_id, "failed", error=str(exc))
# ------------------------------------------------------------------
# TX execution
# ------------------------------------------------------------------
def _start_transmit(self, cmd: dict) -> None:
"""Execute a synthetic transmit campaign using TxExecutor.
The command payload mirrors a TransmitterConfig dict with an optional
``schedule`` of steps. Each step synthesises a signal and transmits it
via the local SDR in TX mode.
"""
try:
from ria_toolkit_oss.orchestration.tx_executor import TxExecutor
except ImportError as exc:
logger.error("start_transmit: TxExecutor not available: %s", exc)
return
if self._tx_thread and self._tx_thread.is_alive():
logger.warning("start_transmit: TX already running — ignoring duplicate command")
return
self._tx_stop.clear()
campaign_id: str = cmd.get("campaign_id") or str(uuid.uuid4())
executor = TxExecutor(
config=cmd,
sdr_device=self.sdr_device,
stop_event=self._tx_stop,
)
self._tx_thread = threading.Thread(
target=self._run_tx_campaign,
args=(executor, campaign_id),
daemon=True,
name=f"tx-campaign-{campaign_id[:8]}",
)
self._tx_thread.start()
def _run_tx_campaign(self, executor: Any, campaign_id: str) -> None:
try:
executor.run()
logger.info("TX campaign %s completed", campaign_id[:8])
self._report_campaign_status(campaign_id, "completed")
except Exception as exc:
logger.error("TX campaign %s failed: %s", campaign_id[:8], exc)
self._report_campaign_status(campaign_id, "failed", error=str(exc))
def _stop_transmit(self) -> None:
"""Signal the TX loop to stop gracefully."""
self._tx_stop.set()
if self._tx_thread and self._tx_thread.is_alive():
self._tx_thread.join(timeout=5.0)
logger.info("TX stopped")
# ------------------------------------------------------------------ # ------------------------------------------------------------------
# Inference — model loading # Inference — model loading
# ------------------------------------------------------------------ # ------------------------------------------------------------------
@ -579,13 +659,18 @@ class NodeAgent:
base_url = f"{self.hub_url}/datasets/upload" base_url = f"{self.hub_url}/datasets/upload"
steps = (result.get("steps") if isinstance(result, dict) else getattr(result, "steps", None)) or [] steps = (result.get("steps") if isinstance(result, dict) else getattr(result, "steps", None)) or []
output_obj = getattr(config, "output", None)
folder = getattr(output_obj, "folder", None)
campaign_name: str = folder if folder is not None else (getattr(config, "name", None) or "")
for step in steps: for step in steps:
output_path: str | None = getattr(step, "output_path", None) output_path: str | None = getattr(step, "output_path", None)
if not output_path: if not output_path:
continue continue
device_id: str = getattr(step, "transmitter_id", "") or "" device_id: str = getattr(step, "transmitter_id", "") or ""
for fpath in _sigmf_files(output_path): for fpath in _sigmf_files(output_path):
filename = os.path.basename(fpath) basename = os.path.basename(fpath)
path_parts = [p for p in (campaign_name, device_id) if p]
filename = "/".join(path_parts + [basename])
metadata = { metadata = {
"filename": filename, "filename": filename,
"repo_owner": repo_owner, "repo_owner": repo_owner,
@ -671,7 +756,7 @@ class NodeAgent:
headers=headers, headers=headers,
files={"file": (filename, chunk, "application/octet-stream")}, files={"file": (filename, chunk, "application/octet-stream")},
data={**metadata, "upload_id": upload_id, "chunk_index": i, "total_chunks": total_chunks}, data={**metadata, "upload_id": upload_id, "chunk_index": i, "total_chunks": total_chunks},
timeout=120, timeout=(30, None), # 30s connect, no read timeout — server may take minutes on final chunk
verify=verify, verify=verify,
) )
if not resp.ok: if not resp.ok:
@ -848,6 +933,21 @@ def main() -> None:
choices=["DEBUG", "INFO", "WARNING", "ERROR"], choices=["DEBUG", "INFO", "WARNING", "ERROR"],
help="Logging verbosity (default: INFO)", help="Logging verbosity (default: INFO)",
) )
parser.add_argument(
"--role",
default=None,
choices=["general", "rx", "tx"],
help=("Node role reported to the hub. " "'tx' enables synthetic transmission commands. " "Default: general"),
)
parser.add_argument(
"--session-code",
default=None,
metavar="CODE",
help=(
"3-word session code to pair this TX agent with a waiting campaign, "
"e.g. 'amber-peak-transmit'. Supplied by the campaign UI."
),
)
args = parser.parse_args() args = parser.parse_args()
@ -861,6 +961,8 @@ def main() -> None:
device = args.device or cfg.get("device", "unknown") device = args.device or cfg.get("device", "unknown")
insecure = args.insecure if args.insecure is not None else cfg.get("insecure", False) insecure = args.insecure if args.insecure is not None else cfg.get("insecure", False)
log_level = args.log_level or cfg.get("log_level", "INFO") log_level = args.log_level or cfg.get("log_level", "INFO")
role = args.role or cfg.get("role", "general")
session_code = args.session_code or cfg.get("session_code")
if not hub: if not hub:
parser.error("--hub is required (or set 'hub' in the config file)") parser.error("--hub is required (or set 'hub' in the config file)")
@ -888,6 +990,8 @@ def main() -> None:
name=name, name=name,
sdr_device=device, sdr_device=device,
insecure=insecure, insecure=insecure,
role=role,
session_code=session_code,
) )
agent.run() agent.run()

View File

@ -233,6 +233,9 @@ class TransmitterConfig:
# For sdr_remote control — keys: host, ssh_user, ssh_key_path, device_type, device_id, zmq_port # For sdr_remote control — keys: host, ssh_user, ssh_key_path, device_type, device_id, zmq_port
sdr_remote: Optional[dict] = None sdr_remote: Optional[dict] = None
# For sdr_agent control — keys: modulation, order, symbol_rate, center_frequency, filter, rolloff
sdr_agent: Optional[dict] = None
@classmethod @classmethod
def from_dict(cls, d: dict) -> "TransmitterConfig": def from_dict(cls, d: dict) -> "TransmitterConfig":
schedule = [CaptureStep.from_dict(s) for s in d.get("schedule", [])] schedule = [CaptureStep.from_dict(s) for s in d.get("schedule", [])]
@ -244,6 +247,7 @@ class TransmitterConfig:
script=d.get("script"), script=d.get("script"),
device=d.get("device"), device=d.get("device"),
sdr_remote=d.get("sdr_remote"), sdr_remote=d.get("sdr_remote"),
sdr_agent=d.get("sdr_agent"),
) )
@ -272,6 +276,7 @@ class OutputConfig:
path: str = "recordings" path: str = "recordings"
device_id: Optional[str] = None # for device-profile campaigns device_id: Optional[str] = None # for device-profile campaigns
repo: Optional[str] = None repo: Optional[str] = None
folder: Optional[str] = None # repo subfolder: None = use campaign name, "" = no subfolder, str = custom
@classmethod @classmethod
def from_dict(cls, d: dict) -> "OutputConfig": def from_dict(cls, d: dict) -> "OutputConfig":
@ -280,6 +285,7 @@ class OutputConfig:
path=str(d.get("path", "recordings")), path=str(d.get("path", "recordings")),
device_id=d.get("device_id"), device_id=d.get("device_id"),
repo=d.get("repo"), repo=d.get("repo"),
folder=d.get("folder"),
) )
@ -293,6 +299,7 @@ class CampaignConfig:
qa: QAConfig = field(default_factory=QAConfig) qa: QAConfig = field(default_factory=QAConfig)
output: OutputConfig = field(default_factory=OutputConfig) output: OutputConfig = field(default_factory=OutputConfig)
mode: str = "controlled_testbed" mode: str = "controlled_testbed"
loops: int = 1 # repeat full schedule this many times; labels get _run{N:02d} suffix
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# Loaders # Loaders
@ -320,6 +327,7 @@ class CampaignConfig:
return cls( return cls(
name=safe_name, name=safe_name,
mode=str(campaign_meta.get("mode", "controlled_testbed")), mode=str(campaign_meta.get("mode", "controlled_testbed")),
loops=max(1, int(campaign_meta.get("loops", 1))),
recorder=RecorderConfig.from_dict(raw["recorder"]), recorder=RecorderConfig.from_dict(raw["recorder"]),
transmitters=transmitters, transmitters=transmitters,
qa=QAConfig.from_dict(raw.get("qa", {})), qa=QAConfig.from_dict(raw.get("qa", {})),
@ -384,6 +392,7 @@ class CampaignConfig:
return cls( return cls(
name=safe_name, name=safe_name,
mode=str(campaign_meta.get("mode", "controlled_testbed")), mode=str(campaign_meta.get("mode", "controlled_testbed")),
loops=max(1, int(campaign_meta.get("loops", 1))),
recorder=RecorderConfig.from_dict(raw["recorder"]), recorder=RecorderConfig.from_dict(raw["recorder"]),
transmitters=transmitters, transmitters=transmitters,
qa=QAConfig.from_dict(raw.get("qa", {})), qa=QAConfig.from_dict(raw.get("qa", {})),
@ -486,9 +495,9 @@ class CampaignConfig:
) )
def total_capture_time_s(self) -> float: def total_capture_time_s(self) -> float:
"""Sum of all step durations across all transmitters.""" """Sum of all step durations across all transmitters and loops."""
return sum(step.duration for tx in self.transmitters for step in tx.schedule) return sum(step.duration for tx in self.transmitters for step in tx.schedule) * self.loops
def total_steps(self) -> int: def total_steps(self) -> int:
"""Total number of capture steps across all transmitters.""" """Total number of capture steps across all transmitters and loops."""
return sum(len(tx.schedule) for tx in self.transmitters) return sum(len(tx.schedule) for tx in self.transmitters) * self.loops

View File

@ -5,8 +5,9 @@ from __future__ import annotations
import json import json
import logging import logging
import subprocess import subprocess
import threading
import time import time
from dataclasses import dataclass, field from dataclasses import dataclass, field, replace
from pathlib import Path from pathlib import Path
from typing import Callable, Optional from typing import Callable, Optional
@ -16,6 +17,7 @@ from ria_toolkit_oss.io.recording import to_sigmf
from .campaign import CampaignConfig, CaptureStep, TransmitterConfig from .campaign import CampaignConfig, CaptureStep, TransmitterConfig
from .labeler import build_output_filename, label_recording from .labeler import build_output_filename, label_recording
from .qa import QAResult, check_recording from .qa import QAResult, check_recording
from .tx_executor import TxExecutor
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -169,6 +171,21 @@ def _run_script(script: str, *args: str, timeout: float = 15.0) -> str:
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
def _extract_tx_params(transmitter: TransmitterConfig) -> dict | None:
"""Build a tx_params dict from a transmitter's signal config for SigMF labeling.
For sdr_agent transmitters, returns the synthetic generation parameters
(modulation, order, symbol_rate, etc.) so recordings capture what was
transmitted. Returns None for control methods without signal-level params.
"""
sdr_agent_cfg = getattr(transmitter, "sdr_agent", None)
if not sdr_agent_cfg:
return None
# Extract known signal-level fields; ignore infra fields
_INFRA_KEYS = {"node_id", "session_code"}
return {k: v for k, v in sdr_agent_cfg.items() if k not in _INFRA_KEYS and v is not None}
class CampaignExecutor: class CampaignExecutor:
"""Executes a :class:`CampaignConfig` end-to-end. """Executes a :class:`CampaignConfig` end-to-end.
@ -192,11 +209,14 @@ class CampaignExecutor:
config: CampaignConfig, config: CampaignConfig,
progress_cb: Optional[Callable[[int, int, StepResult], None]] = None, progress_cb: Optional[Callable[[int, int, StepResult], None]] = None,
verbose: bool = False, verbose: bool = False,
skip_local_tx: bool = False,
): ):
self.config = config self.config = config
self.progress_cb = progress_cb self.progress_cb = progress_cb
self.skip_local_tx = skip_local_tx
self._sdr = None self._sdr = None
self._remote_tx_controllers: dict = {} self._remote_tx_controllers: dict = {}
self._tx_executors: dict[str, tuple] = {} # tx_id → (TxExecutor, stop_event, thread)
if verbose: if verbose:
logging.basicConfig(level=logging.DEBUG) logging.basicConfig(level=logging.DEBUG)
@ -216,10 +236,12 @@ class CampaignExecutor:
""" """
result = CampaignResult(campaign_name=self.config.name) result = CampaignResult(campaign_name=self.config.name)
loops = self.config.loops
logger.info( logger.info(
f"Starting campaign '{self.config.name}': " f"Starting campaign '{self.config.name}': "
f"{self.config.total_steps()} steps, " f"{self.config.total_steps()} steps"
f"~{self.config.total_capture_time_s():.0f}s capture time" + (f" ({self.config.total_steps() // loops} × {loops} loops)" if loops > 1 else "")
+ f", ~{self.config.total_capture_time_s():.0f}s capture time"
) )
self._init_sdr() self._init_sdr()
@ -228,29 +250,36 @@ class CampaignExecutor:
total = self.config.total_steps() total = self.config.total_steps()
step_index = 0 step_index = 0
for transmitter in self.config.transmitters: for loop_idx in range(loops):
logger.info(f"Transmitter: {transmitter.id} ({len(transmitter.schedule)} steps)") if loops > 1:
for step in transmitter.schedule: logger.info(f"Loop {loop_idx + 1}/{loops}")
step_result = self._execute_step(transmitter, step) for transmitter in self.config.transmitters:
result.steps.append(step_result) logger.info(f"Transmitter: {transmitter.id} ({len(transmitter.schedule)} steps)")
step_index += 1 for step in transmitter.schedule:
looped_step = replace(step, label=f"{step.label}_run{loop_idx + 1:02d}") if loops > 1 else step
step_result = self._execute_step(transmitter, looped_step)
result.steps.append(step_result)
step_index += 1
if self.progress_cb: if self.progress_cb:
self.progress_cb(step_index, total, step_result) self.progress_cb(step_index, total, step_result)
if step_result.error: if step_result.error:
logger.warning(f"Step '{step.label}' error: {step_result.error}") logger.warning(f"Step '{looped_step.label}' error: {step_result.error}")
elif step_result.qa.flagged: elif step_result.qa.flagged:
logger.warning(f"Step '{step.label}' flagged for review: " + "; ".join(step_result.qa.issues)) logger.warning(
else: f"Step '{looped_step.label}' flagged for review: " + "; ".join(step_result.qa.issues)
logger.info( )
f"Step '{step.label}' OK " else:
f"(SNR {step_result.qa.snr_db:.1f} dB, " logger.info(
f"{step_result.qa.duration_s:.1f}s)" f"Step '{looped_step.label}' OK "
) f"(SNR {step_result.qa.snr_db:.1f} dB, "
f"{step_result.qa.duration_s:.1f}s)"
)
finally: finally:
self._close_sdr() self._close_sdr()
self._close_remote_tx_controllers() self._close_remote_tx_controllers()
self._close_tx_executors()
result.end_time = time.time() result.end_time = time.time()
logger.info( logger.info(
@ -325,6 +354,12 @@ class CampaignExecutor:
logger.warning(f"Error closing remote Tx controller {tx_id}: {exc}") logger.warning(f"Error closing remote Tx controller {tx_id}: {exc}")
self._remote_tx_controllers.clear() self._remote_tx_controllers.clear()
def _close_tx_executors(self) -> None:
for tx_id, (_, stop_event, t) in list(self._tx_executors.items()):
stop_event.set()
t.join(timeout=5.0)
self._tx_executors.clear()
def _record(self, duration_s: float) -> Recording: def _record(self, duration_s: float) -> Recording:
"""Capture ``duration_s`` seconds of IQ samples.""" """Capture ``duration_s`` seconds of IQ samples."""
num_samples = int(duration_s * self.config.recorder.sample_rate) num_samples = int(duration_s * self.config.recorder.sample_rate)
@ -369,6 +404,7 @@ class CampaignExecutor:
step=step, step=step,
capture_timestamp=capture_timestamp, capture_timestamp=capture_timestamp,
campaign_name=self.config.name, campaign_name=self.config.name,
tx_params=_extract_tx_params(transmitter),
) )
# QA # QA
@ -437,6 +473,30 @@ class CampaignExecutor:
# Start transmission in background; _record() runs concurrently # Start transmission in background; _record() runs concurrently
ctrl.transmit_async(step.duration + 1.0) ctrl.transmit_async(step.duration + 1.0)
elif transmitter.control_method == "sdr_agent":
if self.skip_local_tx:
logger.debug(f"skip_local_tx — TX for '{transmitter.id}' delegated to TX agent node")
return
if not transmitter.sdr_agent:
logger.warning(f"Transmitter '{transmitter.id}' has no sdr_agent config — skipping")
return
step_dict: dict = {"label": step.label, "duration": step.duration + 1.0}
if step.power_dbm is not None:
step_dict["power_dbm"] = step.power_dbm
tx_config = {
"id": transmitter.id,
"sdr_agent": transmitter.sdr_agent,
"schedule": [step_dict],
}
rec = self.config.recorder
tx_device = transmitter.device or rec.device
sdr_device = _DEVICE_ALIASES.get(tx_device.lower(), tx_device.lower())
stop_event = threading.Event()
executor = TxExecutor(tx_config, sdr_device=sdr_device, stop_event=stop_event)
t = threading.Thread(target=executor.run, daemon=True, name=f"tx-{transmitter.id}")
self._tx_executors[transmitter.id] = (executor, stop_event, t)
t.start()
else: else:
logger.warning(f"Unknown control method '{transmitter.control_method}' — skipping") logger.warning(f"Unknown control method '{transmitter.control_method}' — skipping")
@ -459,6 +519,13 @@ class CampaignExecutor:
if ctrl is not None: if ctrl is not None:
ctrl.wait_transmit(timeout=step.duration + 10.0) ctrl.wait_transmit(timeout=step.duration + 10.0)
elif transmitter.control_method == "sdr_agent":
entry = self._tx_executors.pop(transmitter.id, None)
if entry is not None:
_, stop_event, t = entry
stop_event.set()
t.join(timeout=step.duration + 10.0)
@staticmethod @staticmethod
def _step_params_json(transmitter: TransmitterConfig, step: CaptureStep) -> str: def _step_params_json(transmitter: TransmitterConfig, step: CaptureStep) -> str:
"""Serialise step parameters to a JSON string for the control script.""" """Serialise step parameters to a JSON string for the control script."""

View File

@ -15,6 +15,7 @@ def label_recording(
step: CaptureStep, step: CaptureStep,
capture_timestamp: float, capture_timestamp: float,
campaign_name: Optional[str] = None, campaign_name: Optional[str] = None,
tx_params: Optional[dict] = None,
) -> Recording: ) -> Recording:
"""Apply device identity and capture configuration labels to a recording's metadata. """Apply device identity and capture configuration labels to a recording's metadata.
@ -27,6 +28,9 @@ def label_recording(
step: The capture step that was active during this recording. step: The capture step that was active during this recording.
capture_timestamp: Unix timestamp (float) of when capture started. capture_timestamp: Unix timestamp (float) of when capture started.
campaign_name: Optional campaign name for cross-recording reference. campaign_name: Optional campaign name for cross-recording reference.
tx_params: Optional dict of transmitter signal parameters (e.g. modulation,
order, symbol_rate) written as ``ria:tx_<key>`` fields so downstream
training pipelines know what was transmitted into the recording.
Returns: Returns:
The same recording with updated metadata. The same recording with updated metadata.
@ -57,6 +61,11 @@ def label_recording(
if step.power_dbm is not None: if step.power_dbm is not None:
recording.update_metadata("tx_power_dbm", step.power_dbm) recording.update_metadata("tx_power_dbm", step.power_dbm)
# Transmitter signal parameters (e.g. from sdr_agent synthetic generation)
if tx_params:
for key, value in tx_params.items():
recording.update_metadata(f"tx_{key}", value)
return recording return recording

View File

@ -0,0 +1,299 @@
"""TX campaign executor — synthesises and transmits signals via a local SDR.
The TxExecutor receives a transmitter config dict (matching the
``sdr_agent`` control method's schema) and a step schedule, then for each
step builds a signal chain with the block generator and transmits it via
the local SDR device.
Supported modulations (``modulation`` field in config):
BPSK, QPSK, 8PSK, 16QAM, 64QAM, 256QAM, FSK, OOK, GMSK, OQPSK
Example config dict (matches CampaignConfig transmitter with
``control_method: sdr_agent``)::
{
"id": "synthetic-tx",
"type": "sdr",
"control_method": "sdr_agent",
"sdr_agent": {
"modulation": "QPSK",
"order": 4,
"symbol_rate": 1000000,
"center_frequency": 0.0,
"filter": "rrc",
"rolloff": 0.35
},
"schedule": [
{"label": "step1", "duration": 10, "power_dbm": -10}
]
}
"""
from __future__ import annotations
import logging
import threading
from typing import Any
logger = logging.getLogger(__name__)
def _parse_hz(val: object) -> float:
"""Parse a frequency value that may be a float (Hz) or a string like '2.45GHz'."""
if isinstance(val, (int, float)):
return float(val)
s = str(val).strip()
for suffix, mult in (("GHz", 1e9), ("MHz", 1e6), ("kHz", 1e3), ("Hz", 1.0)):
if s.endswith(suffix):
return float(s[: -len(suffix)]) * mult
return float(s)
def _parse_seconds(val: object) -> float:
"""Parse a duration value that may be a float (seconds) or a string like '5s'."""
if isinstance(val, (int, float)):
return float(val)
s = str(val).strip()
return float(s[:-1]) if s.endswith("s") else float(s)
# Mapping from modulation name → (PSK/QAM order, generator_type)
# 'psk' uses PSKGenerator, 'qam' uses QAMGenerator
_MOD_TABLE: dict[str, tuple[int, str]] = {
"BPSK": (1, "psk"),
"QPSK": (2, "psk"),
"8PSK": (3, "psk"),
"16QAM": (4, "qam"),
"64QAM": (6, "qam"),
"256QAM": (8, "qam"),
}
_SPECIAL_MODS = {"FSK", "OOK", "GMSK", "OQPSK"}
# usrp-uhd-client's tx_recording() streams 2 000-sample chunks and loops the
# source buffer for the full tx_time, so only this many samples ever need to
# be in RAM regardless of step duration or sample rate.
# 50 000 complex64 samples ≈ 400 kB — enough spectral diversity for looping.
_SYNTH_BLOCK_SAMPLES = 50_000
class TxExecutor:
"""Synthesise and transmit a signal campaign via a local SDR.
Args:
config: Transmitter config dict (must have ``sdr_agent`` sub-dict with
modulation params, and ``schedule`` list of step dicts).
sdr_device: SDR device name to open in TX mode (e.g. "pluto", "usrp").
stop_event: External event that aborts the TX loop mid-step.
"""
def __init__(
self,
config: dict,
sdr_device: str = "unknown",
stop_event: threading.Event | None = None,
) -> None:
self.config = config
self.sdr_device = sdr_device
self.stop_event = stop_event or threading.Event()
self._sdr: Any = None
def run(self) -> None:
"""Execute all steps in the schedule, transmitting for each step duration."""
agent_cfg: dict = self.config.get("sdr_agent") or {}
schedule: list[dict] = self.config.get("schedule") or []
if not schedule:
logger.warning("TxExecutor: no schedule steps — nothing to transmit")
return
modulation: str = agent_cfg.get("modulation", "QPSK").upper()
symbol_rate: float = float(agent_cfg.get("symbol_rate", 1e6))
center_freq: float = _parse_hz(agent_cfg.get("center_frequency", 0.0))
filter_type: str = agent_cfg.get("filter", "rrc").lower()
rolloff: float = float(agent_cfg.get("rolloff", 0.35))
loops: int = max(1, int(self.config.get("loops", 1)))
# Upsampling factor: samples_per_symbol, fixed at 8 for SDR compatibility.
sps = 8
sample_rate = symbol_rate * sps
self._init_sdr(sample_rate, center_freq)
try:
for loop_idx in range(loops):
if self.stop_event.is_set():
break
if loops > 1:
logger.info("TX loop %d/%d", loop_idx + 1, loops)
for step in schedule:
if self.stop_event.is_set():
break
looped_step = (
{**step, "label": f"{step.get('label', 'step')}_run{loop_idx + 1:02d}"} if loops > 1 else step
)
self._execute_step(looped_step, modulation, sps, symbol_rate, filter_type, rolloff)
finally:
self._close_sdr()
def _execute_step(
self,
step: dict,
modulation: str,
sps: int,
symbol_rate: float,
filter_type: str,
rolloff: float,
) -> None:
duration: float = _parse_seconds(step.get("duration", 10.0))
label: str = step.get("label", "step")
gain: float = float(step.get("power_dbm") or 0.0)
sample_rate = symbol_rate * sps
logger.info(
"TX step '%s': %.0f s, %s @ %.3f MHz (sps=%d, filter=%s)",
label,
duration,
modulation,
symbol_rate / 1e6,
sps,
filter_type,
)
num_samples = int(duration * sample_rate)
# Synthesise a short representative block. tx_recording() loops this
# buffer for the full tx_time using a 2 000-sample streaming callback,
# so peak memory is O(_SYNTH_BLOCK_SAMPLES) regardless of duration.
block_size = min(num_samples, _SYNTH_BLOCK_SAMPLES)
signal = self._synthesise(modulation, sps, block_size, filter_type, rolloff)
if self._sdr is not None:
try:
# Apply gain update if SDR supports it
if hasattr(self._sdr, "set_tx_gain"):
self._sdr.set_tx_gain(gain)
self._sdr.tx_recording(signal, tx_time=duration)
except Exception as exc:
logger.error("TX step '%s' SDR error: %s", label, exc)
else:
# No SDR available — simulate by sleeping for the step duration.
logger.warning("TX step '%s': no SDR — simulating %.0f s delay", label, duration)
self.stop_event.wait(timeout=duration)
def _synthesise(
self,
modulation: str,
sps: int,
num_samples: int,
filter_type: str,
rolloff: float,
):
"""Build a block-generator chain and return IQ samples as a numpy array."""
try:
import numpy as np
from ria_toolkit_oss.signal.block_generator import (
BinarySource,
GMSKModulator,
Mapper,
OOKModulator,
OQPSKModulator,
RaisedCosineFilter,
RootRaisedCosineFilter,
Upsampling,
)
from ria_toolkit_oss.signal.block_generator.continuous_modulation.fsk_modulator import (
FSKModulator,
)
except ImportError as exc:
raise RuntimeError(f"ria_toolkit_oss block generator not available: {exc}") from exc
# ── Special modulations with their own source-connected modulator ──
if modulation in ("OOK", "GMSK", "OQPSK"):
src = BinarySource()
if modulation == "OOK":
mod = OOKModulator(src, samples_per_symbol=sps)
elif modulation == "GMSK":
mod = GMSKModulator(src, samples_per_symbol=sps)
else:
mod = OQPSKModulator(src, samples_per_symbol=sps)
recording = mod.record(num_samples)
flat = np.asarray(recording.data).flatten().astype(np.complex64)
if len(flat) < num_samples:
flat = np.tile(flat, num_samples // len(flat) + 1)
return flat[:num_samples]
if modulation == "FSK":
symbol_rate = num_samples / sps
bits_per_sym = 1 # 2-FSK
num_bits = max(num_samples // sps, 128) * bits_per_sym
bits = BinarySource()((1, num_bits))
mod = FSKModulator(
num_bits_per_symbol=bits_per_sym,
frequency_spacing=symbol_rate * 0.5,
symbol_duration=1.0 / max(symbol_rate, 1.0),
sampling_frequency=symbol_rate * sps,
)
flat = np.asarray(mod(bits)).flatten().astype(np.complex64)
if len(flat) < num_samples:
flat = np.tile(flat, num_samples // len(flat) + 1)
return flat[:num_samples]
# ── PSK / QAM via Mapper → Upsampling → pulse filter ──────────────
if modulation not in _MOD_TABLE:
logger.warning("Unknown modulation %r — defaulting to QPSK", modulation)
modulation = "QPSK"
bits_per_sym, gen_type = _MOD_TABLE[modulation]
mod_family = "QAM" if gen_type == "qam" else "PSK"
source = BinarySource()
mapper = Mapper(constellation_type=mod_family, num_bits_per_symbol=bits_per_sym)
upsampler = Upsampling(factor=sps)
mapper.connect_input([source])
upsampler.connect_input([mapper])
if filter_type in ("rrc",):
pulse_filter = RootRaisedCosineFilter(span_in_symbols=6, upsampling_factor=sps, beta=rolloff)
pulse_filter.connect_input([upsampler])
recording = pulse_filter.record(num_samples)
elif filter_type in ("rc",):
pulse_filter = RaisedCosineFilter(span_in_symbols=6, upsampling_factor=sps, beta=rolloff)
pulse_filter.connect_input([upsampler])
recording = pulse_filter.record(num_samples)
else:
# "none", "rect", "gaussian" — use upsampler output directly
recording = upsampler.record(num_samples)
flat = np.asarray(recording.data).flatten().astype(np.complex64)
if len(flat) < num_samples:
flat = np.tile(flat, num_samples // len(flat) + 1)
return flat[:num_samples]
def _init_sdr(self, sample_rate: float, center_freq: float) -> None:
try:
from ria_toolkit_oss.sdr import get_sdr_device
self._sdr = get_sdr_device(self.sdr_device)
self._sdr.init_tx(
sample_rate=sample_rate,
center_frequency=center_freq,
gain=0,
channel=0,
gain_mode="manual",
)
logger.info(
"TX SDR initialised: %s @ %.3f MHz, %.1f Msps", self.sdr_device, center_freq / 1e6, sample_rate / 1e6
)
except Exception as exc:
logger.warning("TX SDR init failed (%s) — will simulate: %s", self.sdr_device, exc)
self._sdr = None
def _close_sdr(self) -> None:
if self._sdr is not None:
try:
self._sdr.close()
except Exception as exc:
logger.debug("TX SDR close error: %s", exc)
self._sdr = None

View File

@ -3,7 +3,7 @@
from fastapi import Depends, FastAPI from fastapi import Depends, FastAPI
from .auth import require_api_key from .auth import require_api_key
from .routers import inference, orchestrator from .routers import conductor, inference
def create_app(api_key: str = "") -> FastAPI: def create_app(api_key: str = "") -> FastAPI:
@ -28,9 +28,9 @@ def create_app(api_key: str = "") -> FastAPI:
app.state.api_key = api_key app.state.api_key = api_key
app.include_router( app.include_router(
orchestrator.router, conductor.router,
prefix="/orchestrator", prefix="/conductor",
tags=["Orchestrator"], tags=["Conductor"],
dependencies=[Depends(require_api_key)], dependencies=[Depends(require_api_key)],
) )
app.include_router( app.include_router(

View File

@ -7,7 +7,7 @@ from pathlib import Path
from pydantic import BaseModel, field_validator from pydantic import BaseModel, field_validator
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# Orchestrator # Conductor
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------

View File

@ -1,4 +1,4 @@
"""Orchestrator routes: campaign deployment, status, and cancellation.""" """Conductor routes: campaign deployment, status, and cancellation."""
from __future__ import annotations from __future__ import annotations

View File

@ -23,9 +23,9 @@ def serve(host: str, port: int, api_key: str, log_level: str):
\b \b
Endpoints: Endpoints:
POST /orchestrator/deploy POST /conductor/deploy
GET /orchestrator/status/{campaign_id} GET /conductor/status/{campaign_id}
POST /orchestrator/cancel/{campaign_id} POST /conductor/cancel/{campaign_id}
POST /inference/load POST /inference/load
POST /inference/start POST /inference/start
POST /inference/stop POST /inference/stop

View File

@ -0,0 +1,314 @@
"""Tests for orchestration executor — StepResult, CampaignResult, _run_script, _extract_tx_params."""
from __future__ import annotations
import json
import stat
from types import SimpleNamespace
import pytest
from ria_toolkit_oss.orchestration.executor import (
CampaignResult,
StepResult,
_extract_tx_params,
_run_script,
)
from ria_toolkit_oss.orchestration.qa import QAResult
def _ok_qa() -> QAResult:
return QAResult(passed=True, flagged=False, snr_db=20.0, duration_s=1.0)
def _flagged_qa() -> QAResult:
return QAResult(passed=True, flagged=True, snr_db=5.0, duration_s=1.0, issues=["low SNR"])
def _failed_qa() -> QAResult:
return QAResult(passed=False, flagged=True, snr_db=0.0, duration_s=0.0, issues=["no signal"])
# ---------------------------------------------------------------------------
# StepResult
# ---------------------------------------------------------------------------
class TestStepResult:
def test_ok_true_when_no_error_and_qa_passed(self):
r = StepResult(
transmitter_id="tx1",
step_label="step1",
output_path="/out/rec.sigmf-data",
qa=_ok_qa(),
capture_timestamp=0.0,
)
assert r.ok is True
def test_ok_false_when_error_set(self):
r = StepResult(
transmitter_id="tx1",
step_label="step1",
output_path=None,
qa=_ok_qa(),
capture_timestamp=0.0,
error="SDR failed",
)
assert r.ok is False
def test_ok_false_when_qa_not_passed(self):
r = StepResult(
transmitter_id="tx1",
step_label="step1",
output_path="/out",
qa=_failed_qa(),
capture_timestamp=0.0,
)
assert r.ok is False
def test_to_dict_contains_required_keys(self):
r = StepResult(
transmitter_id="tx1",
step_label="step1",
output_path="/out/rec.sigmf-data",
qa=_ok_qa(),
capture_timestamp=1234.5,
)
d = r.to_dict()
assert d["transmitter_id"] == "tx1"
assert d["step_label"] == "step1"
assert d["output_path"] == "/out/rec.sigmf-data"
assert d["capture_timestamp"] == pytest.approx(1234.5)
assert d["error"] is None
assert d["qa"]["passed"] is True
def test_to_dict_includes_error_when_set(self):
r = StepResult(
transmitter_id="tx1",
step_label="step1",
output_path=None,
qa=_failed_qa(),
capture_timestamp=0.0,
error="disk full",
)
assert r.to_dict()["error"] == "disk full"
# ---------------------------------------------------------------------------
# CampaignResult
# ---------------------------------------------------------------------------
class TestCampaignResult:
def _make(self, steps: list) -> CampaignResult:
r = CampaignResult(campaign_name="test_campaign")
r.steps = steps
r.end_time = r.start_time + 5.0
return r
def test_total_steps(self):
r = self._make(
[
StepResult("tx1", "s1", "/out", _ok_qa(), 0.0),
StepResult("tx1", "s2", "/out", _ok_qa(), 0.0),
]
)
assert r.total_steps == 2
def test_passed_count(self):
r = self._make(
[
StepResult("tx1", "s1", "/out", _ok_qa(), 0.0),
StepResult("tx1", "s2", "/out", _failed_qa(), 0.0),
]
)
assert r.passed == 1
def test_failed_count(self):
r = self._make(
[
StepResult("tx1", "s1", "/out", _ok_qa(), 0.0),
StepResult("tx1", "s2", "/out", _failed_qa(), 0.0),
]
)
assert r.failed == 1
def test_flagged_count(self):
r = self._make(
[
StepResult("tx1", "s1", "/out", _ok_qa(), 0.0),
StepResult("tx1", "s2", "/out", _flagged_qa(), 0.0),
]
)
assert r.flagged == 1
def test_error_step_counts_as_failed_not_passed(self):
r = self._make(
[
StepResult("tx1", "s1", None, _ok_qa(), 0.0, error="disk full"),
]
)
assert r.failed == 1
assert r.passed == 0
def test_duration_s_from_end_time(self):
r = CampaignResult(campaign_name="c")
r.start_time = 100.0
r.end_time = 115.0
assert r.duration_s == pytest.approx(15.0)
def test_to_dict_structure(self):
r = self._make([StepResult("tx1", "s1", "/out", _ok_qa(), 0.0)])
d = r.to_dict()
assert d["campaign_name"] == "test_campaign"
assert d["total_steps"] == 1
assert d["passed"] == 1
assert len(d["steps"]) == 1
def test_write_report(self, tmp_path):
r = self._make([StepResult("tx1", "s1", "/out", _ok_qa(), 0.0)])
out = tmp_path / "report.json"
r.write_report(str(out))
assert out.exists()
data = json.loads(out.read_text())
assert data["campaign_name"] == "test_campaign"
def test_write_report_creates_nested_dirs(self, tmp_path):
r = self._make([])
out = tmp_path / "nested" / "deep" / "report.json"
r.write_report(str(out))
assert out.exists()
# ---------------------------------------------------------------------------
# _run_script
# ---------------------------------------------------------------------------
class TestRunScript:
def _script(self, tmp_path, body: str) -> str:
s = tmp_path / "script.sh"
s.write_text("#!/bin/sh\n" + body)
s.chmod(s.stat().st_mode | stat.S_IEXEC)
return str(s)
def test_returns_stdout(self, tmp_path):
out = _run_script(self._script(tmp_path, 'echo "hello world"'))
assert out == "hello world"
def test_passes_args_to_script(self, tmp_path):
out = _run_script(self._script(tmp_path, 'echo "$1 $2"'), "configure", "arg2")
assert "configure" in out
def test_raises_on_nonzero_exit(self, tmp_path):
with pytest.raises(RuntimeError, match="exited 1"):
_run_script(self._script(tmp_path, "exit 1"))
def test_raises_on_relative_path(self):
with pytest.raises(RuntimeError, match="absolute"):
_run_script("relative/script.sh")
def test_raises_on_missing_file(self, tmp_path):
with pytest.raises(RuntimeError):
_run_script(str(tmp_path / "nonexistent.sh"))
def test_raises_on_timeout(self, tmp_path):
with pytest.raises(RuntimeError, match="timed out"):
_run_script(self._script(tmp_path, "sleep 60"), timeout=0.1)
def test_stderr_included_in_error_message(self, tmp_path):
with pytest.raises(RuntimeError) as exc_info:
_run_script(self._script(tmp_path, "echo 'bad thing' >&2; exit 1"))
assert "bad thing" in str(exc_info.value)
# ---------------------------------------------------------------------------
# _extract_tx_params
# ---------------------------------------------------------------------------
class TestExtractTxParams:
def test_returns_none_when_no_sdr_agent_attribute(self):
tx = SimpleNamespace()
assert _extract_tx_params(tx) is None
def test_returns_none_when_sdr_agent_is_none(self):
tx = SimpleNamespace(sdr_agent=None)
assert _extract_tx_params(tx) is None
def test_returns_none_when_sdr_agent_is_empty_dict(self):
tx = SimpleNamespace(sdr_agent={})
assert _extract_tx_params(tx) is None
def test_returns_signal_params(self):
tx = SimpleNamespace(
sdr_agent={
"modulation": "QPSK",
"symbol_rate": 1e6,
"center_frequency": 2.4e9,
}
)
result = _extract_tx_params(tx)
assert result == {"modulation": "QPSK", "symbol_rate": 1e6, "center_frequency": 2.4e9}
def test_strips_infra_key_node_id(self):
tx = SimpleNamespace(
sdr_agent={
"modulation": "BPSK",
"node_id": "node_abc123",
}
)
result = _extract_tx_params(tx)
assert "node_id" not in result
assert result == {"modulation": "BPSK"}
def test_strips_infra_key_session_code(self):
tx = SimpleNamespace(
sdr_agent={
"modulation": "FSK",
"session_code": "amber-peak-transmit",
}
)
result = _extract_tx_params(tx)
assert "session_code" not in result
def test_strips_none_values(self):
tx = SimpleNamespace(
sdr_agent={
"modulation": "QPSK",
"order": None,
"rolloff": 0.35,
}
)
result = _extract_tx_params(tx)
assert "order" not in result
assert result == {"modulation": "QPSK", "rolloff": 0.35}
def test_does_not_mutate_source_dict(self):
cfg = {"modulation": "QPSK", "node_id": "nid", "session_code": "code"}
tx = SimpleNamespace(sdr_agent=cfg)
_extract_tx_params(tx)
assert "node_id" in cfg
def test_full_sdr_agent_config(self):
tx = SimpleNamespace(
sdr_agent={
"modulation": "16QAM",
"order": 4,
"symbol_rate": 5e6,
"center_frequency": 915e6,
"filter": "rrc",
"rolloff": 0.35,
"node_id": "node_xyz",
"session_code": "some-code",
}
)
result = _extract_tx_params(tx)
assert result == {
"modulation": "16QAM",
"order": 4,
"symbol_rate": 5e6,
"center_frequency": 915e6,
"filter": "rrc",
"rolloff": 0.35,
}

View File

@ -109,6 +109,38 @@ class TestLabelRecording:
result = label_recording(rec, "iphone13_001", _wifi_step(), time.time()) result = label_recording(rec, "iphone13_001", _wifi_step(), time.time())
assert result is rec assert result is rec
def test_tx_params_none_by_default(self):
rec = label_recording(_simple_recording(), "iphone13_001", _wifi_step(), time.time())
tx_keys = [k for k in rec.metadata if k.startswith("tx_")]
assert tx_keys == []
def test_tx_params_written_as_tx_prefix_keys(self):
params = {"modulation": "QPSK", "symbol_rate": 1e6}
rec = label_recording(_simple_recording(), "dev", _wifi_step(), time.time(), tx_params=params)
assert rec.metadata["tx_modulation"] == "QPSK"
assert rec.metadata["tx_symbol_rate"] == pytest.approx(1e6)
def test_tx_params_multiple_fields(self):
params = {
"modulation": "16QAM",
"order": 4,
"symbol_rate": 5e6,
"center_frequency": 915e6,
"filter": "rrc",
"rolloff": 0.35,
}
rec = label_recording(_simple_recording(), "dev", _wifi_step(), time.time(), tx_params=params)
for k, v in params.items():
assert f"tx_{k}" in rec.metadata
assert (
rec.metadata[f"tx_{k}"] == pytest.approx(v) if isinstance(v, float) else rec.metadata[f"tx_{k}"] == v
)
def test_tx_params_empty_dict_writes_nothing(self):
rec = label_recording(_simple_recording(), "dev", _wifi_step(), time.time(), tx_params={})
tx_keys = [k for k in rec.metadata if k.startswith("tx_") and k != "tx_power_dbm"]
assert tx_keys == []
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# build_output_filename # build_output_filename

View File

@ -0,0 +1,153 @@
"""Tests for TxExecutor — signal synthesis and step execution."""
from __future__ import annotations
import threading
from unittest.mock import patch
import numpy as np
import pytest
from ria_toolkit_oss.orchestration.tx_executor import TxExecutor
def _cfg(modulation="QPSK", symbol_rate=100_000, steps=None):
return {
"id": "test-tx",
"type": "sdr",
"control_method": "sdr_agent",
"sdr_agent": {
"modulation": modulation,
"symbol_rate": symbol_rate,
"center_frequency": 0.0,
"filter": "rrc",
"rolloff": 0.35,
},
"schedule": steps or [{"label": "step1", "duration": 0.001, "power_dbm": -10}],
}
# ---------------------------------------------------------------------------
# Initialisation
# ---------------------------------------------------------------------------
class TestTxExecutorInit:
def test_stores_sdr_device(self):
ex = TxExecutor(_cfg(), sdr_device="pluto")
assert ex.sdr_device == "pluto"
def test_stop_event_created_when_not_supplied(self):
ex = TxExecutor(_cfg())
assert isinstance(ex.stop_event, threading.Event)
assert not ex.stop_event.is_set()
def test_accepts_external_stop_event(self):
ev = threading.Event()
ex = TxExecutor(_cfg(), stop_event=ev)
assert ex.stop_event is ev
# ---------------------------------------------------------------------------
# run() — schedule iteration
# ---------------------------------------------------------------------------
class TestTxExecutorRun:
def test_empty_schedule_returns_immediately(self):
cfg = _cfg(steps=[])
ex = TxExecutor(cfg)
ex.run() # must not raise or block
def test_pre_set_stop_event_skips_all_steps(self):
ev = threading.Event()
ev.set()
ex = TxExecutor(_cfg(), stop_event=ev)
# If stop was set, _execute_step should never be called.
# run() should return cleanly without attempting synthesis.
ex.run()
def test_no_sdr_falls_back_to_simulation(self, monkeypatch):
"""Without SDR hardware TxExecutor simulates by calling stop_event.wait."""
cfg = _cfg(steps=[{"label": "s", "duration": 0.001, "power_dbm": 0}])
waited = []
real_ev = threading.Event()
def _fake_wait(timeout=None):
waited.append(timeout)
return False
monkeypatch.setattr(real_ev, "wait", _fake_wait)
# Patch SDR init to always fail (forces simulation path)
with patch.object(TxExecutor, "_init_sdr", lambda self, *a, **kw: setattr(self, "_sdr", None)):
ex = TxExecutor(cfg, sdr_device="nonexistent_xyz", stop_event=real_ev)
ex.run()
assert len(waited) >= 1, "expected stop_event.wait to be called for simulation"
# ---------------------------------------------------------------------------
# _synthesise — all modulation types and filter types
# ---------------------------------------------------------------------------
class TestSynthesise:
@pytest.fixture(autouse=True)
def _ex(self):
self.ex = TxExecutor(_cfg())
def _synth(self, mod, num_samples=256):
return self.ex._synthesise(mod, sps=4, num_samples=num_samples, filter_type="rrc", rolloff=0.35)
@pytest.mark.parametrize("mod", ["BPSK", "QPSK", "8PSK", "16QAM", "64QAM", "256QAM"])
def test_psk_qam_returns_complex64_array(self, mod):
sig = self._synth(mod)
assert sig.dtype == np.complex64
assert len(sig) == 256
def test_fsk_returns_correct_length(self):
sig = self._synth("FSK")
assert len(sig) == 256
def test_ook_returns_correct_length(self):
sig = self._synth("OOK")
assert len(sig) == 256
def test_gmsk_returns_correct_length(self):
sig = self._synth("GMSK")
assert len(sig) == 256
def test_oqpsk_returns_correct_length(self):
sig = self._synth("OQPSK")
assert len(sig) == 256
@pytest.mark.parametrize("mod", ["BPSK", "QPSK", "16QAM", "FSK", "OOK", "GMSK"])
def test_samples_are_finite(self, mod):
sig = self._synth(mod)
assert np.all(np.isfinite(sig.real)), f"{mod}: non-finite real samples"
assert np.all(np.isfinite(sig.imag)), f"{mod}: non-finite imag samples"
def test_unknown_modulation_defaults_to_qpsk(self):
sig = self._synth("UNKNOWN_MOD_XYZ")
assert len(sig) == 256
assert sig.dtype == np.complex64
@pytest.mark.parametrize("filter_type", ["rrc", "rc", "gaussian", "rect", "none"])
def test_all_filter_types(self, filter_type):
sig = self.ex._synthesise("QPSK", sps=4, num_samples=128, filter_type=filter_type, rolloff=0.35)
assert len(sig) == 128
@pytest.mark.parametrize("n", [64, 128, 512, 1024])
def test_output_length_matches_requested_samples(self, n):
sig = self._synth("QPSK", num_samples=n)
assert len(sig) == n
def test_bpsk_output_is_complex_not_real(self):
sig = self._synth("BPSK")
# complex64 always has imag part; just check dtype
assert sig.dtype == np.complex64
def test_256qam_correct_length(self):
sig = self._synth("256QAM")
assert len(sig) == 256

View File

@ -189,6 +189,8 @@ class TestNoiseCommand:
"10000", "10000",
"--noise-type", "--noise-type",
"gaussian", "gaussian",
"--power",
"0.01",
"--output", "--output",
output, output,
"-q", "-q",
@ -234,7 +236,7 @@ class TestNoiseCommand:
"--num-samples", "--num-samples",
"10000", "10000",
"--power", "--power",
"0.5", "0.01",
"--output", "--output",
output, output,
"-q", "-q",

View File

@ -1,6 +1,6 @@
"""Tests for the RT-OSS HTTP server. """Tests for the RT-OSS HTTP server.
Covers: auth, inference lifecycle (without SDR/ONNX hardware), orchestrator Covers: auth, inference lifecycle (without SDR/ONNX hardware), conductor
lifecycle (with mocked executor), and state helpers. lifecycle (with mocked executor), and state helpers.
``start_inference`` and ``_inference_loop`` require real SDR hardware and an ``start_inference`` and ``_inference_loop`` require real SDR hardware and an
@ -286,17 +286,17 @@ class TestInferenceStop:
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# POST /orchestrator/deploy # POST /conductor/deploy
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
class TestOrchestratorDeploy: class TestConductorDeploy:
def test_deploy_422_on_invalid_config(self, client): def test_deploy_422_on_invalid_config(self, client):
with patch( with patch(
"ria_toolkit_oss.server.routers.orchestrator.CampaignConfig.from_dict", "ria_toolkit_oss.server.routers.conductor.CampaignConfig.from_dict",
side_effect=ValueError("missing required field 'name'"), side_effect=ValueError("missing required field 'name'"),
): ):
resp = client.post("/orchestrator/deploy", json={"config": {}}) resp = client.post("/conductor/deploy", json={"config": {}})
assert resp.status_code == 422 assert resp.status_code == 422
def test_deploy_returns_campaign_id(self, client): def test_deploy_returns_campaign_id(self, client):
@ -307,10 +307,10 @@ class TestOrchestratorDeploy:
mock_executor.return_value.run.return_value = MagicMock(to_dict=lambda: {}) mock_executor.return_value.run.return_value = MagicMock(to_dict=lambda: {})
with ( with (
patch("ria_toolkit_oss.server.routers.orchestrator.CampaignConfig.from_dict", return_value=mock_cfg), patch("ria_toolkit_oss.server.routers.conductor.CampaignConfig.from_dict", return_value=mock_cfg),
patch("ria_toolkit_oss.server.routers.orchestrator.CampaignExecutor", mock_executor), patch("ria_toolkit_oss.server.routers.conductor.CampaignExecutor", mock_executor),
): ):
resp = client.post("/orchestrator/deploy", json={"config": {"name": "test_campaign"}}) resp = client.post("/conductor/deploy", json={"config": {"name": "test_campaign"}})
assert resp.status_code == 200 assert resp.status_code == 200
body = resp.json() body = resp.json()
@ -325,23 +325,23 @@ class TestOrchestratorDeploy:
mock_executor.return_value.run.return_value = MagicMock(to_dict=lambda: {}) mock_executor.return_value.run.return_value = MagicMock(to_dict=lambda: {})
with ( with (
patch("ria_toolkit_oss.server.routers.orchestrator.CampaignConfig.from_dict", return_value=mock_cfg), patch("ria_toolkit_oss.server.routers.conductor.CampaignConfig.from_dict", return_value=mock_cfg),
patch("ria_toolkit_oss.server.routers.orchestrator.CampaignExecutor", mock_executor), patch("ria_toolkit_oss.server.routers.conductor.CampaignExecutor", mock_executor),
): ):
resp = client.post("/orchestrator/deploy", json={"config": {}}) resp = client.post("/conductor/deploy", json={"config": {}})
campaign_id = resp.json()["campaign_id"] campaign_id = resp.json()["campaign_id"]
assert state_module._campaigns.get(campaign_id) is not None assert state_module._campaigns.get(campaign_id) is not None
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# GET /orchestrator/status/{campaign_id} # GET /conductor/status/{campaign_id}
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
class TestOrchestratorStatus: class TestConductorStatus:
def test_status_404_for_unknown_id(self, client): def test_status_404_for_unknown_id(self, client):
resp = client.get("/orchestrator/status/nonexistent-id") resp = client.get("/conductor/status/nonexistent-id")
assert resp.status_code == 404 assert resp.status_code == 404
def test_status_returns_campaign_state(self, client): def test_status_returns_campaign_state(self, client):
@ -357,7 +357,7 @@ class TestOrchestratorStatus:
) )
state_module._campaigns["abc-123"] = state state_module._campaigns["abc-123"] = state
resp = client.get("/orchestrator/status/abc-123") resp = client.get("/conductor/status/abc-123")
assert resp.status_code == 200 assert resp.status_code == 200
body = resp.json() body = resp.json()
assert body["campaign_id"] == "abc-123" assert body["campaign_id"] == "abc-123"
@ -367,13 +367,13 @@ class TestOrchestratorStatus:
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# POST /orchestrator/cancel/{campaign_id} # POST /conductor/cancel/{campaign_id}
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
class TestOrchestratorCancel: class TestConductorCancel:
def test_cancel_404_for_unknown_id(self, client): def test_cancel_404_for_unknown_id(self, client):
resp = client.post("/orchestrator/cancel/no-such-id") resp = client.post("/conductor/cancel/no-such-id")
assert resp.status_code == 404 assert resp.status_code == 404
def test_cancel_sets_cancel_event(self, client): def test_cancel_sets_cancel_event(self, client):
@ -387,7 +387,7 @@ class TestOrchestratorCancel:
) )
state_module._campaigns["camp-to-cancel"] = state state_module._campaigns["camp-to-cancel"] = state
resp = client.post("/orchestrator/cancel/camp-to-cancel") resp = client.post("/conductor/cancel/camp-to-cancel")
assert resp.status_code == 200 assert resp.status_code == 200
assert resp.json()["cancelled"] is True assert resp.json()["cancelled"] is True
assert cancel_event.is_set() assert cancel_event.is_set()
@ -403,7 +403,7 @@ class TestOrchestratorCancel:
) )
state_module._campaigns["done"] = state state_module._campaigns["done"] = state
resp = client.post("/orchestrator/cancel/done") resp = client.post("/conductor/cancel/done")
assert resp.status_code == 200 assert resp.status_code == 200
assert resp.json()["cancelled"] is False assert resp.json()["cancelled"] is False
assert not cancel_event.is_set() assert not cancel_event.is_set()

247
tests/test_agent.py Normal file
View File

@ -0,0 +1,247 @@
"""Tests for NodeAgent — TX role, session code, and TX command dispatch."""
from __future__ import annotations
import threading
import time
from unittest.mock import MagicMock, patch
from ria_toolkit_oss.agent import NodeAgent
def _agent(role="general", session_code=None, **kwargs):
return NodeAgent(
hub_url="http://hub.test",
api_key="test-key",
name="test-node",
sdr_device="mock",
role=role,
session_code=session_code,
**kwargs,
)
def _mock_register(agent, node_id="node_abc123"):
"""Patch _post so _register() returns a fake node_id response."""
resp = MagicMock()
resp.json.return_value = {"node_id": node_id}
resp.raise_for_status.return_value = None
agent._post = MagicMock(return_value=resp)
return agent._post
# ---------------------------------------------------------------------------
# Initialisation
# ---------------------------------------------------------------------------
class TestNodeAgentInit:
def test_stores_role_general(self):
assert _agent(role="general").role == "general"
def test_stores_role_tx(self):
assert _agent(role="tx").role == "tx"
def test_stores_role_rx(self):
assert _agent(role="rx").role == "rx"
def test_session_code_stored(self):
assert _agent(session_code="amber-peak-transmit").session_code == "amber-peak-transmit"
def test_session_code_none_by_default(self):
assert _agent().session_code is None
def test_tx_stop_event_created(self):
a = _agent()
assert isinstance(a._tx_stop, threading.Event)
def test_tx_thread_none_initially(self):
assert _agent()._tx_thread is None
def test_hub_url_trailing_slash_stripped(self):
a = NodeAgent(hub_url="http://hub.test/", api_key="k", name="n")
assert a.hub_url == "http://hub.test"
# ---------------------------------------------------------------------------
# _register payload
# ---------------------------------------------------------------------------
class TestNodeAgentRegisterPayload:
def _payload(self, agent):
post = _mock_register(agent)
agent._register()
_, kwargs = post.call_args
return kwargs["json"]
def test_general_role_in_payload(self):
payload = self._payload(_agent(role="general"))
assert payload["role"] == "general"
def test_tx_role_in_payload(self):
payload = self._payload(_agent(role="tx"))
assert payload["role"] == "tx"
def test_tx_role_adds_transmit_capability(self):
payload = self._payload(_agent(role="tx"))
assert "transmit" in payload["capabilities"]
def test_general_role_omits_transmit_capability(self):
payload = self._payload(_agent(role="general"))
assert "transmit" not in payload.get("capabilities", [])
def test_session_code_included_when_set(self):
payload = self._payload(_agent(role="tx", session_code="amber-peak-transmit"))
assert payload["session_code"] == "amber-peak-transmit"
def test_session_code_omitted_when_none(self):
payload = self._payload(_agent())
assert "session_code" not in payload
def test_register_stores_returned_node_id(self):
a = _agent()
_mock_register(a, node_id="node_xyz999")
a._register()
assert a.node_id == "node_xyz999"
def test_name_in_payload(self):
a = NodeAgent(hub_url="http://h", api_key="k", name="my-bench")
_mock_register(a)
a._register()
_, kwargs = a._post.call_args
assert kwargs["json"]["name"] == "my-bench"
def test_sdr_device_in_payload(self):
a = _agent()
post = _mock_register(a)
a._register()
_, kwargs = post.call_args
assert kwargs["json"]["sdr_device"] == "mock"
def test_campaign_capability_always_present(self):
for role in ("general", "rx", "tx"):
a = _agent(role=role)
payload = self._payload(a)
assert "campaign" in payload["capabilities"]
# ---------------------------------------------------------------------------
# _dispatch — TX commands
# ---------------------------------------------------------------------------
class TestNodeAgentDispatch:
def _make_agent(self):
a = _agent(role="tx")
a.node_id = "node_abc"
a._report_campaign_status = MagicMock()
return a
def test_start_transmit_spawns_thread(self):
a = self._make_agent()
done = threading.Event()
class _FakeExecutor:
def run(self_):
done.wait(timeout=2)
with patch("ria_toolkit_oss.orchestration.tx_executor.TxExecutor", return_value=_FakeExecutor()):
a._dispatch({"command": "start_transmit", "sdr_agent": {}, "schedule": []})
time.sleep(0.05)
assert a._tx_thread is not None
done.set()
def test_start_transmit_clears_stop_event(self):
a = self._make_agent()
a._tx_stop.set() # pre-set
done = threading.Event()
class _FakeExecutor:
def run(self_):
done.wait(timeout=2)
with patch("ria_toolkit_oss.orchestration.tx_executor.TxExecutor", return_value=_FakeExecutor()):
a._dispatch({"command": "start_transmit", "sdr_agent": {}, "schedule": []})
time.sleep(0.05)
assert not a._tx_stop.is_set()
done.set()
def test_stop_transmit_sets_stop_event(self):
a = self._make_agent()
a._dispatch({"command": "stop_transmit"})
assert a._tx_stop.is_set()
def test_configure_transmit_does_not_raise(self):
a = self._make_agent()
a._dispatch({"command": "configure_transmit", "modulation": "BPSK"})
def test_unknown_command_is_ignored(self):
a = self._make_agent()
a._dispatch({"command": "frobnicate_xyz"})
def test_duplicate_start_transmit_ignored_while_running(self):
a = self._make_agent()
done = threading.Event()
run_calls = []
class _FakeExecutor:
def run(self_):
run_calls.append(1)
done.wait(timeout=2)
with patch("ria_toolkit_oss.orchestration.tx_executor.TxExecutor", return_value=_FakeExecutor()):
a._dispatch({"command": "start_transmit"})
time.sleep(0.05)
a._dispatch({"command": "start_transmit"}) # second while first alive
done.set()
time.sleep(0.05)
assert len(run_calls) == 1
def test_run_campaign_dispatched_in_thread(self):
a = self._make_agent()
done = threading.Event()
with patch("ria_toolkit_oss.agent.NodeAgent._run_campaign") as mock_run:
mock_run.side_effect = lambda *_: done.set()
a._dispatch({"command": "run_campaign", "campaign_id": "c1", "payload": {}})
done.wait(timeout=2)
assert mock_run.called
# ---------------------------------------------------------------------------
# _stop_transmit
# ---------------------------------------------------------------------------
class TestStopTransmit:
def test_no_thread_noop(self):
a = _agent()
a._stop_transmit() # must not raise
def test_sets_stop_event(self):
a = _agent()
a._stop_transmit()
assert a._tx_stop.is_set()
def test_joins_live_thread(self):
a = _agent()
finished = threading.Event()
unblock = threading.Event()
def _task():
unblock.wait(timeout=2)
finished.set()
t = threading.Thread(target=_task, daemon=True)
t.start()
a._tx_thread = t
# Signal stop and trigger thread exit
a._tx_stop.set()
unblock.set()
a._stop_transmit()
assert not t.is_alive()