"""RT-OSS Node Agent — connects to RIA Hub and dispatches work to local hardware. The agent runs on any machine with an SDR attached and connects **outbound** to RIA Hub. No inbound ports need to be opened on the user's machine, and the connection works identically through NAT, corporate firewalls, or a Pi on a cellular link. Usage:: ria-agent \\ --hub https://riahub.company.com \\ --key \\ --name lab-bench-1 \\ [--device plutosdr] \\ [--insecure] The agent: 1. Registers with RIA Hub and receives a ``node_id``. 2. Sends a heartbeat every 30 s so the hub knows it is online. 3. Long-polls ``GET /orchestrator/nodes/{id}/commands`` (30 s timeout). 4. Executes received campaigns via :class:`ria_toolkit_oss.orchestration.executor.CampaignExecutor`. 5. Uploads recordings to the hub via chunked POST, keeping each request under 50 MB so it passes through Cloudflare without needing the bypass subdomain. 6. Deregisters cleanly on SIGINT / SIGTERM. """ from __future__ import annotations import logging import math import os import signal import sys import threading import time import uuid from typing import Any logger = logging.getLogger("ria_agent") # --------------------------------------------------------------------------- # Tuneable constants # --------------------------------------------------------------------------- _HEARTBEAT_INTERVAL = 30 # seconds between heartbeats _POLL_TIMEOUT = 30 # server-side long-poll duration _POLL_CLIENT_TIMEOUT = 40 # client read timeout — slightly longer than server _RECONNECT_PAUSE = 5 # seconds to wait after a poll error before retrying _CHUNK_SIZE = 50 * 1024 * 1024 # 50 MB — well below Cloudflare's 100 MB limit _DIRECT_THRESHOLD = 90 * 1024 * 1024 # files above this use chunked upload # --------------------------------------------------------------------------- # Agent # --------------------------------------------------------------------------- class NodeAgent: """Outbound-connecting agent that bridges RIA Hub to local SDR hardware. All network I/O is initiated by the agent (outbound). RIA Hub never opens a connection back to the agent's machine. """ def __init__( self, hub_url: str, api_key: str, name: str, sdr_device: str = "unknown", insecure: bool = False, ) -> None: self.hub_url = hub_url.rstrip("/") self.api_key = api_key self.name = name self.sdr_device = sdr_device self.insecure = insecure self.node_id: str | None = None self._stop = threading.Event() try: import ria_toolkit_oss self._ria_version: str = getattr(ria_toolkit_oss, "__version__", "unknown") except Exception: self._ria_version = "unknown" # ------------------------------------------------------------------ # Public entry point # ------------------------------------------------------------------ def run(self) -> None: """Register, start the heartbeat thread, and enter the command loop. Blocks until SIGINT or SIGTERM is received. """ self._register() def _shutdown(sig: int, _frame: Any) -> None: logger.info("Shutdown signal received — stopping agent") self._stop.set() signal.signal(signal.SIGINT, _shutdown) signal.signal(signal.SIGTERM, _shutdown) hb = threading.Thread(target=self._heartbeat_loop, daemon=True, name="ria-agent-heartbeat") hb.start() logger.info("Agent %r online (node_id=%s, hub=%s)", self.name, self.node_id, self.hub_url) try: self._command_loop() finally: self._stop.set() self._deregister() # ------------------------------------------------------------------ # Registration # ------------------------------------------------------------------ def _register(self) -> None: resp = self._post( "/orchestrator/nodes/register", json={ "name": self.name, "sdr_device": self.sdr_device, "ria_toolkit_version": self._ria_version, "capabilities": ["inference", "campaign"], }, timeout=15, ) resp.raise_for_status() self.node_id = resp.json()["node_id"] logger.info("Registered as %r (node_id=%s)", self.name, self.node_id) def _deregister(self) -> None: if not self.node_id: return try: self._delete(f"/orchestrator/nodes/{self.node_id}", timeout=10) logger.info("Deregistered %s", self.node_id) except Exception as exc: logger.debug("Deregister failed (ignored on shutdown): %s", exc) # ------------------------------------------------------------------ # Heartbeat thread # ------------------------------------------------------------------ def _heartbeat_loop(self) -> None: while not self._stop.wait(_HEARTBEAT_INTERVAL): try: resp = self._post(f"/orchestrator/nodes/{self.node_id}/heartbeat", timeout=10) if resp.status_code == 404: logger.warning("Heartbeat got 404 — hub lost registration, re-registering") self._register() except Exception as exc: logger.warning("Heartbeat failed: %s", exc) # ------------------------------------------------------------------ # Command poll loop # ------------------------------------------------------------------ def _command_loop(self) -> None: while not self._stop.is_set(): try: resp = self._get( f"/orchestrator/nodes/{self.node_id}/commands", timeout=_POLL_CLIENT_TIMEOUT, ) if resp.status_code == 204: # No command within the timeout window — loop immediately. continue if resp.status_code == 404: logger.warning("Command poll got 404 — re-registering") self._register() continue resp.raise_for_status() cmd = resp.json() logger.info("Received command: %s", cmd.get("command")) self._dispatch(cmd) except Exception as exc: if not self._stop.is_set(): logger.warning("Command poll error: %s — retrying in %ds", exc, _RECONNECT_PAUSE) time.sleep(_RECONNECT_PAUSE) # ------------------------------------------------------------------ # Command dispatch # ------------------------------------------------------------------ def _dispatch(self, cmd: dict) -> None: command = cmd.get("command") if command == "run_campaign": campaign_id: str = cmd.get("campaign_id") or str(uuid.uuid4()) config_dict: dict = cmd.get("payload") or {} threading.Thread( target=self._run_campaign, args=(campaign_id, config_dict), daemon=True, name=f"campaign-{campaign_id[:8]}", ).start() else: logger.warning("Unknown command %r — ignored", command) # ------------------------------------------------------------------ # Campaign execution # ------------------------------------------------------------------ def _run_campaign(self, campaign_id: str, config_dict: dict) -> None: try: from ria_toolkit_oss.orchestration.campaign import CampaignConfig from ria_toolkit_oss.orchestration.executor import CampaignExecutor except ImportError as exc: logger.error( "Campaign %s cannot start — ria_toolkit_oss not fully installed: %s", campaign_id[:8], exc, ) return logger.info("Campaign %s starting", campaign_id[:8]) try: config = CampaignConfig.from_dict(config_dict) executor = CampaignExecutor(config) result = executor.run() logger.info("Campaign %s completed — uploading recordings", campaign_id[:8]) self._upload_recordings(campaign_id, config, result) except Exception as exc: logger.error("Campaign %s failed: %s", campaign_id[:8], exc) # ------------------------------------------------------------------ # Recording upload (chunked for large files) # ------------------------------------------------------------------ def _upload_recordings(self, campaign_id: str, config: Any, result: Any) -> None: output_repo: str | None = getattr(getattr(config, "output", None), "repo", None) if not output_repo or "/" not in output_repo: logger.warning("Campaign %s: no output.repo — skipping upload", campaign_id[:8]) return repo_owner, repo_name = output_repo.split("/", 1) base_url = f"{self.hub_url}/datasets/upload" steps = getattr(result, "steps", None) or [] for step in steps: output_path: str | None = getattr(step, "output_path", None) if not output_path: continue device_id: str = getattr(step, "transmitter_id", "") or "" for fpath in _sigmf_files(output_path): filename = os.path.basename(fpath) metadata = { "filename": filename, "repo_owner": repo_owner, "repo_name": repo_name, "device_id": device_id, "campaign_id": campaign_id, } try: resp_data = self._upload_file(base_url, fpath, metadata) logger.info( "Campaign %s: uploaded %s (oid=%s)", campaign_id[:8], filename, resp_data.get("oid", "?"), ) except Exception as exc: logger.warning("Campaign %s: upload of %s failed: %s", campaign_id[:8], filename, exc) def _upload_file(self, base_url: str, file_path: str, metadata: dict) -> dict: """Upload *file_path*, choosing chunked or direct path based on file size.""" import requests as _requests size = os.path.getsize(file_path) filename = os.path.basename(file_path) headers = {"X-API-Key": self.api_key} verify = not self.insecure # Small files: single POST (unchanged endpoint, no assembly needed server-side). if size <= _DIRECT_THRESHOLD: with open(file_path, "rb") as fh: resp = _requests.post( base_url, headers=headers, files={"file": (filename, fh)}, data=metadata, timeout=300, verify=verify, ) resp.raise_for_status() return resp.json() # Large files: chunked upload — each request is ≤ 50 MB. total_chunks = math.ceil(size / _CHUNK_SIZE) upload_id = str(uuid.uuid4()) chunk_url = base_url + "/chunk" logger.info( "Chunked upload: %s (%d bytes, %d × %d MB chunks)", filename, size, total_chunks, _CHUNK_SIZE // (1024 * 1024), ) resp_data: dict = {} with open(file_path, "rb") as fh: for i in range(total_chunks): chunk = fh.read(_CHUNK_SIZE) resp = _requests.post( chunk_url, headers=headers, files={"file": (filename, chunk, "application/octet-stream")}, data={ **metadata, "upload_id": upload_id, "chunk_index": i, "total_chunks": total_chunks, }, timeout=120, verify=verify, ) if not resp.ok: raise RuntimeError( f"Chunk {i + 1}/{total_chunks} failed: " f"HTTP {resp.status_code}: {resp.text[:300]}" ) resp_data = resp.json() logger.debug("Chunk %d/%d uploaded", i + 1, total_chunks) return resp_data # ------------------------------------------------------------------ # HTTP helpers # ------------------------------------------------------------------ def _get(self, path: str, **kwargs: Any): import requests as _requests return _requests.get( f"{self.hub_url}{path}", headers={"X-API-Key": self.api_key}, verify=not self.insecure, **kwargs, ) def _post(self, path: str, **kwargs: Any): import requests as _requests return _requests.post( f"{self.hub_url}{path}", headers={"X-API-Key": self.api_key}, verify=not self.insecure, **kwargs, ) def _delete(self, path: str, **kwargs: Any): import requests as _requests return _requests.delete( f"{self.hub_url}{path}", headers={"X-API-Key": self.api_key}, verify=not self.insecure, **kwargs, ) # --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- def _sigmf_files(data_path: str) -> list[str]: """Return paths to both SigMF files (.sigmf-data and .sigmf-meta) for a recording.""" candidates = [data_path] if data_path.endswith(".sigmf-data"): candidates.append(data_path[: -len(".sigmf-data")] + ".sigmf-meta") return [p for p in candidates if os.path.exists(p)] # --------------------------------------------------------------------------- # CLI entry point # --------------------------------------------------------------------------- def main() -> None: import argparse parser = argparse.ArgumentParser( prog="ria-agent", description=( "RT-OSS Node Agent — connects outbound to RIA Hub and executes " "campaigns / inference on local SDR hardware." ), ) parser.add_argument( "--hub", required=True, metavar="URL", help="RIA Hub base URL, e.g. https://riahub.company.com", ) parser.add_argument( "--key", required=True, metavar="API_KEY", help="Shared API key (must match [wac] API_KEY in the hub's app.ini)", ) parser.add_argument( "--name", required=True, metavar="NAME", help='Human-readable name shown in the Target Node dropdown, e.g. "lab-bench-1"', ) parser.add_argument( "--device", default="unknown", metavar="SDR", help=( "SDR device type reported to the hub (informational only). " "Examples: plutosdr, usrp_b210, rtlsdr, mock. Default: unknown" ), ) parser.add_argument( "--insecure", action="store_true", help="Disable TLS certificate verification (dev/self-signed certs only)", ) parser.add_argument( "--log-level", default="INFO", choices=["DEBUG", "INFO", "WARNING", "ERROR"], help="Logging verbosity (default: INFO)", ) args = parser.parse_args() logging.basicConfig( level=getattr(logging, args.log_level), format="%(asctime)s [%(levelname)s] %(name)s: %(message)s", datefmt="%Y-%m-%d %H:%M:%S", stream=sys.stderr, ) # Warn loudly if --insecure is used outside of development. if args.insecure: logger.warning( "--insecure disables TLS certificate verification. " "Only use this for local development with self-signed certs." ) agent = NodeAgent( hub_url=args.hub, api_key=args.key, name=args.name, sdr_device=args.device, insecure=args.insecure, ) agent.run() if __name__ == "__main__": main()