Compare commits
3 Commits
2182899162
...
5718e109b5
| Author | SHA1 | Date | |
|---|---|---|---|
| 5718e109b5 | |||
| d81c61c3cf | |||
| 54b9bd4fc8 |
|
|
@ -14,19 +14,40 @@ Usage::
|
||||||
[--device plutosdr] \\
|
[--device plutosdr] \\
|
||||||
[--insecure]
|
[--insecure]
|
||||||
|
|
||||||
|
# Or store credentials in a config file and omit them from the command line:
|
||||||
|
ria-agent --config ~/.config/ria-agent/config.json --name lab-bench-1
|
||||||
|
|
||||||
The agent:
|
The agent:
|
||||||
1. Registers with RIA Hub and receives a ``node_id``.
|
1. Registers with RIA Hub and receives a ``node_id``.
|
||||||
2. Sends a heartbeat every 30 s so the hub knows it is online.
|
2. Sends a heartbeat every 30 s so the hub knows it is online.
|
||||||
3. Long-polls ``GET /orchestrator/nodes/{id}/commands`` (30 s timeout).
|
3. Long-polls ``GET /orchestrator/nodes/{id}/commands`` (30 s timeout).
|
||||||
4. Executes received campaigns via :class:`ria_toolkit_oss.orchestration.executor.CampaignExecutor`.
|
4. Dispatches received commands:
|
||||||
5. Uploads recordings to the hub via chunked POST, keeping each request
|
- ``run_campaign``: executes via CampaignExecutor, uploads recordings.
|
||||||
under 50 MB so it passes through Cloudflare without needing the bypass
|
- ``load_model``: loads an ONNX fingerprint or detector model.
|
||||||
subdomain.
|
- ``start_inference``: opens the SDR, runs the inference loop, posts
|
||||||
6. Deregisters cleanly on SIGINT / SIGTERM.
|
detection events to the hub for SSE fan-out to browsers.
|
||||||
|
- ``stop_inference``: gracefully stops the inference loop.
|
||||||
|
- ``configure_inference``: queues an SDR parameter update (applied at the
|
||||||
|
next capture boundary without restarting the loop).
|
||||||
|
5. Deregisters cleanly on SIGINT / SIGTERM.
|
||||||
|
|
||||||
|
Config file (JSON, optional)::
|
||||||
|
|
||||||
|
{
|
||||||
|
"hub": "https://riahub.company.com",
|
||||||
|
"key": "secret",
|
||||||
|
"name": "lab-bench-1",
|
||||||
|
"device": "plutosdr",
|
||||||
|
"insecure": false,
|
||||||
|
"log_level": "INFO"
|
||||||
|
}
|
||||||
|
|
||||||
|
CLI arguments always override config file values.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
import logging
|
import logging
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
|
|
@ -49,6 +70,8 @@ _POLL_CLIENT_TIMEOUT = 40 # client read timeout — slightly longer than server
|
||||||
_RECONNECT_PAUSE = 5 # seconds to wait after a poll error before retrying
|
_RECONNECT_PAUSE = 5 # seconds to wait after a poll error before retrying
|
||||||
_CHUNK_SIZE = 50 * 1024 * 1024 # 50 MB — well below Cloudflare's 100 MB limit
|
_CHUNK_SIZE = 50 * 1024 * 1024 # 50 MB — well below Cloudflare's 100 MB limit
|
||||||
_DIRECT_THRESHOLD = 90 * 1024 * 1024 # files above this use chunked upload
|
_DIRECT_THRESHOLD = 90 * 1024 * 1024 # files above this use chunked upload
|
||||||
|
_CAPTURE_SAMPLES = 4096 # IQ samples per inference window
|
||||||
|
_IDLE_LABELS = frozenset({"noise", "idle", "no_signal", "unknown_protocol", "background"})
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
@ -80,6 +103,30 @@ class NodeAgent:
|
||||||
self.node_id: str | None = None
|
self.node_id: str | None = None
|
||||||
self._stop = threading.Event()
|
self._stop = threading.Event()
|
||||||
|
|
||||||
|
# ── Inference state ─────────────────────────────────────────────────
|
||||||
|
# Protected by _inf_lock for cross-thread model swaps.
|
||||||
|
self._inf_lock = threading.Lock()
|
||||||
|
self._inf_session: Any = None # primary fingerprint ONNX session
|
||||||
|
self._inf_index_to_label: dict[int, str] = {}
|
||||||
|
self._inf_detector_session: Any = None # optional protocol-detector session
|
||||||
|
self._inf_detector_index_to_label: dict[int, str] = {}
|
||||||
|
self._inf_detector_threshold: float = 0.7
|
||||||
|
self._inf_pending_config: dict = {} # queued SDR attribute updates
|
||||||
|
|
||||||
|
self._inf_stop = threading.Event()
|
||||||
|
self._inf_thread: threading.Thread | None = None
|
||||||
|
|
||||||
|
# Detect optional dependencies once at startup so capability
|
||||||
|
# advertising is accurate from the first registration.
|
||||||
|
try:
|
||||||
|
import onnxruntime as _ort_mod
|
||||||
|
|
||||||
|
self._ort: Any = _ort_mod
|
||||||
|
self._ort_available = True
|
||||||
|
except ImportError:
|
||||||
|
self._ort = None
|
||||||
|
self._ort_available = False
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import ria_toolkit_oss
|
import ria_toolkit_oss
|
||||||
|
|
||||||
|
|
@ -114,6 +161,7 @@ class NodeAgent:
|
||||||
self._command_loop()
|
self._command_loop()
|
||||||
finally:
|
finally:
|
||||||
self._stop.set()
|
self._stop.set()
|
||||||
|
self._stop_inference()
|
||||||
self._deregister()
|
self._deregister()
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
|
|
@ -121,13 +169,16 @@ class NodeAgent:
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
def _register(self) -> None:
|
def _register(self) -> None:
|
||||||
|
capabilities = ["campaign"]
|
||||||
|
if self._ort_available:
|
||||||
|
capabilities.append("inference")
|
||||||
resp = self._post(
|
resp = self._post(
|
||||||
"/orchestrator/nodes/register",
|
"/orchestrator/nodes/register",
|
||||||
json={
|
json={
|
||||||
"name": self.name,
|
"name": self.name,
|
||||||
"sdr_device": self.sdr_device,
|
"sdr_device": self.sdr_device,
|
||||||
"ria_toolkit_version": self._ria_version,
|
"ria_toolkit_version": self._ria_version,
|
||||||
"capabilities": ["inference", "campaign"],
|
"capabilities": capabilities,
|
||||||
},
|
},
|
||||||
timeout=15,
|
timeout=15,
|
||||||
)
|
)
|
||||||
|
|
@ -200,6 +251,24 @@ class NodeAgent:
|
||||||
daemon=True,
|
daemon=True,
|
||||||
name=f"campaign-{campaign_id[:8]}",
|
name=f"campaign-{campaign_id[:8]}",
|
||||||
).start()
|
).start()
|
||||||
|
elif command == "load_model":
|
||||||
|
threading.Thread(
|
||||||
|
target=self._load_model,
|
||||||
|
args=(cmd,),
|
||||||
|
daemon=True,
|
||||||
|
name="ria-load-model",
|
||||||
|
).start()
|
||||||
|
elif command == "start_inference":
|
||||||
|
threading.Thread(
|
||||||
|
target=self._start_inference,
|
||||||
|
args=(cmd,),
|
||||||
|
daemon=True,
|
||||||
|
name="ria-start-inf",
|
||||||
|
).start()
|
||||||
|
elif command == "stop_inference":
|
||||||
|
self._stop_inference()
|
||||||
|
elif command == "configure_inference":
|
||||||
|
self._queue_sdr_config(cmd)
|
||||||
else:
|
else:
|
||||||
logger.warning("Unknown command %r — ignored", command)
|
logger.warning("Unknown command %r — ignored", command)
|
||||||
|
|
||||||
|
|
@ -232,6 +301,270 @@ class NodeAgent:
|
||||||
logger.error("Campaign %s failed: %s", campaign_id[:8], exc)
|
logger.error("Campaign %s failed: %s", campaign_id[:8], exc)
|
||||||
self._report_campaign_status(campaign_id, "failed", error=str(exc))
|
self._report_campaign_status(campaign_id, "failed", error=str(exc))
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Inference — model loading
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def _load_model(self, cmd: dict) -> None:
|
||||||
|
"""Load an ONNX model into the fingerprint or detector slot.
|
||||||
|
|
||||||
|
The ``model_path`` field may be either a local filesystem path or an
|
||||||
|
``http(s)://`` URL; in the latter case the file is downloaded first.
|
||||||
|
"""
|
||||||
|
if not self._ort_available:
|
||||||
|
logger.error("load_model: onnxruntime is not installed — cannot load model")
|
||||||
|
return
|
||||||
|
|
||||||
|
model_path: str = cmd.get("model_path", "")
|
||||||
|
label_map: dict[str, int] = cmd.get("label_map") or {}
|
||||||
|
stage: str = cmd.get("stage", "fingerprint")
|
||||||
|
detector_threshold: float = float(cmd.get("detector_threshold") or 0.7)
|
||||||
|
|
||||||
|
if model_path.startswith(("http://", "https://")):
|
||||||
|
model_path = self._download_model(model_path)
|
||||||
|
if model_path is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
session = self._ort.InferenceSession(model_path, providers=["CPUExecutionProvider"])
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error("Failed to load model %r: %s", model_path, exc)
|
||||||
|
return
|
||||||
|
|
||||||
|
index_to_label = {v: k for k, v in label_map.items()}
|
||||||
|
with self._inf_lock:
|
||||||
|
if stage == "detector":
|
||||||
|
self._inf_detector_session = session
|
||||||
|
self._inf_detector_index_to_label = index_to_label
|
||||||
|
self._inf_detector_threshold = detector_threshold
|
||||||
|
logger.info(
|
||||||
|
"Detector model loaded: path=%s classes=%d threshold=%.2f",
|
||||||
|
model_path,
|
||||||
|
len(label_map),
|
||||||
|
detector_threshold,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self._inf_session = session
|
||||||
|
self._inf_index_to_label = index_to_label
|
||||||
|
logger.info(
|
||||||
|
"Fingerprint model loaded: path=%s classes=%d",
|
||||||
|
model_path,
|
||||||
|
len(label_map),
|
||||||
|
)
|
||||||
|
|
||||||
|
def _download_model(self, url: str) -> str | None:
|
||||||
|
"""Download a model from *url* to a temp file and return the local path."""
|
||||||
|
import tempfile
|
||||||
|
|
||||||
|
import requests as _requests
|
||||||
|
|
||||||
|
try:
|
||||||
|
logger.info("Downloading model from %s", url)
|
||||||
|
resp = _requests.get(
|
||||||
|
url,
|
||||||
|
headers={"X-API-Key": self.api_key},
|
||||||
|
verify=not self.insecure,
|
||||||
|
timeout=120,
|
||||||
|
)
|
||||||
|
resp.raise_for_status()
|
||||||
|
with tempfile.NamedTemporaryFile(suffix=".onnx", delete=False) as fh:
|
||||||
|
fh.write(resp.content)
|
||||||
|
path = fh.name
|
||||||
|
logger.info("Model downloaded to %s (%d bytes)", path, len(resp.content))
|
||||||
|
return path
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error("Model download from %s failed: %s", url, exc)
|
||||||
|
return None
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Inference — loop lifecycle
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def _start_inference(self, cmd: dict) -> None:
|
||||||
|
"""Start the SDR capture + ONNX inference loop."""
|
||||||
|
if not self._ort_available:
|
||||||
|
logger.error("start_inference: onnxruntime is not installed")
|
||||||
|
return
|
||||||
|
|
||||||
|
with self._inf_lock:
|
||||||
|
if self._inf_session is None:
|
||||||
|
logger.error("start_inference: no fingerprint model loaded — call load_model first")
|
||||||
|
return
|
||||||
|
|
||||||
|
if self._inf_thread is not None and self._inf_thread.is_alive():
|
||||||
|
logger.warning("start_inference: inference loop is already running — ignoring")
|
||||||
|
return
|
||||||
|
|
||||||
|
center_freq: float = float(cmd.get("center_freq", 2.4e9))
|
||||||
|
sample_rate: float = float(cmd.get("sample_rate", 10e6))
|
||||||
|
gain: float | str = cmd.get("gain", "auto")
|
||||||
|
device_type: str = cmd.get("device") or self.sdr_device
|
||||||
|
|
||||||
|
self._inf_stop.clear()
|
||||||
|
self._inf_thread = threading.Thread(
|
||||||
|
target=self._inference_loop,
|
||||||
|
args=(device_type, center_freq, sample_rate, gain),
|
||||||
|
daemon=True,
|
||||||
|
name="ria-agent-inference",
|
||||||
|
)
|
||||||
|
self._inf_thread.start()
|
||||||
|
logger.info(
|
||||||
|
"Inference started (device=%s freq=%.3f MHz rate=%.1f MHz)",
|
||||||
|
device_type,
|
||||||
|
center_freq / 1e6,
|
||||||
|
sample_rate / 1e6,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _stop_inference(self) -> None:
|
||||||
|
"""Signal the inference loop to stop and wait up to 5 s for it to exit."""
|
||||||
|
self._inf_stop.set()
|
||||||
|
if self._inf_thread is not None and self._inf_thread.is_alive():
|
||||||
|
self._inf_thread.join(timeout=5.0)
|
||||||
|
if self._inf_thread.is_alive():
|
||||||
|
logger.warning("Inference thread did not exit within 5 s")
|
||||||
|
logger.info("Inference stopped")
|
||||||
|
|
||||||
|
def _queue_sdr_config(self, cmd: dict) -> None:
|
||||||
|
"""Merge SDR parameter updates into the pending-config dict.
|
||||||
|
|
||||||
|
The inference loop checks this at each capture boundary and applies
|
||||||
|
the updates without restarting.
|
||||||
|
"""
|
||||||
|
cfg = {k: v for k, v in cmd.items() if k != "command" and v is not None}
|
||||||
|
with self._inf_lock:
|
||||||
|
self._inf_pending_config.update(cfg)
|
||||||
|
logger.debug("SDR reconfiguration queued: %s", cfg)
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Inference — main loop
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def _inference_loop(
|
||||||
|
self,
|
||||||
|
device_type: str,
|
||||||
|
center_freq: float,
|
||||||
|
sample_rate: float,
|
||||||
|
gain: float | str,
|
||||||
|
) -> None:
|
||||||
|
"""Continuous SDR capture → ONNX inference → POST events to hub.
|
||||||
|
|
||||||
|
Mirrors the two-stage pipeline in the hub's ``_inference_loop``:
|
||||||
|
an optional protocol-detector gates the fingerprint model so the
|
||||||
|
fingerprint model only runs when an active transmission is detected.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
from ria_toolkit_oss.sdr import get_sdr_device
|
||||||
|
except ImportError as exc:
|
||||||
|
logger.error("inference_loop: ria_toolkit_oss not installed: %s", exc)
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
sdr = get_sdr_device(device_type)
|
||||||
|
_apply_sdr_config(sdr, {"center_freq": center_freq, "sample_rate": sample_rate, "gain": gain})
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error("SDR initialisation failed: %s", exc)
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
try:
|
||||||
|
from ria_toolkit_oss.orchestration.qa import estimate_snr_db
|
||||||
|
except ImportError:
|
||||||
|
estimate_snr_db = None
|
||||||
|
|
||||||
|
# Snapshot model state once at loop start. If the hub sends a
|
||||||
|
# new load_model command while the loop is running, the new session
|
||||||
|
# will be picked up on the next loop restart (stop + start).
|
||||||
|
with self._inf_lock:
|
||||||
|
session = self._inf_session
|
||||||
|
index_to_label = dict(self._inf_index_to_label)
|
||||||
|
det_session = self._inf_detector_session
|
||||||
|
det_threshold = self._inf_detector_threshold
|
||||||
|
|
||||||
|
input_name = session.get_inputs()[0].name
|
||||||
|
det_input_name = det_session.get_inputs()[0].name if det_session else None
|
||||||
|
|
||||||
|
while not self._inf_stop.is_set():
|
||||||
|
# Apply any queued SDR configuration changes.
|
||||||
|
with self._inf_lock:
|
||||||
|
pending = self._inf_pending_config.copy()
|
||||||
|
self._inf_pending_config.clear()
|
||||||
|
if pending:
|
||||||
|
_apply_sdr_config(sdr, pending)
|
||||||
|
|
||||||
|
try:
|
||||||
|
samples = sdr.rx(_CAPTURE_SAMPLES)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("SDR capture error: %s", exc)
|
||||||
|
# Avoid a tight spin when the SDR is in a persistent error
|
||||||
|
# state (e.g. physically disconnected).
|
||||||
|
self._inf_stop.wait(timeout=0.5)
|
||||||
|
continue
|
||||||
|
|
||||||
|
samples = np.array(samples, dtype=np.complex64)
|
||||||
|
snr_db = float(estimate_snr_db(samples)) if estimate_snr_db is not None else 0.0
|
||||||
|
iq = np.stack([samples.real, samples.imag], axis=0).astype(np.float32)
|
||||||
|
|
||||||
|
# Stage 1: protocol detector gate (optional).
|
||||||
|
if det_session is not None:
|
||||||
|
det_out = _run_onnx_session(det_session, det_input_name, iq)
|
||||||
|
det_probs = _softmax(det_out[0][0])
|
||||||
|
det_confidence = float(det_probs.max())
|
||||||
|
if det_confidence < det_threshold:
|
||||||
|
# No active protocol detected — report idle and skip
|
||||||
|
# the fingerprint model for this window.
|
||||||
|
self._post_event(device_id=None, confidence=det_confidence, snr_db=snr_db)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Stage 2: fingerprint model.
|
||||||
|
out = _run_onnx_session(session, input_name, iq)
|
||||||
|
probs = _softmax(out[0][0])
|
||||||
|
pred_idx = int(probs.argmax())
|
||||||
|
confidence = float(probs[pred_idx])
|
||||||
|
device_id = index_to_label.get(pred_idx)
|
||||||
|
idle = (device_id in _IDLE_LABELS) if device_id else True
|
||||||
|
self._post_event(
|
||||||
|
device_id=None if idle else device_id,
|
||||||
|
confidence=confidence,
|
||||||
|
snr_db=snr_db,
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as exc:
|
||||||
|
logger.exception("Inference loop terminated unexpectedly: %s", exc)
|
||||||
|
finally:
|
||||||
|
try:
|
||||||
|
sdr.close()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
logger.info("Inference loop exited")
|
||||||
|
|
||||||
|
def _post_event(self, device_id: str | None, confidence: float, snr_db: float) -> None:
|
||||||
|
"""POST a single detection event to ``POST /orchestrator/nodes/{id}/events``.
|
||||||
|
|
||||||
|
Failures are logged at DEBUG level and silently swallowed so that a
|
||||||
|
transient network blip does not crash the inference loop.
|
||||||
|
"""
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
|
payload = {
|
||||||
|
"type": "detection",
|
||||||
|
"device_id": device_id,
|
||||||
|
"confidence": round(confidence, 6),
|
||||||
|
"snr_db": round(snr_db, 2),
|
||||||
|
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||||
|
}
|
||||||
|
try:
|
||||||
|
resp = self._post(
|
||||||
|
f"/orchestrator/nodes/{self.node_id}/events",
|
||||||
|
json=payload,
|
||||||
|
timeout=5,
|
||||||
|
)
|
||||||
|
if resp.status_code not in (200, 204):
|
||||||
|
logger.debug("Event POST returned HTTP %d", resp.status_code)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.debug("Event POST failed (will retry next inference cycle): %s", exc)
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
# Recording upload (chunked for large files)
|
# Recording upload (chunked for large files)
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
|
|
@ -244,7 +577,7 @@ class NodeAgent:
|
||||||
|
|
||||||
repo_owner, repo_name = output_repo.split("/", 1)
|
repo_owner, repo_name = output_repo.split("/", 1)
|
||||||
base_url = f"{self.hub_url}/datasets/upload"
|
base_url = f"{self.hub_url}/datasets/upload"
|
||||||
steps = getattr(result, "steps", None) or []
|
steps = (result.get("steps") if isinstance(result, dict) else getattr(result, "steps", None)) or []
|
||||||
|
|
||||||
for step in steps:
|
for step in steps:
|
||||||
output_path: str | None = getattr(step, "output_path", None)
|
output_path: str | None = getattr(step, "output_path", None)
|
||||||
|
|
@ -304,7 +637,6 @@ class NodeAgent:
|
||||||
headers = {"X-API-Key": self.api_key}
|
headers = {"X-API-Key": self.api_key}
|
||||||
verify = not self.insecure
|
verify = not self.insecure
|
||||||
|
|
||||||
# Small files: single POST (unchanged endpoint, no assembly needed server-side).
|
|
||||||
if size <= _DIRECT_THRESHOLD:
|
if size <= _DIRECT_THRESHOLD:
|
||||||
with open(file_path, "rb") as fh:
|
with open(file_path, "rb") as fh:
|
||||||
resp = _requests.post(
|
resp = _requests.post(
|
||||||
|
|
@ -318,7 +650,6 @@ class NodeAgent:
|
||||||
resp.raise_for_status()
|
resp.raise_for_status()
|
||||||
return resp.json()
|
return resp.json()
|
||||||
|
|
||||||
# Large files: chunked upload — each request is ≤ 50 MB.
|
|
||||||
total_chunks = math.ceil(size / _CHUNK_SIZE)
|
total_chunks = math.ceil(size / _CHUNK_SIZE)
|
||||||
upload_id = str(uuid.uuid4())
|
upload_id = str(uuid.uuid4())
|
||||||
chunk_url = base_url + "/chunk"
|
chunk_url = base_url + "/chunk"
|
||||||
|
|
@ -339,18 +670,13 @@ class NodeAgent:
|
||||||
chunk_url,
|
chunk_url,
|
||||||
headers=headers,
|
headers=headers,
|
||||||
files={"file": (filename, chunk, "application/octet-stream")},
|
files={"file": (filename, chunk, "application/octet-stream")},
|
||||||
data={
|
data={**metadata, "upload_id": upload_id, "chunk_index": i, "total_chunks": total_chunks},
|
||||||
**metadata,
|
|
||||||
"upload_id": upload_id,
|
|
||||||
"chunk_index": i,
|
|
||||||
"total_chunks": total_chunks,
|
|
||||||
},
|
|
||||||
timeout=120,
|
timeout=120,
|
||||||
verify=verify,
|
verify=verify,
|
||||||
)
|
)
|
||||||
if not resp.ok:
|
if not resp.ok:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
f"Chunk {i + 1}/{total_chunks} failed: " f"HTTP {resp.status_code}: {resp.text[:300]}"
|
f"Chunk {i + 1}/{total_chunks} failed: HTTP {resp.status_code}: {resp.text[:300]}"
|
||||||
)
|
)
|
||||||
resp_data = resp.json()
|
resp_data = resp.json()
|
||||||
logger.debug("Chunk %d/%d uploaded", i + 1, total_chunks)
|
logger.debug("Chunk %d/%d uploaded", i + 1, total_chunks)
|
||||||
|
|
@ -393,10 +719,41 @@ class NodeAgent:
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Helpers
|
# Module-level helpers (shared by NodeAgent._inference_loop)
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def _run_onnx_session(session: Any, input_name: str, iq: Any) -> list:
|
||||||
|
"""Run an ONNX session on an IQ array (2, N).
|
||||||
|
|
||||||
|
Tries channel-first layout (1, 2, N) first; falls back to interleaved flat
|
||||||
|
(1, 2*N) when the model expects a flattened input.
|
||||||
|
"""
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
x = iq[np.newaxis] # (1, 2, N)
|
||||||
|
try:
|
||||||
|
return session.run(None, {input_name: x})
|
||||||
|
except Exception:
|
||||||
|
return session.run(None, {input_name: iq.flatten()[np.newaxis]})
|
||||||
|
|
||||||
|
|
||||||
|
def _softmax(x: Any) -> Any:
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
e = np.exp(x - x.max())
|
||||||
|
return e / e.sum()
|
||||||
|
|
||||||
|
|
||||||
|
def _apply_sdr_config(sdr: Any, cfg: dict) -> None:
|
||||||
|
for attr in ("center_freq", "sample_rate", "gain"):
|
||||||
|
if attr in cfg:
|
||||||
|
try:
|
||||||
|
setattr(sdr, attr, cfg[attr])
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("SDR config %s=%r failed: %s", attr, cfg[attr], exc)
|
||||||
|
|
||||||
|
|
||||||
def _sigmf_files(data_path: str) -> list[str]:
|
def _sigmf_files(data_path: str) -> list[str]:
|
||||||
"""Return paths to both SigMF files (.sigmf-data and .sigmf-meta) for a recording."""
|
"""Return paths to both SigMF files (.sigmf-data and .sigmf-meta) for a recording."""
|
||||||
candidates = [data_path]
|
candidates = [data_path]
|
||||||
|
|
@ -405,6 +762,29 @@ def _sigmf_files(data_path: str) -> list[str]:
|
||||||
return [p for p in candidates if os.path.exists(p)]
|
return [p for p in candidates if os.path.exists(p)]
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Config file helpers
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
_DEFAULT_CONFIG_PATH = os.path.join(
|
||||||
|
os.environ.get("XDG_CONFIG_HOME", os.path.expanduser("~/.config")),
|
||||||
|
"ria-agent",
|
||||||
|
"config.json",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _load_config(path: str) -> dict:
|
||||||
|
"""Load a JSON config file, returning an empty dict if it does not exist."""
|
||||||
|
try:
|
||||||
|
with open(path) as fh:
|
||||||
|
return json.load(fh)
|
||||||
|
except FileNotFoundError:
|
||||||
|
return {}
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("Could not read config file %s: %s", path, exc)
|
||||||
|
return {}
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# CLI entry point
|
# CLI entry point
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
@ -420,67 +800,94 @@ def main() -> None:
|
||||||
"campaigns / inference on local SDR hardware."
|
"campaigns / inference on local SDR hardware."
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--config",
|
||||||
|
default=None,
|
||||||
|
metavar="PATH",
|
||||||
|
help=(
|
||||||
|
f"Path to a JSON config file (default: {_DEFAULT_CONFIG_PATH}). "
|
||||||
|
"CLI arguments override config file values."
|
||||||
|
),
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--hub",
|
"--hub",
|
||||||
required=True,
|
default=None,
|
||||||
metavar="URL",
|
metavar="URL",
|
||||||
help="RIA Hub base URL, e.g. https://riahub.company.com",
|
help="RIA Hub base URL, e.g. https://riahub.company.com",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--key",
|
"--key",
|
||||||
required=True,
|
default=None,
|
||||||
metavar="API_KEY",
|
metavar="API_KEY",
|
||||||
help="Shared API key (must match [wac] API_KEY in the hub's app.ini)",
|
help="Shared API key (must match [wac] API_KEY in the hub's app.ini)",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--name",
|
"--name",
|
||||||
required=True,
|
default=None,
|
||||||
metavar="NAME",
|
metavar="NAME",
|
||||||
help='Human-readable name shown in the Target Node dropdown, e.g. "lab-bench-1"',
|
help='Human-readable name shown in the Target Node dropdown, e.g. "lab-bench-1"',
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--device",
|
"--device",
|
||||||
default="unknown",
|
default=None,
|
||||||
metavar="SDR",
|
metavar="SDR",
|
||||||
help=(
|
help=(
|
||||||
"SDR device type reported to the hub (informational only). "
|
"SDR device type reported to the hub and used for inference. "
|
||||||
"Examples: plutosdr, usrp_b210, rtlsdr, mock. Default: unknown"
|
"Examples: plutosdr, usrp_b210, rtlsdr, mock. Default: unknown"
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--insecure",
|
"--insecure",
|
||||||
action="store_true",
|
action="store_true",
|
||||||
|
default=None,
|
||||||
help="Disable TLS certificate verification (dev/self-signed certs only)",
|
help="Disable TLS certificate verification (dev/self-signed certs only)",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--log-level",
|
"--log-level",
|
||||||
default="INFO",
|
default=None,
|
||||||
choices=["DEBUG", "INFO", "WARNING", "ERROR"],
|
choices=["DEBUG", "INFO", "WARNING", "ERROR"],
|
||||||
help="Logging verbosity (default: INFO)",
|
help="Logging verbosity (default: INFO)",
|
||||||
)
|
)
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# Merge: config file → CLI args (CLI wins).
|
||||||
|
config_path = args.config or _DEFAULT_CONFIG_PATH
|
||||||
|
cfg = _load_config(config_path)
|
||||||
|
|
||||||
|
hub = args.hub or cfg.get("hub")
|
||||||
|
key = args.key or cfg.get("key")
|
||||||
|
name = args.name or cfg.get("name")
|
||||||
|
device = args.device or cfg.get("device", "unknown")
|
||||||
|
insecure = args.insecure if args.insecure is not None else cfg.get("insecure", False)
|
||||||
|
log_level = args.log_level or cfg.get("log_level", "INFO")
|
||||||
|
|
||||||
|
if not hub:
|
||||||
|
parser.error("--hub is required (or set 'hub' in the config file)")
|
||||||
|
if not key:
|
||||||
|
parser.error("--key is required (or set 'key' in the config file)")
|
||||||
|
if not name:
|
||||||
|
parser.error("--name is required (or set 'name' in the config file)")
|
||||||
|
|
||||||
logging.basicConfig(
|
logging.basicConfig(
|
||||||
level=getattr(logging, args.log_level),
|
level=getattr(logging, log_level),
|
||||||
format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
|
format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
|
||||||
datefmt="%Y-%m-%d %H:%M:%S",
|
datefmt="%Y-%m-%d %H:%M:%S",
|
||||||
stream=sys.stderr,
|
stream=sys.stderr,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Warn loudly if --insecure is used outside of development.
|
if insecure:
|
||||||
if args.insecure:
|
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"--insecure disables TLS certificate verification. "
|
"--insecure disables TLS certificate verification. "
|
||||||
"Only use this for local development with self-signed certs."
|
"Only use this for local development with self-signed certs."
|
||||||
)
|
)
|
||||||
|
|
||||||
agent = NodeAgent(
|
agent = NodeAgent(
|
||||||
hub_url=args.hub,
|
hub_url=hub,
|
||||||
api_key=args.key,
|
api_key=key,
|
||||||
name=args.name,
|
name=name,
|
||||||
sdr_device=args.device,
|
sdr_device=device,
|
||||||
insecure=args.insecure,
|
insecure=insecure,
|
||||||
)
|
)
|
||||||
agent.run()
|
agent.run()
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user