Compare commits
No commits in common. "44df45160e0be09d5db59b73dc8790ac1c1a92f2" and "b1e3ebf74ffa44af0a23bed054e65bafbbafc93a" have entirely different histories.
44df45160e
...
b1e3ebf74f
20
CHANGELOG.md
20
CHANGELOG.md
|
|
@ -1,20 +0,0 @@
|
|||
# Changelog
|
||||
|
||||
All notable changes to this project will be documented in this file.
|
||||
|
||||
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/) and [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
|
||||
|
||||
---
|
||||
## [0.1.1] - 2026-03-20
|
||||
|
||||
### Added
|
||||
|
||||
- **Campaign orchestration** — new `orchestration` module that manages the full lifecycle of an RF data collection campaign: SDR capture, automatic labeling, QA checks, and dataset packaging.
|
||||
- **HTTP inference server** — `ria-server` command starts a REST API server for deploying campaigns and controlling live inference from external systems such as the RIA Hub platform.
|
||||
- **Campaign CLI** — `ria campaign` commands for starting, monitoring, and managing campaigns from the terminal.
|
||||
|
||||
### Changed
|
||||
|
||||
- **Visualization layout** — recording and dataset views have been reformatted with improved sizing, repositioned titles, and updated Qoherent branding.
|
||||
|
||||
---
|
||||
2876
poetry.lock
generated
2876
poetry.lock
generated
File diff suppressed because it is too large
Load Diff
|
|
@ -85,25 +85,15 @@ build-backend = "poetry.core.masonry.api"
|
|||
[tool.poetry.group.test.dependencies]
|
||||
pytest = "^8.0.0"
|
||||
tox = "^4.19.0"
|
||||
fastapi = ">=0.111,<1.0"
|
||||
uvicorn = {version = ">=0.29,<1.0", extras = ["standard"]}
|
||||
onnxruntime = ">=1.17,<2.0"
|
||||
httpx = ">=0.27,<1.0"
|
||||
|
||||
[tool.poetry.group.docs.dependencies]
|
||||
sphinx = "^7.2.6"
|
||||
sphinx-rtd-theme = "^2.0.0"
|
||||
sphinx-autobuild = "^2024.2.4"
|
||||
|
||||
[tool.poetry.group.agent]
|
||||
optional = true
|
||||
|
||||
[tool.poetry.group.agent.dependencies]
|
||||
requests = ">=2.28,<3.0"
|
||||
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
flake8 = "^7.1.0"
|
||||
black = "^26.3.1"
|
||||
black = "^24.3.0"
|
||||
isort = "^5.13.2"
|
||||
pylint = "^3.2.6" # For pyreverse, to automate the creation of UML diagrams
|
||||
|
||||
|
|
@ -115,13 +105,6 @@ pylint = "^3.2.6" # For pyreverse, to automate the creation of UML diagrams
|
|||
[tool.poetry.scripts]
|
||||
ria = "ria_toolkit_oss_cli.cli:cli"
|
||||
ria-tools = "ria_toolkit_oss_cli.cli:cli"
|
||||
ria-server = "ria_toolkit_oss.server.cli:serve"
|
||||
ria-agent = "ria_toolkit_oss.agent:main"
|
||||
|
||||
[tool.poetry.group.server.dependencies]
|
||||
fastapi = ">=0.111,<1.0"
|
||||
uvicorn = {version = ">=0.29,<1.0", extras = ["standard"]}
|
||||
onnxruntime = ">=1.17,<2.0"
|
||||
|
||||
[tool.black]
|
||||
line-length = 119
|
||||
|
|
@ -144,8 +127,5 @@ exclude = '''
|
|||
)/
|
||||
'''
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
pythonpath = ["src"]
|
||||
|
||||
[tool.isort]
|
||||
profile = "black"
|
||||
|
|
|
|||
|
|
@ -1,462 +0,0 @@
|
|||
"""RT-OSS Node Agent — connects to RIA Hub and dispatches work to local hardware.
|
||||
|
||||
The agent runs on any machine with an SDR attached and connects **outbound** to
|
||||
RIA Hub. No inbound ports need to be opened on the user's machine, and the
|
||||
connection works identically through NAT, corporate firewalls, or a Pi on a
|
||||
cellular link.
|
||||
|
||||
Usage::
|
||||
|
||||
ria-agent \\
|
||||
--hub https://riahub.company.com \\
|
||||
--key <api-key> \\
|
||||
--name lab-bench-1 \\
|
||||
[--device plutosdr] \\
|
||||
[--insecure]
|
||||
|
||||
The agent:
|
||||
1. Registers with RIA Hub and receives a ``node_id``.
|
||||
2. Sends a heartbeat every 30 s so the hub knows it is online.
|
||||
3. Long-polls ``GET /orchestrator/nodes/{id}/commands`` (30 s timeout).
|
||||
4. Executes received campaigns via :class:`ria_toolkit_oss.orchestration.executor.CampaignExecutor`.
|
||||
5. Uploads recordings to the hub via chunked POST, keeping each request
|
||||
under 50 MB so it passes through Cloudflare without needing the bypass
|
||||
subdomain.
|
||||
6. Deregisters cleanly on SIGINT / SIGTERM.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
import signal
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
import uuid
|
||||
from typing import Any
|
||||
|
||||
logger = logging.getLogger("ria_agent")
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tuneable constants
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_HEARTBEAT_INTERVAL = 30 # seconds between heartbeats
|
||||
_POLL_TIMEOUT = 30 # server-side long-poll duration
|
||||
_POLL_CLIENT_TIMEOUT = 40 # client read timeout — slightly longer than server
|
||||
_RECONNECT_PAUSE = 5 # seconds to wait after a poll error before retrying
|
||||
_CHUNK_SIZE = 50 * 1024 * 1024 # 50 MB — well below Cloudflare's 100 MB limit
|
||||
_DIRECT_THRESHOLD = 90 * 1024 * 1024 # files above this use chunked upload
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Agent
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class NodeAgent:
|
||||
"""Outbound-connecting agent that bridges RIA Hub to local SDR hardware.
|
||||
|
||||
All network I/O is initiated by the agent (outbound). RIA Hub never opens
|
||||
a connection back to the agent's machine.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hub_url: str,
|
||||
api_key: str,
|
||||
name: str,
|
||||
sdr_device: str = "unknown",
|
||||
insecure: bool = False,
|
||||
) -> None:
|
||||
self.hub_url = hub_url.rstrip("/")
|
||||
self.api_key = api_key
|
||||
self.name = name
|
||||
self.sdr_device = sdr_device
|
||||
self.insecure = insecure
|
||||
|
||||
self.node_id: str | None = None
|
||||
self._stop = threading.Event()
|
||||
|
||||
try:
|
||||
import ria_toolkit_oss
|
||||
|
||||
self._ria_version: str = getattr(ria_toolkit_oss, "__version__", "unknown")
|
||||
except Exception:
|
||||
self._ria_version = "unknown"
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Public entry point
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def run(self) -> None:
|
||||
"""Register, start the heartbeat thread, and enter the command loop.
|
||||
|
||||
Blocks until SIGINT or SIGTERM is received.
|
||||
"""
|
||||
self._register()
|
||||
|
||||
def _shutdown(sig: int, _frame: Any) -> None:
|
||||
logger.info("Shutdown signal received — stopping agent")
|
||||
self._stop.set()
|
||||
|
||||
signal.signal(signal.SIGINT, _shutdown)
|
||||
signal.signal(signal.SIGTERM, _shutdown)
|
||||
|
||||
hb = threading.Thread(target=self._heartbeat_loop, daemon=True, name="ria-agent-heartbeat")
|
||||
hb.start()
|
||||
|
||||
logger.info("Agent %r online (node_id=%s, hub=%s)", self.name, self.node_id, self.hub_url)
|
||||
|
||||
try:
|
||||
self._command_loop()
|
||||
finally:
|
||||
self._stop.set()
|
||||
self._deregister()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Registration
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _register(self) -> None:
|
||||
resp = self._post(
|
||||
"/orchestrator/nodes/register",
|
||||
json={
|
||||
"name": self.name,
|
||||
"sdr_device": self.sdr_device,
|
||||
"ria_toolkit_version": self._ria_version,
|
||||
"capabilities": ["inference", "campaign"],
|
||||
},
|
||||
timeout=15,
|
||||
)
|
||||
resp.raise_for_status()
|
||||
self.node_id = resp.json()["node_id"]
|
||||
logger.info("Registered as %r (node_id=%s)", self.name, self.node_id)
|
||||
|
||||
def _deregister(self) -> None:
|
||||
if not self.node_id:
|
||||
return
|
||||
try:
|
||||
self._delete(f"/orchestrator/nodes/{self.node_id}", timeout=10)
|
||||
logger.info("Deregistered %s", self.node_id)
|
||||
except Exception as exc:
|
||||
logger.debug("Deregister failed (ignored on shutdown): %s", exc)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Heartbeat thread
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _heartbeat_loop(self) -> None:
|
||||
while not self._stop.wait(_HEARTBEAT_INTERVAL):
|
||||
try:
|
||||
resp = self._post(f"/orchestrator/nodes/{self.node_id}/heartbeat", timeout=10)
|
||||
if resp.status_code == 404:
|
||||
logger.warning("Heartbeat got 404 — hub lost registration, re-registering")
|
||||
self._register()
|
||||
except Exception as exc:
|
||||
logger.warning("Heartbeat failed: %s", exc)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Command poll loop
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _command_loop(self) -> None:
|
||||
while not self._stop.is_set():
|
||||
try:
|
||||
resp = self._get(
|
||||
f"/orchestrator/nodes/{self.node_id}/commands",
|
||||
timeout=_POLL_CLIENT_TIMEOUT,
|
||||
)
|
||||
if resp.status_code == 204:
|
||||
# No command within the timeout window — loop immediately.
|
||||
continue
|
||||
if resp.status_code == 404:
|
||||
logger.warning("Command poll got 404 — re-registering")
|
||||
self._register()
|
||||
continue
|
||||
resp.raise_for_status()
|
||||
cmd = resp.json()
|
||||
logger.info("Received command: %s", cmd.get("command"))
|
||||
self._dispatch(cmd)
|
||||
except Exception as exc:
|
||||
if not self._stop.is_set():
|
||||
logger.warning("Command poll error: %s — retrying in %ds", exc, _RECONNECT_PAUSE)
|
||||
time.sleep(_RECONNECT_PAUSE)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Command dispatch
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _dispatch(self, cmd: dict) -> None:
|
||||
command = cmd.get("command")
|
||||
if command == "run_campaign":
|
||||
campaign_id: str = cmd.get("campaign_id") or str(uuid.uuid4())
|
||||
config_dict: dict = cmd.get("payload") or {}
|
||||
threading.Thread(
|
||||
target=self._run_campaign,
|
||||
args=(campaign_id, config_dict),
|
||||
daemon=True,
|
||||
name=f"campaign-{campaign_id[:8]}",
|
||||
).start()
|
||||
else:
|
||||
logger.warning("Unknown command %r — ignored", command)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Campaign execution
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _run_campaign(self, campaign_id: str, config_dict: dict) -> None:
|
||||
try:
|
||||
from ria_toolkit_oss.orchestration.campaign import CampaignConfig
|
||||
from ria_toolkit_oss.orchestration.executor import CampaignExecutor
|
||||
except ImportError as exc:
|
||||
logger.error(
|
||||
"Campaign %s cannot start — ria_toolkit_oss not fully installed: %s",
|
||||
campaign_id[:8],
|
||||
exc,
|
||||
)
|
||||
return
|
||||
|
||||
logger.info("Campaign %s starting", campaign_id[:8])
|
||||
try:
|
||||
config = CampaignConfig.from_dict(config_dict)
|
||||
executor = CampaignExecutor(config)
|
||||
result = executor.run()
|
||||
logger.info("Campaign %s completed — uploading recordings", campaign_id[:8])
|
||||
self._upload_recordings(campaign_id, config, result)
|
||||
except Exception as exc:
|
||||
logger.error("Campaign %s failed: %s", campaign_id[:8], exc)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Recording upload (chunked for large files)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _upload_recordings(self, campaign_id: str, config: Any, result: Any) -> None:
|
||||
output_repo: str | None = getattr(getattr(config, "output", None), "repo", None)
|
||||
if not output_repo or "/" not in output_repo:
|
||||
logger.warning("Campaign %s: no output.repo — skipping upload", campaign_id[:8])
|
||||
return
|
||||
|
||||
repo_owner, repo_name = output_repo.split("/", 1)
|
||||
base_url = f"{self.hub_url}/datasets/upload"
|
||||
steps = getattr(result, "steps", None) or []
|
||||
|
||||
for step in steps:
|
||||
output_path: str | None = getattr(step, "output_path", None)
|
||||
if not output_path:
|
||||
continue
|
||||
device_id: str = getattr(step, "transmitter_id", "") or ""
|
||||
for fpath in _sigmf_files(output_path):
|
||||
filename = os.path.basename(fpath)
|
||||
metadata = {
|
||||
"filename": filename,
|
||||
"repo_owner": repo_owner,
|
||||
"repo_name": repo_name,
|
||||
"device_id": device_id,
|
||||
"campaign_id": campaign_id,
|
||||
}
|
||||
try:
|
||||
resp_data = self._upload_file(base_url, fpath, metadata)
|
||||
logger.info(
|
||||
"Campaign %s: uploaded %s (oid=%s)",
|
||||
campaign_id[:8],
|
||||
filename,
|
||||
resp_data.get("oid", "?"),
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.warning("Campaign %s: upload of %s failed: %s", campaign_id[:8], filename, exc)
|
||||
|
||||
def _upload_file(self, base_url: str, file_path: str, metadata: dict) -> dict:
|
||||
"""Upload *file_path*, choosing chunked or direct path based on file size."""
|
||||
import requests as _requests
|
||||
|
||||
size = os.path.getsize(file_path)
|
||||
filename = os.path.basename(file_path)
|
||||
headers = {"X-API-Key": self.api_key}
|
||||
verify = not self.insecure
|
||||
|
||||
# Small files: single POST (unchanged endpoint, no assembly needed server-side).
|
||||
if size <= _DIRECT_THRESHOLD:
|
||||
with open(file_path, "rb") as fh:
|
||||
resp = _requests.post(
|
||||
base_url,
|
||||
headers=headers,
|
||||
files={"file": (filename, fh)},
|
||||
data=metadata,
|
||||
timeout=300,
|
||||
verify=verify,
|
||||
)
|
||||
resp.raise_for_status()
|
||||
return resp.json()
|
||||
|
||||
# Large files: chunked upload — each request is ≤ 50 MB.
|
||||
total_chunks = math.ceil(size / _CHUNK_SIZE)
|
||||
upload_id = str(uuid.uuid4())
|
||||
chunk_url = base_url + "/chunk"
|
||||
|
||||
logger.info(
|
||||
"Chunked upload: %s (%d bytes, %d × %d MB chunks)",
|
||||
filename,
|
||||
size,
|
||||
total_chunks,
|
||||
_CHUNK_SIZE // (1024 * 1024),
|
||||
)
|
||||
|
||||
resp_data: dict = {}
|
||||
with open(file_path, "rb") as fh:
|
||||
for i in range(total_chunks):
|
||||
chunk = fh.read(_CHUNK_SIZE)
|
||||
resp = _requests.post(
|
||||
chunk_url,
|
||||
headers=headers,
|
||||
files={"file": (filename, chunk, "application/octet-stream")},
|
||||
data={
|
||||
**metadata,
|
||||
"upload_id": upload_id,
|
||||
"chunk_index": i,
|
||||
"total_chunks": total_chunks,
|
||||
},
|
||||
timeout=120,
|
||||
verify=verify,
|
||||
)
|
||||
if not resp.ok:
|
||||
raise RuntimeError(
|
||||
f"Chunk {i + 1}/{total_chunks} failed: " f"HTTP {resp.status_code}: {resp.text[:300]}"
|
||||
)
|
||||
resp_data = resp.json()
|
||||
logger.debug("Chunk %d/%d uploaded", i + 1, total_chunks)
|
||||
|
||||
return resp_data
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# HTTP helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _get(self, path: str, **kwargs: Any):
|
||||
import requests as _requests
|
||||
|
||||
return _requests.get(
|
||||
f"{self.hub_url}{path}",
|
||||
headers={"X-API-Key": self.api_key},
|
||||
verify=not self.insecure,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def _post(self, path: str, **kwargs: Any):
|
||||
import requests as _requests
|
||||
|
||||
return _requests.post(
|
||||
f"{self.hub_url}{path}",
|
||||
headers={"X-API-Key": self.api_key},
|
||||
verify=not self.insecure,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def _delete(self, path: str, **kwargs: Any):
|
||||
import requests as _requests
|
||||
|
||||
return _requests.delete(
|
||||
f"{self.hub_url}{path}",
|
||||
headers={"X-API-Key": self.api_key},
|
||||
verify=not self.insecure,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _sigmf_files(data_path: str) -> list[str]:
|
||||
"""Return paths to both SigMF files (.sigmf-data and .sigmf-meta) for a recording."""
|
||||
candidates = [data_path]
|
||||
if data_path.endswith(".sigmf-data"):
|
||||
candidates.append(data_path[: -len(".sigmf-data")] + ".sigmf-meta")
|
||||
return [p for p in candidates if os.path.exists(p)]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# CLI entry point
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def main() -> None:
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
prog="ria-agent",
|
||||
description=(
|
||||
"RT-OSS Node Agent — connects outbound to RIA Hub and executes "
|
||||
"campaigns / inference on local SDR hardware."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--hub",
|
||||
required=True,
|
||||
metavar="URL",
|
||||
help="RIA Hub base URL, e.g. https://riahub.company.com",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--key",
|
||||
required=True,
|
||||
metavar="API_KEY",
|
||||
help="Shared API key (must match [wac] API_KEY in the hub's app.ini)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--name",
|
||||
required=True,
|
||||
metavar="NAME",
|
||||
help='Human-readable name shown in the Target Node dropdown, e.g. "lab-bench-1"',
|
||||
)
|
||||
parser.add_argument(
|
||||
"--device",
|
||||
default="unknown",
|
||||
metavar="SDR",
|
||||
help=(
|
||||
"SDR device type reported to the hub (informational only). "
|
||||
"Examples: plutosdr, usrp_b210, rtlsdr, mock. Default: unknown"
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--insecure",
|
||||
action="store_true",
|
||||
help="Disable TLS certificate verification (dev/self-signed certs only)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--log-level",
|
||||
default="INFO",
|
||||
choices=["DEBUG", "INFO", "WARNING", "ERROR"],
|
||||
help="Logging verbosity (default: INFO)",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
logging.basicConfig(
|
||||
level=getattr(logging, args.log_level),
|
||||
format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
|
||||
datefmt="%Y-%m-%d %H:%M:%S",
|
||||
stream=sys.stderr,
|
||||
)
|
||||
|
||||
# Warn loudly if --insecure is used outside of development.
|
||||
if args.insecure:
|
||||
logger.warning(
|
||||
"--insecure disables TLS certificate verification. "
|
||||
"Only use this for local development with self-signed certs."
|
||||
)
|
||||
|
||||
agent = NodeAgent(
|
||||
hub_url=args.hub,
|
||||
api_key=args.key,
|
||||
name=args.name,
|
||||
sdr_device=args.device,
|
||||
insecure=args.insecure,
|
||||
)
|
||||
agent.run()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
@ -21,8 +21,7 @@ class DatasetBuilder(ABC):
|
|||
"""
|
||||
|
||||
_url: str = abstract_attribute()
|
||||
_SHA256: Optional[str] = None # SHA256 checksum.
|
||||
_MD5: Optional[str] = None # MD5 checksum.
|
||||
_SHA256: str # SHA256 checksum.
|
||||
_name: str = abstract_attribute()
|
||||
_author: str = abstract_attribute()
|
||||
_license: DatasetLicense = abstract_attribute()
|
||||
|
|
|
|||
|
|
@ -109,10 +109,13 @@ def copy_file(original_source: str | os.PathLike, new_source: str | os.PathLike)
|
|||
|
||||
:return: None
|
||||
"""
|
||||
with h5py.File(original_source, "r") as original_file:
|
||||
with h5py.File(new_source, "w") as new_file:
|
||||
for key in original_file.keys():
|
||||
original_file.copy(key, new_file)
|
||||
original_file = h5py.File(original_source, "r")
|
||||
|
||||
with h5py.File(new_source, "w") as new_file:
|
||||
for key in original_file.keys():
|
||||
original_file.copy(key, new_file)
|
||||
|
||||
original_file.close()
|
||||
|
||||
|
||||
def make_empty_clone(original_source: str | os.PathLike, new_source: str | os.PathLike, example_length: int) -> None:
|
||||
|
|
@ -169,10 +172,8 @@ def delete_example_inplace(source: str | os.PathLike, idx: int) -> None:
|
|||
with h5py.File(source, "a") as f:
|
||||
ds, md = f["data"], f["metadata/metadata"]
|
||||
m, c, n = ds.shape
|
||||
if not (0 <= idx <= m - 1):
|
||||
raise IndexError(f"Index {idx} out of range [0, {m - 1}]")
|
||||
if len(ds) != len(md):
|
||||
raise ValueError("Data and metadata array lengths do not match")
|
||||
assert 0 <= idx <= m - 1
|
||||
assert len(ds) == len(md)
|
||||
|
||||
new_ds = f.create_dataset(
|
||||
"data.temp",
|
||||
|
|
@ -217,3 +218,4 @@ def overwrite_file(source: str | os.PathLike, new_data: np.ndarray) -> None:
|
|||
ds_name = tuple(f.keys())[0]
|
||||
del f[ds_name]
|
||||
f.create_dataset(ds_name, data=new_data)
|
||||
f.close()
|
||||
|
|
|
|||
|
|
@ -169,10 +169,8 @@ class IQDataset(RadioDataset, ABC):
|
|||
"""
|
||||
|
||||
if split_factor is not None and example_length is not None:
|
||||
# Warn and use split factor
|
||||
import warnings
|
||||
|
||||
warnings.warn("split_factor and example_length should not both be specified.")
|
||||
# Raise warning and use split factor
|
||||
raise Warning("split_factor and example_length should not both be specified.")
|
||||
|
||||
if not inplace:
|
||||
# ds = self.create_new_dataset(example_length=example_length)
|
||||
|
|
|
|||
|
|
@ -255,9 +255,7 @@ class RadioDataset(ABC):
|
|||
else:
|
||||
classes_to_augment = classes_to_augment.encode("utf-8")
|
||||
if classes_to_augment not in class_sizes:
|
||||
raise ValueError(
|
||||
f"class name of {classes_to_augment} does not belong to the class key of {class_key}"
|
||||
)
|
||||
raise ValueError(f"class name of {i} does not belong to the class key of {class_key}")
|
||||
|
||||
result_sizes = get_result_sizes(
|
||||
level=level, target_size=target_size, classes_to_augment=classes_to_augment, class_sizes=class_sizes
|
||||
|
|
@ -377,7 +375,7 @@ class RadioDataset(ABC):
|
|||
counters[key] = counters.get(key, 0)
|
||||
|
||||
idx = 0
|
||||
with h5py.File(self.source, "r") as f:
|
||||
with h5py.File(self.source, "a") as f:
|
||||
while idx < len(self):
|
||||
labels = f["metadata/metadata"][class_key]
|
||||
current_class = labels[idx]
|
||||
|
|
@ -516,7 +514,7 @@ class RadioDataset(ABC):
|
|||
|
||||
idx = 0
|
||||
|
||||
with h5py.File(self.source, "r") as f:
|
||||
with h5py.File(self.source, "a") as f:
|
||||
while idx < len(self):
|
||||
labels = f["metadata/metadata"][class_key]
|
||||
current_class = labels[idx]
|
||||
|
|
|
|||
|
|
@ -247,7 +247,7 @@ def _validate_sublists(list_of_lists: list[list[str]], ids: list[str]) -> None:
|
|||
"""Ensure that each ID is present in one and only one sublist."""
|
||||
all_elements = [item for sublist in list_of_lists for item in sublist]
|
||||
|
||||
assert len(all_elements) == len(set(all_elements)) and sorted(set(ids)) == sorted(set(all_elements))
|
||||
assert len(all_elements) == len(set(all_elements)) and list(set(ids)).sort() == list(set(all_elements)).sort()
|
||||
|
||||
|
||||
def _generate_split_source_filenames(
|
||||
|
|
|
|||
|
|
@ -146,7 +146,7 @@ class Recording:
|
|||
self._metadata["timestamp"] = time.time()
|
||||
else:
|
||||
if not isinstance(self._metadata["timestamp"], (int, float)):
|
||||
raise ValueError(f"timestamp must be int or float, not {type(self._metadata['timestamp'])}")
|
||||
raise ValueError("timestamp must be int or float, not ", type(self._metadata["timestamp"]))
|
||||
|
||||
if "rec_id" not in self.metadata:
|
||||
self._metadata["rec_id"] = generate_recording_id(data=self.data, timestamp=self._metadata["timestamp"])
|
||||
|
|
@ -393,7 +393,6 @@ class Recording:
|
|||
"""
|
||||
if key not in self.metadata:
|
||||
self.add_to_metadata(key=key, value=value)
|
||||
return
|
||||
|
||||
if not _is_jsonable(value):
|
||||
raise ValueError("Value must be JSON serializable.")
|
||||
|
|
@ -445,7 +444,7 @@ class Recording:
|
|||
'rec_id': 'fda0f41...'} # Example value
|
||||
"""
|
||||
if key not in PROTECTED_KEYS:
|
||||
self._metadata.pop(key, None)
|
||||
self._metadata.pop(key)
|
||||
else:
|
||||
raise ValueError(f"Key {key} is protected and cannot be modified or removed.")
|
||||
|
||||
|
|
@ -602,7 +601,7 @@ class Recording:
|
|||
>>> recording = Recording(data=samples, metadata=metadata)
|
||||
>>> recording.to_wav()
|
||||
"""
|
||||
from ria_toolkit_oss.io.recording import to_wav
|
||||
from utils.io.recording import to_wav
|
||||
|
||||
return to_wav(
|
||||
recording=self,
|
||||
|
|
@ -652,7 +651,7 @@ class Recording:
|
|||
>>> recording = Recording(data=samples, metadata=metadata)
|
||||
>>> recording.to_blue()
|
||||
"""
|
||||
from ria_toolkit_oss.io.recording import to_blue
|
||||
from utils.io.recording import to_blue
|
||||
|
||||
return to_blue(recording=self, filename=filename, path=path, data_format=data_format, overwrite=overwrite)
|
||||
|
||||
|
|
@ -703,14 +702,7 @@ class Recording:
|
|||
data = self.data[:, start_sample:end_sample]
|
||||
|
||||
new_annotations = copy.deepcopy(self.annotations)
|
||||
trimmed_annotations = []
|
||||
for annotation in new_annotations:
|
||||
# skip annotations entirely outside the trim window
|
||||
if annotation.sample_start + annotation.sample_count <= start_sample:
|
||||
continue
|
||||
if annotation.sample_start >= end_sample:
|
||||
continue
|
||||
|
||||
# trim annotation if it goes outside the trim boundaries
|
||||
if annotation.sample_start < start_sample:
|
||||
annotation.sample_count = annotation.sample_count - (start_sample - annotation.sample_start)
|
||||
|
|
@ -721,9 +713,8 @@ class Recording:
|
|||
|
||||
# shift annotation to align with the new start point
|
||||
annotation.sample_start = annotation.sample_start - start_sample
|
||||
trimmed_annotations.append(annotation)
|
||||
|
||||
return Recording(data=data, metadata=self.metadata, annotations=trimmed_annotations)
|
||||
return Recording(data=data, metadata=self.metadata, annotations=new_annotations)
|
||||
|
||||
def normalize(self) -> Recording:
|
||||
"""Scale the recording data, relative to its maximum value, so that the magnitude of the maximum sample is 1.
|
||||
|
|
@ -752,10 +743,7 @@ class Recording:
|
|||
>>> print(numpy.max(numpy.abs(normalized_recording.data)))
|
||||
1
|
||||
"""
|
||||
max_val = np.max(abs(self.data))
|
||||
if max_val == 0:
|
||||
raise ValueError("Cannot normalize a recording with all-zero data.")
|
||||
scaled_data = self.data / max_val
|
||||
scaled_data = self.data / np.max(abs(self.data))
|
||||
return Recording(data=scaled_data, metadata=self.metadata, annotations=self.annotations)
|
||||
|
||||
def __len__(self) -> int:
|
||||
|
|
|
|||
|
|
@ -4,12 +4,10 @@ Utilities for input/output operations on the ria_toolkit_oss.datatypes.Recording
|
|||
|
||||
import datetime
|
||||
import datetime as dt
|
||||
import json
|
||||
import numbers
|
||||
import os
|
||||
import re
|
||||
import struct
|
||||
import warnings
|
||||
from datetime import timezone
|
||||
from typing import Any, List, Optional
|
||||
|
||||
|
|
@ -93,35 +91,15 @@ def to_npy(
|
|||
metadata = recording.metadata
|
||||
annotations = recording.annotations
|
||||
|
||||
# Serialize metadata and annotations as JSON to avoid pickle-based deserialization.
|
||||
# JSON is safe; pickle allows arbitrary code execution when loading untrusted files.
|
||||
metadata_bytes = json.dumps(convert_to_serializable(metadata)).encode()
|
||||
annotations_bytes = json.dumps([a.__dict__ for a in annotations]).encode()
|
||||
|
||||
with open(file=fullpath, mode="wb") as f:
|
||||
# Write format version marker first so from_npy can detect the safe JSON format.
|
||||
np.save(f, np.array("ria-toolkit-oss-v2"))
|
||||
np.save(f, data)
|
||||
np.save(f, np.frombuffer(metadata_bytes, dtype=np.uint8))
|
||||
np.save(f, np.frombuffer(annotations_bytes, dtype=np.uint8))
|
||||
np.save(f, metadata)
|
||||
np.save(f, annotations)
|
||||
|
||||
# print(f"Saved recording to {os.getcwd()}/{fullpath}")
|
||||
return str(fullpath)
|
||||
|
||||
|
||||
_NPY_MAGIC = b"\x93NUMPY"
|
||||
|
||||
|
||||
def _check_npy_magic(filepath: str) -> None:
|
||||
"""Raise ValueError if the file does not start with the NumPy magic bytes."""
|
||||
try:
|
||||
with open(filepath, "rb") as f:
|
||||
header = f.read(6)
|
||||
except OSError as e:
|
||||
raise IOError(f"Cannot open file for validation: {filepath}") from e
|
||||
if header != _NPY_MAGIC:
|
||||
raise ValueError(f"File does not appear to be a valid NumPy .npy file (bad magic bytes): {filepath}")
|
||||
|
||||
|
||||
def from_npy(file: os.PathLike | str, legacy: bool = False) -> Recording:
|
||||
"""Load a recording from a ``.npy`` binary file.
|
||||
|
||||
|
|
@ -148,37 +126,14 @@ def from_npy(file: os.PathLike | str, legacy: bool = False) -> Recording:
|
|||
if legacy:
|
||||
return from_npy_legacy(filename)
|
||||
|
||||
_check_npy_magic(filename)
|
||||
|
||||
with open(file=filename, mode="rb") as f:
|
||||
first = np.load(f, allow_pickle=False)
|
||||
|
||||
if first.ndim == 0 and first.dtype.kind in ("U", "S") and str(first) == "ria-toolkit-oss-v2":
|
||||
# Safe JSON format written by current to_npy.
|
||||
data = np.load(f, allow_pickle=False)
|
||||
raw_meta = np.load(f, allow_pickle=False)
|
||||
metadata = json.loads(raw_meta.tobytes().decode())
|
||||
try:
|
||||
raw_ann = np.load(f, allow_pickle=False)
|
||||
ann_list = json.loads(raw_ann.tobytes().decode())
|
||||
from ria_toolkit_oss.datatypes.annotation import Annotation
|
||||
|
||||
annotations = [Annotation(**a) for a in ann_list]
|
||||
except EOFError:
|
||||
annotations = []
|
||||
else:
|
||||
# Legacy pickle-based format. Only load files from trusted sources.
|
||||
warnings.warn(
|
||||
"Loading .npy file in legacy pickle format — only load files from trusted sources. "
|
||||
"Re-save with to_npy() to upgrade to the safe JSON format.",
|
||||
stacklevel=2,
|
||||
)
|
||||
data = first # already loaded without pickle (numeric array)
|
||||
metadata = np.load(f, allow_pickle=True).tolist()
|
||||
try:
|
||||
annotations = list(np.load(f, allow_pickle=True))
|
||||
except EOFError:
|
||||
annotations = []
|
||||
data = np.load(f, allow_pickle=True)
|
||||
metadata = np.load(f, allow_pickle=True)
|
||||
metadata = metadata.tolist()
|
||||
try:
|
||||
annotations = list(np.load(f, allow_pickle=True))
|
||||
except EOFError:
|
||||
annotations = []
|
||||
|
||||
recording = Recording(data=data, metadata=metadata, annotations=annotations)
|
||||
return recording
|
||||
|
|
@ -216,20 +171,14 @@ def from_npy_legacy(file: os.PathLike | str) -> Recording:
|
|||
# Rebuild with .npy extension.
|
||||
filename = str(filename) + ".npy"
|
||||
|
||||
warnings.warn(
|
||||
"from_npy_legacy uses pickle deserialization for extended metadata — only load files from trusted sources.",
|
||||
stacklevel=2,
|
||||
)
|
||||
_check_npy_magic(filename)
|
||||
|
||||
with open(filename, "rb") as f:
|
||||
# Read IQ data (2, N) format
|
||||
iqdata = np.load(f, allow_pickle=False)
|
||||
iqdata = np.load(f)
|
||||
|
||||
# Read basic metadata array [center_freq, rec_length, decimation, sample_rate]
|
||||
meta = np.load(f, allow_pickle=False)
|
||||
meta = np.load(f)
|
||||
|
||||
# Read extended metadata dict (legacy format requires pickle)
|
||||
# Read extended metadata dict
|
||||
extended_meta = np.load(f, allow_pickle=True)[0]
|
||||
|
||||
# Convert IQ data from (2, N) to (N,) complex format
|
||||
|
|
@ -330,7 +279,7 @@ def to_sigmf(
|
|||
converted_metadata = {
|
||||
sigmf_key: metadata[metadata_key]
|
||||
for sigmf_key, metadata_key in SIGMF_KEY_CONVERSION.items()
|
||||
if metadata_key in metadata and sigmf_key != SigMFFile.HASH_KEY
|
||||
if metadata_key in metadata
|
||||
}
|
||||
|
||||
# Merge dictionaries, giving priority to sigmf_meta
|
||||
|
|
@ -367,8 +316,6 @@ def to_sigmf(
|
|||
meta_dict = sigMF_metafile.ordered_metadata()
|
||||
meta_dict["ria"] = metadata
|
||||
|
||||
if overwrite and os.path.isfile(meta_file_path):
|
||||
os.remove(meta_file_path)
|
||||
sigMF_metafile.tofile(meta_file_path)
|
||||
|
||||
|
||||
|
|
@ -387,8 +334,9 @@ def from_sigmf(file: os.PathLike | str) -> Recording:
|
|||
"""
|
||||
|
||||
file = str(file)
|
||||
if not file.endswith((".sigmf-data", ".sigmf-meta", ".sigmf")):
|
||||
file = file + ".sigmf-data"
|
||||
if len(file) > 11:
|
||||
if file[-11:-5] != ".sigmf":
|
||||
file = file + ".sigmf-data"
|
||||
|
||||
sigmf_file = sigmffile.fromfile(file)
|
||||
|
||||
|
|
@ -401,7 +349,7 @@ def from_sigmf(file: os.PathLike | str) -> Recording:
|
|||
# Process core keys
|
||||
if key.startswith("core:"):
|
||||
base_key = key[5:] # Remove 'core:' prefix
|
||||
converted_key = SIGMF_KEY_CONVERSION.get(key, base_key)
|
||||
converted_key = SIGMF_KEY_CONVERSION.get(base_key, base_key)
|
||||
# Process ria keys
|
||||
elif key.startswith("ria:"):
|
||||
converted_key = key[4:] # Remove 'ria:' prefix
|
||||
|
|
|
|||
|
|
@ -1,26 +0,0 @@
|
|||
"""Orchestration layer for automated RF capture campaigns."""
|
||||
|
||||
from .campaign import (
|
||||
CampaignConfig,
|
||||
CaptureStep,
|
||||
QAConfig,
|
||||
RecorderConfig,
|
||||
TransmitterConfig,
|
||||
)
|
||||
from .executor import CampaignExecutor, CampaignResult, StepResult
|
||||
from .labeler import label_recording
|
||||
from .qa import QAResult, check_recording
|
||||
|
||||
__all__ = [
|
||||
"CampaignConfig",
|
||||
"CaptureStep",
|
||||
"QAConfig",
|
||||
"RecorderConfig",
|
||||
"TransmitterConfig",
|
||||
"CampaignExecutor",
|
||||
"CampaignResult",
|
||||
"StepResult",
|
||||
"label_recording",
|
||||
"QAResult",
|
||||
"check_recording",
|
||||
]
|
||||
|
|
@ -1,490 +0,0 @@
|
|||
"""Campaign configuration schema and YAML parser for orchestrated RF captures."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import yaml
|
||||
|
||||
# Allowed characters in campaign names when used as filename components.
|
||||
_SAFE_NAME_RE = re.compile(r"[^a-zA-Z0-9_\-]")
|
||||
|
||||
# Reasonable RF bounds for consumer/research SDR hardware.
|
||||
_FREQ_MIN_HZ = 1.0 # 1 Hz
|
||||
_FREQ_MAX_HZ = 300e9 # 300 GHz
|
||||
_GAIN_MIN_DB = -30.0
|
||||
_GAIN_MAX_DB = 120.0
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Parsing helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def parse_duration(value: str | float | int) -> float:
|
||||
"""Parse a duration string to seconds.
|
||||
|
||||
Accepts:
|
||||
"30s" → 30.0
|
||||
"1.5m" or "1.5min" → 90.0
|
||||
"2h" → 7200.0
|
||||
30 (numeric) → 30.0
|
||||
"""
|
||||
if isinstance(value, (int, float)):
|
||||
return float(value)
|
||||
value = str(value).strip()
|
||||
match = re.fullmatch(r"([\d.]+)\s*(s|sec|m|min|h|hr)?", value, re.IGNORECASE)
|
||||
if not match:
|
||||
raise ValueError(f"Cannot parse duration: '{value}'")
|
||||
amount = float(match.group(1))
|
||||
unit = (match.group(2) or "s").lower()
|
||||
if unit in ("h", "hr"):
|
||||
return amount * 3600
|
||||
if unit in ("m", "min"):
|
||||
return amount * 60
|
||||
return amount
|
||||
|
||||
|
||||
def parse_frequency(value: str | float | int) -> float:
|
||||
"""Parse a frequency string to Hz.
|
||||
|
||||
Accepts:
|
||||
"2.45GHz" → 2_450_000_000.0
|
||||
"40MHz" → 40_000_000.0
|
||||
"915e6" → 915_000_000.0
|
||||
2.45e9 (numeric) → 2_450_000_000.0
|
||||
"""
|
||||
if isinstance(value, (int, float)):
|
||||
result = float(value)
|
||||
if not (_FREQ_MIN_HZ <= result <= _FREQ_MAX_HZ):
|
||||
raise ValueError(
|
||||
f"Frequency {result:.3g} Hz is outside the supported range "
|
||||
f"({_FREQ_MIN_HZ:.0f} Hz – {_FREQ_MAX_HZ:.3g} Hz)"
|
||||
)
|
||||
return result
|
||||
value = str(value).strip()
|
||||
|
||||
# Try bare numeric first (handles scientific notation like "915e6")
|
||||
try:
|
||||
result = float(value)
|
||||
except ValueError:
|
||||
pass
|
||||
else:
|
||||
if not (_FREQ_MIN_HZ <= result <= _FREQ_MAX_HZ):
|
||||
raise ValueError(
|
||||
f"Frequency {result:.3g} Hz is outside the supported range "
|
||||
f"({_FREQ_MIN_HZ:.0f} Hz – {_FREQ_MAX_HZ:.3g} Hz): '{value}'"
|
||||
)
|
||||
return result
|
||||
|
||||
# Handle suffix notation: "2.45GHz", "40MHz", "40M", "433k"
|
||||
match = re.fullmatch(r"([\d.]+)\s*(k|M|G)(?:\s*Hz?)?", value, re.IGNORECASE)
|
||||
if match:
|
||||
amount = float(match.group(1))
|
||||
suffix = match.group(2).upper()
|
||||
result = amount * {"K": 1e3, "M": 1e6, "G": 1e9}[suffix]
|
||||
if not (_FREQ_MIN_HZ <= result <= _FREQ_MAX_HZ):
|
||||
raise ValueError(
|
||||
f"Frequency {result:.3g} Hz is outside the supported range "
|
||||
f"({_FREQ_MIN_HZ:.0f} Hz – {_FREQ_MAX_HZ:.3g} Hz): '{value}'"
|
||||
)
|
||||
return result
|
||||
|
||||
raise ValueError(f"Cannot parse frequency: '{value}'")
|
||||
|
||||
|
||||
def parse_gain(value: str | float | int) -> float | str:
|
||||
"""Parse a gain string.
|
||||
|
||||
Accepts:
|
||||
"40dB" or "40 dB" → 40.0
|
||||
"auto" → "auto"
|
||||
40 (numeric) → 40.0
|
||||
"""
|
||||
if isinstance(value, (int, float)):
|
||||
result = float(value)
|
||||
if not (_GAIN_MIN_DB <= result <= _GAIN_MAX_DB):
|
||||
raise ValueError(f"Gain {result} dB is outside the supported range ({_GAIN_MIN_DB} – {_GAIN_MAX_DB} dB)")
|
||||
return result
|
||||
value = str(value).strip()
|
||||
if value.lower() == "auto":
|
||||
return "auto"
|
||||
match = re.fullmatch(r"([\d.+\-]+)\s*dB?", value, re.IGNORECASE)
|
||||
if not match:
|
||||
raise ValueError(f"Cannot parse gain: '{value}'")
|
||||
result = float(match.group(1))
|
||||
if not (_GAIN_MIN_DB <= result <= _GAIN_MAX_DB):
|
||||
raise ValueError(
|
||||
f"Gain {result} dB is outside the supported range ({_GAIN_MIN_DB} – {_GAIN_MAX_DB} dB): '{value}'"
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
def parse_bandwidth_mhz(value: str | float | int | None) -> Optional[float]:
|
||||
"""Parse a bandwidth string to MHz.
|
||||
|
||||
Accepts:
|
||||
"20MHz" → 20.0
|
||||
"40MHz" → 40.0
|
||||
20 (numeric, assumed MHz) → 20.0
|
||||
None → None
|
||||
"""
|
||||
if value is None:
|
||||
return None
|
||||
if isinstance(value, (int, float)):
|
||||
return float(value)
|
||||
value = str(value).strip()
|
||||
match = re.fullmatch(r"([\d.]+)\s*MHz?", value, re.IGNORECASE)
|
||||
if match:
|
||||
return float(match.group(1))
|
||||
match = re.fullmatch(r"([\d.]+)", value)
|
||||
if match:
|
||||
return float(match.group(1))
|
||||
raise ValueError(f"Cannot parse bandwidth: '{value}'")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Config dataclasses
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@dataclass
|
||||
class RecorderConfig:
|
||||
"""SDR recorder configuration."""
|
||||
|
||||
device: str
|
||||
center_freq: float # Hz
|
||||
sample_rate: float # Hz
|
||||
gain: float | str # dB float, or "auto"
|
||||
bandwidth: Optional[float] = None # Hz, None = match sample_rate
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, d: dict) -> "RecorderConfig":
|
||||
gain = parse_gain(d.get("gain", "auto"))
|
||||
bandwidth_raw = d.get("bandwidth") or d.get("bandwidth_hz")
|
||||
bandwidth = parse_frequency(bandwidth_raw) if bandwidth_raw else None
|
||||
return cls(
|
||||
device=str(d["device"]),
|
||||
center_freq=parse_frequency(d["center_freq"]),
|
||||
sample_rate=parse_frequency(d["sample_rate"]),
|
||||
gain=gain,
|
||||
bandwidth=bandwidth,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class CaptureStep:
|
||||
"""A single timed capture within a transmitter schedule."""
|
||||
|
||||
duration: float # seconds
|
||||
label: str # used as filename component
|
||||
|
||||
# WiFi-specific
|
||||
channel: Optional[int] = None
|
||||
bandwidth_mhz: Optional[float] = None # MHz
|
||||
traffic: Optional[str] = None
|
||||
|
||||
# Bluetooth-specific
|
||||
connection_interval_ms: Optional[float] = None
|
||||
|
||||
# Power (dBm), optional
|
||||
power_dbm: Optional[float] = None
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, d: dict, auto_label: bool = True) -> "CaptureStep":
|
||||
duration = parse_duration(d["duration"])
|
||||
label = d.get("label", "")
|
||||
if not label and auto_label:
|
||||
parts = []
|
||||
if d.get("channel"):
|
||||
parts.append(f"ch{d['channel']:02d}")
|
||||
if d.get("bandwidth"):
|
||||
bw = parse_bandwidth_mhz(d["bandwidth"])
|
||||
parts.append(f"{int(bw)}mhz")
|
||||
if d.get("traffic"):
|
||||
parts.append(str(d["traffic"]).replace(" ", "_"))
|
||||
label = "_".join(parts) if parts else "capture"
|
||||
return cls(
|
||||
duration=duration,
|
||||
label=label,
|
||||
channel=d.get("channel"),
|
||||
bandwidth_mhz=parse_bandwidth_mhz(d.get("bandwidth")),
|
||||
traffic=d.get("traffic"),
|
||||
connection_interval_ms=d.get("connection_interval_ms"),
|
||||
power_dbm=float(d["power"].removesuffix("dBm").strip()) if d.get("power") else None,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TransmitterConfig:
|
||||
"""Configuration for a single transmitter device in the campaign."""
|
||||
|
||||
id: str
|
||||
type: str # "wifi", "bluetooth", "sdr", "external"
|
||||
control_method: str # "external_script" | "sdr"
|
||||
schedule: list[CaptureStep]
|
||||
|
||||
# For external_script control
|
||||
script: Optional[str] = None # path to control script
|
||||
device: Optional[str] = None # e.g. "/dev/wlan0"
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, d: dict) -> "TransmitterConfig":
|
||||
schedule = [CaptureStep.from_dict(s) for s in d.get("schedule", [])]
|
||||
return cls(
|
||||
id=str(d["id"]),
|
||||
type=str(d["type"]),
|
||||
control_method=str(d.get("control_method", "external_script")),
|
||||
schedule=schedule,
|
||||
script=d.get("script"),
|
||||
device=d.get("device"),
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class QAConfig:
|
||||
"""Quality assurance thresholds."""
|
||||
|
||||
snr_threshold_db: float = 10.0
|
||||
min_duration_s: float = 25.0
|
||||
flag_for_review: bool = True
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, d: dict) -> "QAConfig":
|
||||
return cls(
|
||||
snr_threshold_db=float(str(d.get("snr_threshold", "10")).rstrip("dB").strip()),
|
||||
min_duration_s=parse_duration(d.get("min_duration", "25s")),
|
||||
flag_for_review=bool(d.get("flag_for_review", True)),
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class OutputConfig:
|
||||
"""Where to save captured recordings."""
|
||||
|
||||
format: str = "sigmf"
|
||||
path: str = "recordings"
|
||||
device_id: Optional[str] = None # for device-profile campaigns
|
||||
repo: Optional[str] = None
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, d: dict) -> "OutputConfig":
|
||||
return cls(
|
||||
format=str(d.get("format", "sigmf")),
|
||||
path=str(d.get("path", "recordings")),
|
||||
device_id=d.get("device_id"),
|
||||
repo=d.get("repo"),
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class CampaignConfig:
|
||||
"""Full campaign configuration parsed from YAML."""
|
||||
|
||||
name: str
|
||||
recorder: RecorderConfig
|
||||
transmitters: list[TransmitterConfig]
|
||||
qa: QAConfig = field(default_factory=QAConfig)
|
||||
output: OutputConfig = field(default_factory=OutputConfig)
|
||||
mode: str = "controlled_testbed"
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Loaders
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, raw: dict) -> "CampaignConfig":
|
||||
"""Build a CampaignConfig from a parsed dictionary.
|
||||
|
||||
Accepts the same structure as the campaign YAML, already loaded into
|
||||
a Python dict (e.g. from a JSON HTTP request body).
|
||||
|
||||
Raises:
|
||||
ValueError: If required fields are missing or malformed.
|
||||
KeyError: If ``recorder`` key is absent.
|
||||
"""
|
||||
campaign_meta = raw.get("campaign", {})
|
||||
transmitters = [TransmitterConfig.from_dict(t) for t in raw.get("transmitters", [])]
|
||||
if not transmitters:
|
||||
raise ValueError("Campaign config must define at least one transmitter")
|
||||
if "recorder" not in raw:
|
||||
raise ValueError("Campaign config is missing required 'recorder' section")
|
||||
raw_name = str(campaign_meta.get("name", "unnamed"))
|
||||
safe_name = _SAFE_NAME_RE.sub("_", raw_name)
|
||||
return cls(
|
||||
name=safe_name,
|
||||
mode=str(campaign_meta.get("mode", "controlled_testbed")),
|
||||
recorder=RecorderConfig.from_dict(raw["recorder"]),
|
||||
transmitters=transmitters,
|
||||
qa=QAConfig.from_dict(raw.get("qa", {})),
|
||||
output=OutputConfig.from_dict(raw.get("output", {})),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_yaml(cls, path: str | Path) -> "CampaignConfig":
|
||||
"""Load a full campaign config YAML.
|
||||
|
||||
Expected format::
|
||||
|
||||
campaign:
|
||||
name: "wifi_capture_001"
|
||||
mode: "controlled_testbed"
|
||||
|
||||
transmitters:
|
||||
- id: "laptop_wifi"
|
||||
type: "wifi"
|
||||
control_method: "external_script"
|
||||
script: "./scripts/wifi_control.sh"
|
||||
device: "/dev/wlan0"
|
||||
schedule:
|
||||
- channel: 6
|
||||
bandwidth: "20MHz"
|
||||
traffic: "iperf_udp"
|
||||
duration: "30s"
|
||||
|
||||
recorder:
|
||||
device: "usrp_b210"
|
||||
center_freq: "2.45GHz"
|
||||
sample_rate: "40MHz"
|
||||
gain: "40dB"
|
||||
|
||||
qa:
|
||||
snr_threshold: "10dB"
|
||||
min_duration: "25s"
|
||||
flag_for_review: true
|
||||
|
||||
output:
|
||||
format: "sigmf"
|
||||
path: "./recordings"
|
||||
"""
|
||||
path = Path(path)
|
||||
try:
|
||||
with open(path) as f:
|
||||
raw = yaml.safe_load(f)
|
||||
except FileNotFoundError:
|
||||
raise FileNotFoundError(f"Campaign config not found: {path}")
|
||||
except yaml.YAMLError as e:
|
||||
raise ValueError(f"Invalid YAML in {path}: {e}")
|
||||
|
||||
campaign_meta = raw.get("campaign", {})
|
||||
transmitters = [TransmitterConfig.from_dict(t) for t in raw.get("transmitters", [])]
|
||||
if not transmitters:
|
||||
raise ValueError("Campaign config must define at least one transmitter")
|
||||
if "recorder" not in raw:
|
||||
raise ValueError(f"Campaign config is missing required 'recorder' section in {path}")
|
||||
raw_name = str(campaign_meta.get("name", path.stem))
|
||||
safe_name = _SAFE_NAME_RE.sub("_", raw_name)
|
||||
|
||||
return cls(
|
||||
name=safe_name,
|
||||
mode=str(campaign_meta.get("mode", "controlled_testbed")),
|
||||
recorder=RecorderConfig.from_dict(raw["recorder"]),
|
||||
transmitters=transmitters,
|
||||
qa=QAConfig.from_dict(raw.get("qa", {})),
|
||||
output=OutputConfig.from_dict(raw.get("output", {})),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_device_profile(cls, path: str | Path) -> "CampaignConfig":
|
||||
"""Build a campaign config from an App 1 device profile YAML.
|
||||
|
||||
Expected format::
|
||||
|
||||
device:
|
||||
name: "iPhone_13_WiFi"
|
||||
type: "wifi"
|
||||
protocol: "wifi_24ghz"
|
||||
|
||||
capture:
|
||||
channels: [1, 6, 11] # WiFi only
|
||||
bandwidth: "20MHz" # WiFi only
|
||||
traffic_patterns: ["idle", "ping", "iperf_udp"]
|
||||
duration_per_config: "30s"
|
||||
|
||||
recorder:
|
||||
device: "usrp_b210"
|
||||
center_freq: "2.45GHz"
|
||||
sample_rate: "40MHz"
|
||||
gain: "auto"
|
||||
|
||||
output:
|
||||
path: "./recordings"
|
||||
device_id: "iphone13_wifi_001"
|
||||
|
||||
For WiFi devices, schedule is expanded as channels × traffic_patterns.
|
||||
For Bluetooth devices (no channels), schedule is traffic_patterns only.
|
||||
"""
|
||||
path = Path(path)
|
||||
try:
|
||||
with open(path) as f:
|
||||
raw = yaml.safe_load(f)
|
||||
except FileNotFoundError:
|
||||
raise FileNotFoundError(f"Device profile not found: {path}")
|
||||
except yaml.YAMLError as e:
|
||||
raise ValueError(f"Invalid YAML in {path}: {e}")
|
||||
|
||||
device = raw.get("device", {})
|
||||
capture = raw.get("capture", {})
|
||||
device_type = str(device.get("type", "wifi")).lower()
|
||||
device_name = str(device.get("name", path.stem))
|
||||
duration = parse_duration(capture.get("duration_per_config", "30s"))
|
||||
traffic_patterns = capture.get("traffic_patterns", ["idle"])
|
||||
|
||||
# Build capture schedule
|
||||
schedule: list[CaptureStep] = []
|
||||
|
||||
if device_type in ("wifi", "wifi_24ghz", "wifi_5ghz"):
|
||||
channels = capture.get("channels", [6])
|
||||
bw_str = capture.get("bandwidth", "20MHz")
|
||||
bw_mhz = parse_bandwidth_mhz(bw_str)
|
||||
for ch in channels:
|
||||
for traffic in traffic_patterns:
|
||||
label = f"ch{ch:02d}_{int(bw_mhz)}mhz_{traffic}"
|
||||
schedule.append(
|
||||
CaptureStep(
|
||||
duration=duration,
|
||||
label=label,
|
||||
channel=ch,
|
||||
bandwidth_mhz=bw_mhz,
|
||||
traffic=traffic,
|
||||
)
|
||||
)
|
||||
else:
|
||||
# Bluetooth / generic — no channels
|
||||
for traffic in traffic_patterns:
|
||||
schedule.append(
|
||||
CaptureStep(
|
||||
duration=duration,
|
||||
label=traffic,
|
||||
traffic=traffic,
|
||||
)
|
||||
)
|
||||
|
||||
device_id = raw.get("output", {}).get("device_id", device_name.lower().replace(" ", "_"))
|
||||
transmitter = TransmitterConfig(
|
||||
id=device_id,
|
||||
type=device_type,
|
||||
control_method=str(capture.get("control_method", "external_script")),
|
||||
schedule=schedule,
|
||||
script=capture.get("script"),
|
||||
device=capture.get("device"),
|
||||
)
|
||||
|
||||
return cls(
|
||||
name=f"enroll_{device_id}",
|
||||
mode="controlled_testbed",
|
||||
recorder=RecorderConfig.from_dict(raw["recorder"]),
|
||||
transmitters=[transmitter],
|
||||
qa=QAConfig.from_dict(raw.get("qa", {})),
|
||||
output=OutputConfig.from_dict(raw.get("output", {})),
|
||||
)
|
||||
|
||||
def total_capture_time_s(self) -> float:
|
||||
"""Sum of all step durations across all transmitters."""
|
||||
return sum(step.duration for tx in self.transmitters for step in tx.schedule)
|
||||
|
||||
def total_steps(self) -> int:
|
||||
"""Total number of capture steps across all transmitters."""
|
||||
return sum(len(tx.schedule) for tx in self.transmitters)
|
||||
|
|
@ -1,444 +0,0 @@
|
|||
"""Campaign executor: runs a capture campaign end-to-end."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import subprocess
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Callable, Optional
|
||||
|
||||
from ria_toolkit_oss.datatypes.recording import Recording
|
||||
from ria_toolkit_oss.io.recording import to_sigmf
|
||||
|
||||
from .campaign import CampaignConfig, CaptureStep, TransmitterConfig
|
||||
from .labeler import build_output_filename, label_recording
|
||||
from .qa import QAResult, check_recording
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Device name aliases: campaign YAML names → get_sdr_device() names
|
||||
_DEVICE_ALIASES = {
|
||||
"usrp_b210": "usrp",
|
||||
"usrp_b200": "usrp",
|
||||
"usrp": "usrp",
|
||||
"plutosdr": "pluto",
|
||||
"pluto": "pluto",
|
||||
"hackrf": "hackrf",
|
||||
"hackrf_one": "hackrf",
|
||||
"bladerf": "bladerf",
|
||||
"rtlsdr": "rtlsdr",
|
||||
"rtl_sdr": "rtlsdr",
|
||||
"thinkrf": "thinkrf",
|
||||
# Simulated device — no hardware required
|
||||
"mock": "mock",
|
||||
"sim": "mock",
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class StepResult:
|
||||
"""Outcome of a single capture step."""
|
||||
|
||||
transmitter_id: str
|
||||
step_label: str
|
||||
output_path: Optional[str]
|
||||
qa: QAResult
|
||||
capture_timestamp: float
|
||||
error: Optional[str] = None
|
||||
|
||||
@property
|
||||
def ok(self) -> bool:
|
||||
return self.error is None and self.qa.passed
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return {
|
||||
"transmitter_id": self.transmitter_id,
|
||||
"step_label": self.step_label,
|
||||
"output_path": self.output_path,
|
||||
"capture_timestamp": self.capture_timestamp,
|
||||
"qa": self.qa.to_dict(),
|
||||
"error": self.error,
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class CampaignResult:
|
||||
"""Aggregate outcome of a full campaign."""
|
||||
|
||||
campaign_name: str
|
||||
steps: list[StepResult] = field(default_factory=list)
|
||||
start_time: float = field(default_factory=time.time)
|
||||
end_time: Optional[float] = None
|
||||
|
||||
@property
|
||||
def total_steps(self) -> int:
|
||||
return len(self.steps)
|
||||
|
||||
@property
|
||||
def passed(self) -> int:
|
||||
return sum(1 for s in self.steps if s.ok)
|
||||
|
||||
@property
|
||||
def flagged(self) -> int:
|
||||
return sum(1 for s in self.steps if not s.error and s.qa.flagged)
|
||||
|
||||
@property
|
||||
def failed(self) -> int:
|
||||
return sum(1 for s in self.steps if s.error or not s.qa.passed)
|
||||
|
||||
@property
|
||||
def duration_s(self) -> float:
|
||||
if self.end_time:
|
||||
return self.end_time - self.start_time
|
||||
return time.time() - self.start_time
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return {
|
||||
"campaign_name": self.campaign_name,
|
||||
"total_steps": self.total_steps,
|
||||
"passed": self.passed,
|
||||
"flagged": self.flagged,
|
||||
"failed": self.failed,
|
||||
"duration_s": round(self.duration_s, 1),
|
||||
"steps": [s.to_dict() for s in self.steps],
|
||||
}
|
||||
|
||||
def write_report(self, path: str | Path) -> None:
|
||||
"""Write a JSON QA report to disk."""
|
||||
path = Path(path)
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with open(path, "w") as f:
|
||||
json.dump(self.to_dict(), f, indent=2)
|
||||
logger.info(f"QA report written to {path}")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# External script interface
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _run_script(script: str, *args: str, timeout: float = 15.0) -> str:
|
||||
"""Run an external control script and return stdout.
|
||||
|
||||
The script is called as::
|
||||
|
||||
<script> <arg1> <arg2> ...
|
||||
|
||||
A non-zero return code raises RuntimeError.
|
||||
|
||||
Args:
|
||||
script: Path to executable script. Must be an absolute path to an
|
||||
existing regular file. Relative paths are rejected to prevent
|
||||
accidentally executing files that are not the intended script.
|
||||
*args: Positional arguments forwarded to the script.
|
||||
timeout: Maximum seconds to wait.
|
||||
|
||||
Returns:
|
||||
Script stdout as a string.
|
||||
"""
|
||||
if not Path(script).is_absolute():
|
||||
raise RuntimeError(f"Script path must be absolute: {script}")
|
||||
script_path = Path(script).resolve()
|
||||
if not script_path.is_file():
|
||||
raise RuntimeError(f"Script not found or is not a regular file: {script}")
|
||||
|
||||
cmd = [str(script_path), *args]
|
||||
logger.debug(f"Running script: {' '.join(cmd)}")
|
||||
try:
|
||||
result = subprocess.run(
|
||||
cmd,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=timeout,
|
||||
)
|
||||
except subprocess.TimeoutExpired:
|
||||
raise RuntimeError(f"Script timed out after {timeout}s: {script}")
|
||||
except FileNotFoundError:
|
||||
raise RuntimeError(f"Script not found: {script}")
|
||||
|
||||
if result.returncode != 0:
|
||||
raise RuntimeError(f"Script exited {result.returncode}: {result.stderr.strip() or result.stdout.strip()}")
|
||||
return result.stdout.strip()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Campaign executor
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class CampaignExecutor:
|
||||
"""Executes a :class:`CampaignConfig` end-to-end.
|
||||
|
||||
Initialises the SDR recorder once, then for each (transmitter, step):
|
||||
1. Configures the transmitter (via external script or SDR TX)
|
||||
2. Records IQ samples
|
||||
3. Labels the recording with device/config metadata
|
||||
4. Runs QA checks
|
||||
5. Saves the recording to disk
|
||||
6. Stops/resets the transmitter
|
||||
|
||||
Args:
|
||||
config: Parsed campaign configuration.
|
||||
progress_cb: Optional callback ``(step_index, total_steps, step_result)``
|
||||
called after each step completes. Useful for status reporting.
|
||||
verbose: Enable debug logging.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: CampaignConfig,
|
||||
progress_cb: Optional[Callable[[int, int, StepResult], None]] = None,
|
||||
verbose: bool = False,
|
||||
):
|
||||
self.config = config
|
||||
self.progress_cb = progress_cb
|
||||
self._sdr = None
|
||||
|
||||
if verbose:
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
else:
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Public interface
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def run(self) -> CampaignResult:
|
||||
"""Execute the full campaign and return a :class:`CampaignResult`.
|
||||
|
||||
Initialises the SDR, runs all steps across all transmitters,
|
||||
then closes the SDR. If SDR initialisation fails the exception
|
||||
propagates immediately (nothing is captured).
|
||||
"""
|
||||
result = CampaignResult(campaign_name=self.config.name)
|
||||
|
||||
logger.info(
|
||||
f"Starting campaign '{self.config.name}': "
|
||||
f"{self.config.total_steps()} steps, "
|
||||
f"~{self.config.total_capture_time_s():.0f}s capture time"
|
||||
)
|
||||
|
||||
self._init_sdr()
|
||||
try:
|
||||
total = self.config.total_steps()
|
||||
step_index = 0
|
||||
|
||||
for transmitter in self.config.transmitters:
|
||||
logger.info(f"Transmitter: {transmitter.id} ({len(transmitter.schedule)} steps)")
|
||||
for step in transmitter.schedule:
|
||||
step_result = self._execute_step(transmitter, step)
|
||||
result.steps.append(step_result)
|
||||
step_index += 1
|
||||
|
||||
if self.progress_cb:
|
||||
self.progress_cb(step_index, total, step_result)
|
||||
|
||||
if step_result.error:
|
||||
logger.warning(f"Step '{step.label}' error: {step_result.error}")
|
||||
elif step_result.qa.flagged:
|
||||
logger.warning(f"Step '{step.label}' flagged for review: " + "; ".join(step_result.qa.issues))
|
||||
else:
|
||||
logger.info(
|
||||
f"Step '{step.label}' OK "
|
||||
f"(SNR {step_result.qa.snr_db:.1f} dB, "
|
||||
f"{step_result.qa.duration_s:.1f}s)"
|
||||
)
|
||||
finally:
|
||||
self._close_sdr()
|
||||
|
||||
result.end_time = time.time()
|
||||
logger.info(
|
||||
f"Campaign complete: {result.passed}/{result.total_steps} passed, "
|
||||
f"{result.flagged} flagged, {result.failed} failed"
|
||||
)
|
||||
return result
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# SDR management
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _init_sdr(self) -> None:
|
||||
"""Initialise and configure the SDR recorder."""
|
||||
from ria_toolkit_oss.sdr import get_sdr_device
|
||||
|
||||
rec = self.config.recorder
|
||||
device_name = _DEVICE_ALIASES.get(rec.device.lower(), rec.device.lower())
|
||||
logger.info(f"Initialising SDR: {device_name} @ {rec.center_freq/1e6:.2f} MHz")
|
||||
|
||||
self._sdr = get_sdr_device(device_name)
|
||||
gain = None if rec.gain == "auto" else float(rec.gain)
|
||||
self._sdr.init_rx(
|
||||
sample_rate=rec.sample_rate,
|
||||
center_frequency=rec.center_freq,
|
||||
gain=gain,
|
||||
channel=0,
|
||||
)
|
||||
if rec.bandwidth and hasattr(self._sdr, "set_rx_bandwidth"):
|
||||
self._sdr.set_rx_bandwidth(rec.bandwidth)
|
||||
|
||||
def _close_sdr(self) -> None:
|
||||
if self._sdr is not None:
|
||||
try:
|
||||
self._sdr.close()
|
||||
except Exception as e:
|
||||
logger.warning(f"SDR close error: {e}")
|
||||
self._sdr = None
|
||||
|
||||
def _record(self, duration_s: float) -> Recording:
|
||||
"""Capture ``duration_s`` seconds of IQ samples."""
|
||||
num_samples = int(duration_s * self.config.recorder.sample_rate)
|
||||
return self._sdr.record(num_samples=num_samples)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Step execution
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _execute_step(self, transmitter: TransmitterConfig, step: CaptureStep) -> StepResult:
|
||||
"""Run a single capture step.
|
||||
|
||||
Returns:
|
||||
StepResult with QA outcome and output path (or error string).
|
||||
"""
|
||||
capture_timestamp = time.time()
|
||||
output_path: Optional[str] = None
|
||||
|
||||
try:
|
||||
self._start_transmitter(transmitter, step)
|
||||
recording = self._record(step.duration)
|
||||
self._stop_transmitter(transmitter, step)
|
||||
except Exception as e:
|
||||
# Best-effort stop on error
|
||||
try:
|
||||
self._stop_transmitter(transmitter, step)
|
||||
except Exception:
|
||||
pass
|
||||
return StepResult(
|
||||
transmitter_id=transmitter.id,
|
||||
step_label=step.label,
|
||||
output_path=None,
|
||||
qa=QAResult(passed=False, flagged=True, snr_db=0.0, duration_s=0.0, issues=[f"Capture error: {e}"]),
|
||||
capture_timestamp=capture_timestamp,
|
||||
error=str(e),
|
||||
)
|
||||
|
||||
# Label recording
|
||||
recording = label_recording(
|
||||
recording=recording,
|
||||
device_id=transmitter.id,
|
||||
step=step,
|
||||
capture_timestamp=capture_timestamp,
|
||||
campaign_name=self.config.name,
|
||||
)
|
||||
|
||||
# QA
|
||||
qa_result = check_recording(recording, self.config.qa)
|
||||
|
||||
# Save
|
||||
try:
|
||||
output_path = self._save(recording, transmitter.id, step)
|
||||
except Exception as e:
|
||||
return StepResult(
|
||||
transmitter_id=transmitter.id,
|
||||
step_label=step.label,
|
||||
output_path=None,
|
||||
qa=qa_result,
|
||||
capture_timestamp=capture_timestamp,
|
||||
error=f"Save failed: {e}",
|
||||
)
|
||||
|
||||
return StepResult(
|
||||
transmitter_id=transmitter.id,
|
||||
step_label=step.label,
|
||||
output_path=output_path,
|
||||
qa=qa_result,
|
||||
capture_timestamp=capture_timestamp,
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Transmitter control (external script interface)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _start_transmitter(self, transmitter: TransmitterConfig, step: CaptureStep) -> None:
|
||||
"""Configure the transmitter for this step.
|
||||
|
||||
For ``external_script`` control method the script is called as::
|
||||
|
||||
<script> configure <step_params_json>
|
||||
|
||||
where ``step_params_json`` is a JSON object with channel, bandwidth,
|
||||
traffic, etc. The script is responsible for applying the configuration
|
||||
and returning promptly (i.e. not blocking for the capture duration).
|
||||
|
||||
For SDR transmitters this is a no-op placeholder (TX not yet implemented).
|
||||
"""
|
||||
if transmitter.control_method == "external_script":
|
||||
if not transmitter.script:
|
||||
logger.debug(f"No script configured for {transmitter.id}, skipping configure")
|
||||
return
|
||||
params = self._step_params_json(transmitter, step)
|
||||
_run_script(transmitter.script, "configure", params)
|
||||
|
||||
elif transmitter.control_method == "sdr":
|
||||
logger.debug("SDR TX not yet implemented — skipping start")
|
||||
|
||||
else:
|
||||
logger.warning(f"Unknown control method '{transmitter.control_method}' — skipping")
|
||||
|
||||
def _stop_transmitter(self, transmitter: TransmitterConfig, step: CaptureStep) -> None:
|
||||
"""Signal the transmitter to stop.
|
||||
|
||||
Calls ``<script> stop`` for external_script transmitters.
|
||||
"""
|
||||
if transmitter.control_method == "external_script":
|
||||
if not transmitter.script:
|
||||
return
|
||||
try:
|
||||
_run_script(transmitter.script, "stop")
|
||||
except Exception as e:
|
||||
logger.warning(f"Script stop failed for {transmitter.id}: {e}")
|
||||
|
||||
@staticmethod
|
||||
def _step_params_json(transmitter: TransmitterConfig, step: CaptureStep) -> str:
|
||||
"""Serialise step parameters to a JSON string for the control script."""
|
||||
params: dict = {"device": transmitter.device or ""}
|
||||
if step.channel is not None:
|
||||
params["channel"] = step.channel
|
||||
if step.bandwidth_mhz is not None:
|
||||
params["bandwidth_mhz"] = step.bandwidth_mhz
|
||||
if step.traffic is not None:
|
||||
params["traffic"] = step.traffic
|
||||
if step.power_dbm is not None:
|
||||
params["power_dbm"] = step.power_dbm
|
||||
return json.dumps(params)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Output
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _save(self, recording: Recording, device_id: str, step: CaptureStep) -> str:
|
||||
"""Save a recording to disk and return the data file path."""
|
||||
out = self.config.output
|
||||
rel_filename = build_output_filename(device_id, step)
|
||||
out_dir = Path(out.path).resolve()
|
||||
|
||||
# build_output_filename returns "<device_id>/<label>"
|
||||
# to_sigmf needs filename (base) and path (dir) separately
|
||||
parts = Path(rel_filename)
|
||||
subdir = (out_dir / parts.parent).resolve()
|
||||
|
||||
# Prevent path traversal: the resolved subdir must stay within the configured output directory.
|
||||
try:
|
||||
subdir.relative_to(out_dir)
|
||||
except ValueError:
|
||||
raise RuntimeError(
|
||||
f"Output path escape detected: '{subdir}' is outside configured output directory '{out_dir}'"
|
||||
)
|
||||
|
||||
subdir.mkdir(parents=True, exist_ok=True)
|
||||
base = parts.name
|
||||
|
||||
to_sigmf(recording, filename=base, path=str(subdir), overwrite=True)
|
||||
return str(subdir / f"{base}.sigmf-data")
|
||||
|
|
@ -1,77 +0,0 @@
|
|||
"""Timestamp-based labeling for captured recordings."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from ria_toolkit_oss.datatypes.recording import Recording
|
||||
|
||||
from .campaign import CaptureStep
|
||||
|
||||
|
||||
def label_recording(
|
||||
recording: Recording,
|
||||
device_id: str,
|
||||
step: CaptureStep,
|
||||
capture_timestamp: float,
|
||||
campaign_name: Optional[str] = None,
|
||||
) -> Recording:
|
||||
"""Apply device identity and capture configuration labels to a recording's metadata.
|
||||
|
||||
Labels are stored in the ``ria:*`` namespace when the recording is saved
|
||||
as SigMF, via the existing ``update_metadata`` mechanism.
|
||||
|
||||
Args:
|
||||
recording: The recording to label.
|
||||
device_id: Identifier for the transmitting device (e.g. "iphone13_wifi_001").
|
||||
step: The capture step that was active during this recording.
|
||||
capture_timestamp: Unix timestamp (float) of when capture started.
|
||||
campaign_name: Optional campaign name for cross-recording reference.
|
||||
|
||||
Returns:
|
||||
The same recording with updated metadata.
|
||||
"""
|
||||
recording.update_metadata("device_id", device_id)
|
||||
recording.update_metadata("capture_timestamp", capture_timestamp)
|
||||
recording.update_metadata("step_label", step.label)
|
||||
recording.update_metadata("step_duration_s", step.duration)
|
||||
|
||||
if campaign_name:
|
||||
recording.update_metadata("campaign", campaign_name)
|
||||
|
||||
# WiFi-specific labels
|
||||
if step.channel is not None:
|
||||
recording.update_metadata("wifi_channel", step.channel)
|
||||
if step.bandwidth_mhz is not None:
|
||||
recording.update_metadata("wifi_bandwidth_mhz", step.bandwidth_mhz)
|
||||
|
||||
# Bluetooth-specific labels
|
||||
if step.connection_interval_ms is not None:
|
||||
recording.update_metadata("bt_connection_interval_ms", step.connection_interval_ms)
|
||||
|
||||
# Traffic pattern (WiFi + BT)
|
||||
if step.traffic is not None:
|
||||
recording.update_metadata("traffic_pattern", step.traffic)
|
||||
|
||||
# TX power
|
||||
if step.power_dbm is not None:
|
||||
recording.update_metadata("tx_power_dbm", step.power_dbm)
|
||||
|
||||
return recording
|
||||
|
||||
|
||||
def build_output_filename(device_id: str, step: CaptureStep) -> str:
|
||||
"""Generate a deterministic filename for a labeled recording.
|
||||
|
||||
Format: ``<device_id>/<step_label>``
|
||||
|
||||
Args:
|
||||
device_id: Device identifier string.
|
||||
step: Capture step.
|
||||
|
||||
Returns:
|
||||
Relative path string (no extension) to use as ``filename`` in ``to_sigmf()``.
|
||||
"""
|
||||
safe_id = device_id.replace("/", "_").replace(" ", "_")
|
||||
safe_label = step.label.replace("/", "_").replace(" ", "_")
|
||||
return f"{safe_id}/{safe_label}"
|
||||
|
|
@ -1,109 +0,0 @@
|
|||
"""QA metrics for captured RF recordings."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ria_toolkit_oss.datatypes.recording import Recording
|
||||
|
||||
from .campaign import QAConfig
|
||||
|
||||
|
||||
@dataclass
|
||||
class QAResult:
|
||||
"""Result of QA checks on a single recording."""
|
||||
|
||||
passed: bool
|
||||
flagged: bool # True if any metric is below threshold (but not hard-failed)
|
||||
snr_db: float
|
||||
duration_s: float
|
||||
issues: list[str] = field(default_factory=list)
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return {
|
||||
"passed": self.passed,
|
||||
"flagged": self.flagged,
|
||||
"snr_db": round(self.snr_db, 2),
|
||||
"duration_s": round(self.duration_s, 3),
|
||||
"issues": self.issues,
|
||||
}
|
||||
|
||||
|
||||
def estimate_snr_db(samples: np.ndarray, signal_fraction: float = 0.7) -> float:
|
||||
"""Estimate SNR from IQ samples using PSD-based signal/noise separation.
|
||||
|
||||
Computes an FFT of the samples and assumes the top ``signal_fraction``
|
||||
of power bins are signal and the remainder are noise. This is a
|
||||
heuristic appropriate for a controlled testbed where a single dominant
|
||||
signal is expected.
|
||||
|
||||
Args:
|
||||
samples: 1-D complex array of IQ samples.
|
||||
signal_fraction: Fraction of PSD bins to treat as signal (0–1).
|
||||
|
||||
Returns:
|
||||
Estimated SNR in dB, or 0.0 if the noise floor is zero.
|
||||
"""
|
||||
n_fft = min(4096, len(samples))
|
||||
window = np.hanning(n_fft)
|
||||
psd = np.abs(np.fft.fft(samples[:n_fft] * window)) ** 2
|
||||
|
||||
psd_sorted = np.sort(psd)[::-1]
|
||||
n_signal = min(max(1, int(n_fft * signal_fraction)), n_fft - 1)
|
||||
signal_power = psd_sorted[:n_signal].mean()
|
||||
noise_power = psd_sorted[n_signal:].mean()
|
||||
|
||||
if noise_power <= 0.0:
|
||||
return 0.0
|
||||
return float(10.0 * np.log10(signal_power / noise_power))
|
||||
|
||||
|
||||
def check_recording(recording: Recording, config: QAConfig) -> QAResult:
|
||||
"""Run QA checks on a recording against the campaign QA config.
|
||||
|
||||
Checks performed:
|
||||
- Duration: number of samples / sample_rate >= min_duration_s
|
||||
- SNR: estimated SNR >= snr_threshold_db
|
||||
|
||||
Args:
|
||||
recording: Recording to evaluate.
|
||||
config: QA thresholds from the campaign config.
|
||||
|
||||
Returns:
|
||||
QAResult with pass/flag status and per-metric details.
|
||||
"""
|
||||
issues: list[str] = []
|
||||
flagged = False
|
||||
|
||||
# --- Duration check ---
|
||||
sample_rate = recording.metadata.get("sample_rate", 1.0)
|
||||
n_samples = recording.data.shape[-1]
|
||||
duration_s = n_samples / sample_rate if sample_rate else 0.0
|
||||
|
||||
if duration_s < config.min_duration_s:
|
||||
issues.append(f"Duration too short: {duration_s:.1f}s < {config.min_duration_s:.1f}s threshold")
|
||||
flagged = True
|
||||
|
||||
# --- SNR check ---
|
||||
samples = recording.data[0] if recording.data.ndim > 1 else recording.data
|
||||
snr_db = estimate_snr_db(samples)
|
||||
|
||||
if snr_db < config.snr_threshold_db:
|
||||
issues.append(f"SNR below threshold: {snr_db:.1f} dB < {config.snr_threshold_db:.1f} dB")
|
||||
flagged = True
|
||||
|
||||
# In flag_for_review mode: flag but don't hard-fail
|
||||
if config.flag_for_review:
|
||||
passed = True # always accept; human reviews flagged recordings
|
||||
else:
|
||||
passed = not flagged
|
||||
|
||||
return QAResult(
|
||||
passed=passed,
|
||||
flagged=flagged,
|
||||
snr_db=snr_db,
|
||||
duration_s=duration_s,
|
||||
issues=issues,
|
||||
)
|
||||
|
|
@ -4,39 +4,6 @@ It streamlines tasks involving signal reception and transmission, as well as com
|
|||
operations such as detecting and configuring available devices.
|
||||
"""
|
||||
|
||||
__all__ = ["SDR", "SDRError", "SDRParameterError", "MockSDR", "get_sdr_device"]
|
||||
__all__ = ["SDR", "SDRError", "SDRParameterError"]
|
||||
|
||||
from .mock import MockSDR
|
||||
from .sdr import SDR, SDRError, SDRParameterError
|
||||
|
||||
|
||||
def get_sdr_device(device_type: str, ident: str | None = None, tx: bool = False) -> SDR:
|
||||
"""Return an SDR instance for *device_type*.
|
||||
|
||||
For ``"mock"`` / ``"sim"`` device types, returns a :class:`MockSDR`
|
||||
immediately (no hardware required). For all real device types, delegates
|
||||
to ``ria_toolkit_oss_cli.ria_toolkit_oss.common.get_sdr_device`` if the
|
||||
CLI package is installed; otherwise raises ``ImportError`` with a helpful
|
||||
message.
|
||||
|
||||
Args:
|
||||
device_type: Device name (``"mock"``, ``"pluto"``, ``"usrp"``, …).
|
||||
ident: Optional device identifier (IP address, serial number, …).
|
||||
tx: If True, require TX capability.
|
||||
"""
|
||||
if device_type in ("mock", "sim"):
|
||||
return MockSDR()
|
||||
|
||||
# Delegate real device types to the CLI package which holds the driver
|
||||
# imports behind hardware-specific optional dependencies.
|
||||
try:
|
||||
from ria_toolkit_oss_cli.ria_toolkit_oss.common import (
|
||||
get_sdr_device as _cli_get,
|
||||
)
|
||||
except ImportError as exc:
|
||||
raise ImportError(
|
||||
f"ria_toolkit_oss_cli is required to use hardware SDR device '{device_type}'. "
|
||||
"Install it with: pip install ria-toolkit-oss-cli"
|
||||
) from exc
|
||||
|
||||
return _cli_get(device_type, ident=ident, tx=tx)
|
||||
|
|
|
|||
|
|
@ -1,131 +0,0 @@
|
|||
"""Simulated SDR device for testing without hardware.
|
||||
|
||||
Set ``recorder.device = "mock"`` (or ``"sim"``) in a campaign config to use
|
||||
this driver. The inference loop can also use it by specifying ``device:
|
||||
"mock"`` in the SDR start request.
|
||||
|
||||
The mock generates complex float32 AWGN samples normalised to [-1, 1].
|
||||
It satisfies both interfaces used in this codebase:
|
||||
|
||||
- ``record(num_samples)`` / ``_stream_rx(callback)`` — used by
|
||||
``CampaignExecutor`` (inherits from ``SDR`` base class).
|
||||
- ``rx(num_samples)`` — PlutoSDR-style interface used by the controller
|
||||
inference loop.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ria_toolkit_oss.sdr.sdr import SDR
|
||||
|
||||
_DEFAULT_BUFFER_SIZE = 4096
|
||||
# Simulated sample rate throttle: sleep this long between buffers so the
|
||||
# loop does not spin at 100% CPU. 10 ms ≈ 100 buffers/s which is fine for
|
||||
# tests and campaign execution timing.
|
||||
_SLEEP_PER_BUFFER_S = 0.01
|
||||
|
||||
|
||||
class MockSDR(SDR):
|
||||
"""Software-simulated SDR that generates AWGN noise.
|
||||
|
||||
Args:
|
||||
buffer_size: Number of complex samples per streaming buffer.
|
||||
seed: Optional RNG seed for reproducible output.
|
||||
"""
|
||||
|
||||
def __init__(self, buffer_size: int = _DEFAULT_BUFFER_SIZE, seed: int | None = None):
|
||||
super().__init__()
|
||||
self.rx_buffer_size: int = buffer_size
|
||||
self._rng = np.random.default_rng(seed)
|
||||
|
||||
# Direct attribute aliases used by _apply_sdr_config in the controller.
|
||||
self.center_freq: float = 2.45e9
|
||||
self.sample_rate: float = 10e6
|
||||
self.gain: float = 40.0
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Abstract method implementations
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def init_rx(
|
||||
self,
|
||||
sample_rate: float,
|
||||
center_frequency: float,
|
||||
gain,
|
||||
channel: int = 0,
|
||||
gain_mode: str = "manual",
|
||||
) -> None:
|
||||
self.rx_sample_rate = float(sample_rate)
|
||||
self.rx_center_frequency = float(center_frequency)
|
||||
self.rx_gain = 40.0 if gain is None else float(gain)
|
||||
# Mirror to the attribute names used by _apply_sdr_config.
|
||||
self.sample_rate = self.rx_sample_rate
|
||||
self.center_freq = self.rx_center_frequency
|
||||
self.gain = self.rx_gain
|
||||
self._rx_initialized = True
|
||||
|
||||
def init_tx(
|
||||
self,
|
||||
sample_rate: float,
|
||||
center_frequency: float,
|
||||
gain,
|
||||
channel: int = 0,
|
||||
gain_mode: str = "manual",
|
||||
) -> None:
|
||||
self.tx_sample_rate = float(sample_rate)
|
||||
self.tx_center_frequency = float(center_frequency)
|
||||
self.tx_gain = 40.0 if gain is None else float(gain)
|
||||
self._tx_initialized = True
|
||||
|
||||
def _stream_rx(self, callback) -> None:
|
||||
"""Generate 1-D AWGN buffers and pass each to *callback* until stopped.
|
||||
|
||||
Uses 1-D arrays so the base class ``_validate_buffer`` check does not
|
||||
incorrectly flag them as corrupted (the (1, N) form triggers a false
|
||||
positive in the all-same-value check).
|
||||
"""
|
||||
self._enable_rx = True
|
||||
while self._enable_rx:
|
||||
buf = self._awgn(self.rx_buffer_size)
|
||||
callback(buf)
|
||||
time.sleep(_SLEEP_PER_BUFFER_S)
|
||||
|
||||
def _stream_tx(self, callback) -> None:
|
||||
self._enable_tx = True
|
||||
while self._enable_tx:
|
||||
callback(self.rx_buffer_size)
|
||||
time.sleep(_SLEEP_PER_BUFFER_S)
|
||||
|
||||
def set_clock_source(self, source: str) -> None:
|
||||
pass # no-op
|
||||
|
||||
def close(self) -> None:
|
||||
self._enable_rx = False
|
||||
self._enable_tx = False
|
||||
self._rx_initialized = False
|
||||
self._tx_initialized = False
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# PlutoSDR-style interface used by the controller inference loop
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def rx(self, num_samples: int) -> np.ndarray:
|
||||
"""Return *num_samples* complex64 AWGN samples (PlutoSDR-style)."""
|
||||
return self._awgn(num_samples)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Internal helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _awgn(self, n: int) -> np.ndarray:
|
||||
"""Return *n* normalised complex64 AWGN samples as a 1-D array."""
|
||||
real = self._rng.standard_normal(n).astype(np.float32)
|
||||
imag = self._rng.standard_normal(n).astype(np.float32)
|
||||
buf = real + 1j * imag
|
||||
peak = np.abs(buf).max()
|
||||
if peak > 1e-9:
|
||||
buf /= peak
|
||||
return buf
|
||||
|
|
@ -329,12 +329,7 @@ class Pluto(SDR):
|
|||
elif tx_time is not None:
|
||||
pass
|
||||
else:
|
||||
if isinstance(recording, Recording):
|
||||
tx_time = recording.data.shape[-1] / self.tx_sample_rate
|
||||
elif isinstance(recording, np.ndarray):
|
||||
tx_time = recording.shape[-1] / self.tx_sample_rate
|
||||
else:
|
||||
tx_time = len(recording[0]) / self.tx_sample_rate
|
||||
tx_time = len(recording) / self.tx_sample_rate
|
||||
|
||||
data = self._format_tx_data(recording=recording)
|
||||
|
||||
|
|
@ -436,7 +431,7 @@ class Pluto(SDR):
|
|||
abs_gain = gain
|
||||
|
||||
if abs_gain < rx_gain_min or abs_gain > rx_gain_max:
|
||||
abs_gain = min(max(abs_gain, rx_gain_min), rx_gain_max)
|
||||
abs_gain = min(max(gain, rx_gain_min), rx_gain_max)
|
||||
print(f"Gain {gain} out of range for Pluto.")
|
||||
print(f"Gain range: {rx_gain_min} to {rx_gain_max} dB")
|
||||
|
||||
|
|
@ -588,8 +583,6 @@ class Pluto(SDR):
|
|||
self.tx_buffer_size = buffer_size
|
||||
|
||||
def close(self):
|
||||
if not hasattr(self, "radio"):
|
||||
return
|
||||
if self.radio.tx_cyclic_buffer:
|
||||
self.radio.tx_destroy_buffer()
|
||||
del self.radio
|
||||
|
|
|
|||
|
|
@ -32,6 +32,7 @@ class SDR(ABC):
|
|||
self._accumulated_buffer = None
|
||||
self._max_num_buffers = None
|
||||
self._num_buffers_processed = 0
|
||||
self._accumulated_buffer = None
|
||||
self._last_buffer = None
|
||||
self._corrupted_buffer_count = 0
|
||||
|
||||
|
|
@ -281,7 +282,7 @@ class SDR(ABC):
|
|||
elif num_samples is not None:
|
||||
self._num_samples_to_transmit = num_samples
|
||||
elif tx_time is not None:
|
||||
self._num_samples_to_transmit = int(tx_time * self.tx_sample_rate)
|
||||
self._num_samples_to_transmit = tx_time * self.tx_sample_rate
|
||||
else:
|
||||
self._num_samples_to_transmit = len(recording)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,5 +0,0 @@
|
|||
"""RT-OSS HTTP server for RIA Hub integration."""
|
||||
|
||||
from .app import create_app
|
||||
|
||||
__all__ = ["create_app"]
|
||||
|
|
@ -1,48 +0,0 @@
|
|||
"""FastAPI application factory for the RT-OSS HTTP server."""
|
||||
|
||||
from fastapi import Depends, FastAPI
|
||||
|
||||
from .auth import require_api_key
|
||||
from .routers import inference, orchestrator
|
||||
|
||||
|
||||
def create_app(api_key: str = "") -> FastAPI:
|
||||
"""Create and configure the RT-OSS FastAPI application.
|
||||
|
||||
Args:
|
||||
api_key: Secret key required in the ``X-API-Key`` request header.
|
||||
Pass an empty string to disable authentication (development only).
|
||||
|
||||
Returns:
|
||||
Configured FastAPI application instance.
|
||||
"""
|
||||
app = FastAPI(
|
||||
title="RIA Toolkit OSS Server",
|
||||
version="0.1.0",
|
||||
description=(
|
||||
"HTTP API for RT-OSS campaign orchestration and RF zone inference. "
|
||||
"All endpoints (except /health) require the X-API-Key header when "
|
||||
"an API key is configured."
|
||||
),
|
||||
)
|
||||
app.state.api_key = api_key
|
||||
|
||||
app.include_router(
|
||||
orchestrator.router,
|
||||
prefix="/orchestrator",
|
||||
tags=["Orchestrator"],
|
||||
dependencies=[Depends(require_api_key)],
|
||||
)
|
||||
app.include_router(
|
||||
inference.router,
|
||||
prefix="/inference",
|
||||
tags=["Inference"],
|
||||
dependencies=[Depends(require_api_key)],
|
||||
)
|
||||
|
||||
@app.get("/health", tags=["Health"])
|
||||
async def health():
|
||||
"""Health check — always returns 200."""
|
||||
return {"status": "ok"}
|
||||
|
||||
return app
|
||||
|
|
@ -1,36 +0,0 @@
|
|||
"""API key authentication dependency."""
|
||||
|
||||
import hmac
|
||||
import logging
|
||||
|
||||
from fastapi import Depends, HTTPException, Request, status
|
||||
from fastapi.security import APIKeyHeader
|
||||
|
||||
_api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def require_api_key(
|
||||
request: Request,
|
||||
api_key: str | None = Depends(_api_key_header),
|
||||
) -> None:
|
||||
"""FastAPI dependency that enforces X-API-Key header authentication.
|
||||
|
||||
If no API key is configured on the server (empty string), all requests
|
||||
are allowed — this is intended for local development only.
|
||||
"""
|
||||
expected: str = request.app.state.api_key
|
||||
if not expected:
|
||||
return # dev mode: no key set, allow all
|
||||
if not hmac.compare_digest(api_key or "", expected):
|
||||
client = getattr(request.client, "host", "unknown")
|
||||
logger.warning(
|
||||
"Authentication failure from %s — %s %s",
|
||||
client,
|
||||
request.method,
|
||||
request.url.path,
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Invalid or missing API key",
|
||||
)
|
||||
|
|
@ -1,47 +0,0 @@
|
|||
"""CLI entry point for the RT-OSS HTTP server.
|
||||
|
||||
Usage:
|
||||
ria-server # default: 0.0.0.0:8080, no auth
|
||||
RT_OSS_API_KEY=secret ria-server # enforce X-API-Key header
|
||||
RT_OSS_PORT=9000 ria-server # custom port
|
||||
|
||||
Environment variables:
|
||||
RT_OSS_API_KEY Shared secret for X-API-Key auth (empty = dev mode, no auth)
|
||||
RT_OSS_PORT TCP port to listen on (default: 8080)
|
||||
RT_OSS_HOST Bind address (default: 0.0.0.0)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
|
||||
|
||||
def serve() -> None:
|
||||
try:
|
||||
import uvicorn
|
||||
except ImportError:
|
||||
raise SystemExit(
|
||||
"uvicorn is required to run the RT-OSS server.\n" "Install it with: pip install 'ria-toolkit-oss[server]'"
|
||||
)
|
||||
|
||||
from .app import create_app
|
||||
|
||||
api_key = os.environ.get("RT_OSS_API_KEY", "")
|
||||
host = os.environ.get("RT_OSS_HOST", "0.0.0.0")
|
||||
port = int(os.environ.get("RT_OSS_PORT", "8080"))
|
||||
|
||||
app = create_app(api_key=api_key)
|
||||
|
||||
if not api_key:
|
||||
print(
|
||||
"\n"
|
||||
"╔══════════════════════════════════════════════════════════════╗\n"
|
||||
"║ WARNING: RT_OSS_API_KEY is not set. ║\n"
|
||||
"║ The server is running with NO authentication. ║\n"
|
||||
"║ Anyone who can reach this port has full API access. ║\n"
|
||||
"║ Set RT_OSS_API_KEY=<secret> before exposing to a network. ║\n"
|
||||
"╚══════════════════════════════════════════════════════════════╝\n",
|
||||
flush=True,
|
||||
)
|
||||
|
||||
uvicorn.run(app, host=host, port=port)
|
||||
|
|
@ -1,114 +0,0 @@
|
|||
"""Pydantic request and response models for the RT-OSS HTTP server."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from pydantic import BaseModel, field_validator
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Orchestrator
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class DeployRequest(BaseModel):
|
||||
config: dict
|
||||
|
||||
|
||||
class DeployResponse(BaseModel):
|
||||
campaign_id: str
|
||||
|
||||
|
||||
class CampaignStatusResponse(BaseModel):
|
||||
campaign_id: str
|
||||
status: str
|
||||
config_name: str
|
||||
progress: int
|
||||
total_steps: int
|
||||
started_at: float
|
||||
ended_at: float | None = None
|
||||
result: dict | None = None
|
||||
error: str | None = None
|
||||
|
||||
|
||||
class CancelResponse(BaseModel):
|
||||
campaign_id: str
|
||||
cancelled: bool
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Inference
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class SdrConfig(BaseModel):
|
||||
device: str
|
||||
center_freq: float
|
||||
sample_rate: float
|
||||
gain: float | str = "auto"
|
||||
|
||||
|
||||
class LoadModelRequest(BaseModel):
|
||||
model_path: str
|
||||
label_map: dict[str, int] # class_name -> class_index
|
||||
|
||||
@field_validator("model_path")
|
||||
@classmethod
|
||||
def validate_model_path(cls, v: str) -> str:
|
||||
p = Path(v)
|
||||
if ".." in p.parts:
|
||||
raise ValueError("model_path must not contain path traversal components")
|
||||
if p.suffix.lower() != ".onnx":
|
||||
raise ValueError("model_path must point to an .onnx file")
|
||||
# Resolve to catch symlink-based traversal; return the resolved absolute path
|
||||
# so callers always work with the real filesystem location.
|
||||
resolved = p.resolve()
|
||||
if resolved.suffix.lower() != ".onnx":
|
||||
raise ValueError("Resolved model_path must point to an .onnx file")
|
||||
return str(resolved)
|
||||
|
||||
|
||||
class LoadModelResponse(BaseModel):
|
||||
loaded: bool
|
||||
model_path: str
|
||||
num_classes: int
|
||||
|
||||
|
||||
class StartInferenceRequest(BaseModel):
|
||||
sdr_config: SdrConfig
|
||||
|
||||
|
||||
class StartInferenceResponse(BaseModel):
|
||||
running: bool
|
||||
|
||||
|
||||
class StopInferenceResponse(BaseModel):
|
||||
stopped: bool
|
||||
|
||||
|
||||
class ConfigureRequest(BaseModel):
|
||||
"""Partial SDR reconfiguration — only supplied fields are updated."""
|
||||
|
||||
center_freq: float | None = None
|
||||
sample_rate: float | None = None
|
||||
gain: float | str | None = None
|
||||
|
||||
|
||||
class ConfigureResponse(BaseModel):
|
||||
configured: bool
|
||||
|
||||
|
||||
class InferenceStatusResponse(BaseModel):
|
||||
"""Latest inference result as returned by GET /inference/status.
|
||||
|
||||
When ``idle`` is True the radio is scanning but no signal was detected.
|
||||
``device_id`` is the raw prediction label from the model's label map.
|
||||
The frontend is responsible for mapping device_id to a human name and
|
||||
determining whether the device is authorized.
|
||||
"""
|
||||
|
||||
timestamp: float
|
||||
idle: bool = False
|
||||
device_id: str | None = None # prediction label; None when idle
|
||||
confidence: float = 0.0
|
||||
snr_db: float = 0.0
|
||||
|
|
@ -1,253 +0,0 @@
|
|||
"""Inference routes: model loading, inference loop control, and status polling."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
from fastapi import APIRouter, HTTPException, status
|
||||
from scipy.special import softmax
|
||||
|
||||
from ..models import (
|
||||
ConfigureRequest,
|
||||
ConfigureResponse,
|
||||
InferenceStatusResponse,
|
||||
LoadModelRequest,
|
||||
LoadModelResponse,
|
||||
StartInferenceRequest,
|
||||
StartInferenceResponse,
|
||||
StopInferenceResponse,
|
||||
)
|
||||
from ..state import InferenceState, get_inference, set_inference
|
||||
|
||||
router = APIRouter()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_INFERENCE_NUM_SAMPLES = 4096
|
||||
|
||||
# Prediction labels that mean "no signal detected" — UI should treat these as idle.
|
||||
_IDLE_LABELS: frozenset[str] = frozenset({"noise", "idle", "no_signal", "unknown_protocol", "background"})
|
||||
|
||||
|
||||
def _load_onnx_session(model_path: str):
|
||||
try:
|
||||
import onnxruntime as ort
|
||||
except ImportError:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail="onnxruntime not installed. Install with: pip install ria-toolkit-oss[server]",
|
||||
)
|
||||
resolved = Path(model_path).resolve()
|
||||
if not resolved.is_file():
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
|
||||
detail=f"Model file not found: {model_path}",
|
||||
)
|
||||
try:
|
||||
return ort.InferenceSession(str(resolved), providers=["CPUExecutionProvider"])
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail=f"Failed to load ONNX model: {e}")
|
||||
|
||||
|
||||
def _preprocess_samples(samples: np.ndarray, expected_shape: tuple) -> np.ndarray:
|
||||
"""Reshape complex IQ samples to float32 matching the model's expected input.
|
||||
|
||||
Supports ``(batch, 2*N)`` interleaved and ``(batch, 2, N)`` two-channel conventions.
|
||||
"""
|
||||
iq = samples.astype(np.complex64)
|
||||
i_ch, q_ch = iq.real, iq.imag
|
||||
|
||||
if len(expected_shape) == 2:
|
||||
n = expected_shape[1] // 2
|
||||
interleaved = np.empty(expected_shape[1], dtype=np.float32)
|
||||
interleaved[0::2] = i_ch[:n]
|
||||
interleaved[1::2] = q_ch[:n]
|
||||
return interleaved.reshape(1, -1)
|
||||
elif len(expected_shape) == 3:
|
||||
n = expected_shape[2]
|
||||
return np.stack([i_ch[:n], q_ch[:n]], axis=0).astype(np.float32).reshape(1, 2, n)
|
||||
else:
|
||||
raise ValueError(f"Unsupported model input shape: {expected_shape}")
|
||||
|
||||
|
||||
def _stop_current_inference(state: InferenceState, timeout: float = 5.0) -> None:
|
||||
state.stop_event.set()
|
||||
if state.thread and state.thread.is_alive():
|
||||
state.thread.join(timeout=timeout)
|
||||
if state.thread.is_alive():
|
||||
logger.warning("Inference thread did not stop within %.1fs; SDR resources may not be released", timeout)
|
||||
|
||||
|
||||
def _apply_sdr_config(sdr, config: dict) -> None:
|
||||
"""Re-initialise the SDR receiver with updated parameters."""
|
||||
gain = config.get("gain")
|
||||
if gain == "auto":
|
||||
gain = None
|
||||
elif gain is not None:
|
||||
gain = float(gain)
|
||||
kwargs: dict = {}
|
||||
if config.get("center_freq") is not None:
|
||||
kwargs["center_frequency"] = float(config["center_freq"])
|
||||
if config.get("sample_rate") is not None:
|
||||
kwargs["sample_rate"] = float(config["sample_rate"])
|
||||
if gain is not None:
|
||||
kwargs["gain"] = gain
|
||||
if kwargs:
|
||||
sdr.init_rx(**kwargs, channel=0)
|
||||
|
||||
|
||||
def _inference_loop(state: InferenceState, sdr) -> None:
|
||||
from ria_toolkit_oss.orchestration.qa import estimate_snr_db
|
||||
|
||||
state.sdr = sdr
|
||||
state.set_running(True)
|
||||
session = state.session
|
||||
input_name = session.get_inputs()[0].name
|
||||
expected_shape = tuple(
|
||||
d if isinstance(d, int) and d > 0 else _INFERENCE_NUM_SAMPLES for d in session.get_inputs()[0].shape
|
||||
)
|
||||
|
||||
try:
|
||||
while not state.stop_event.is_set():
|
||||
# Apply any pending SDR reconfiguration before the next capture.
|
||||
pending = state.pop_pending_config()
|
||||
if pending:
|
||||
try:
|
||||
_apply_sdr_config(sdr, pending)
|
||||
except Exception as exc:
|
||||
logger.warning("SDR reconfigure failed: %s", exc)
|
||||
|
||||
recording = sdr.record(num_samples=_INFERENCE_NUM_SAMPLES)
|
||||
samples = recording.data[0] if recording.data.ndim > 1 else recording.data
|
||||
snr_db = estimate_snr_db(samples)
|
||||
|
||||
try:
|
||||
model_input = _preprocess_samples(samples, expected_shape)
|
||||
logits = session.run(None, {input_name: model_input})[0][0].astype(np.float32)
|
||||
probs = softmax(logits)
|
||||
pred_idx = int(np.argmax(probs))
|
||||
prediction = state.index_to_label.get(pred_idx, str(pred_idx))
|
||||
except Exception as exc:
|
||||
logger.warning("Inference prediction failed: %s", exc)
|
||||
continue
|
||||
|
||||
is_idle = prediction in _IDLE_LABELS
|
||||
|
||||
state.set_latest(
|
||||
{
|
||||
"timestamp": time.time(),
|
||||
"idle": is_idle,
|
||||
"device_id": prediction if not is_idle else None,
|
||||
"confidence": round(float(probs[pred_idx]), 4),
|
||||
"snr_db": round(snr_db, 2),
|
||||
}
|
||||
)
|
||||
finally:
|
||||
state.sdr = None
|
||||
try:
|
||||
sdr.close()
|
||||
except Exception:
|
||||
pass
|
||||
state.set_running(False)
|
||||
|
||||
|
||||
@router.post("/load", response_model=LoadModelResponse)
|
||||
async def load_model(request: LoadModelRequest):
|
||||
"""Load an ONNX model. Stops any running inference first.
|
||||
|
||||
``label_map`` maps class names to integer indices (e.g. ``{"iphone13_wifi_001": 0}``).
|
||||
``enrolled_devices`` enriches status responses with human names and authorization flags.
|
||||
"""
|
||||
existing = get_inference()
|
||||
if existing and existing.get_running():
|
||||
_stop_current_inference(existing)
|
||||
|
||||
session = _load_onnx_session(request.model_path)
|
||||
set_inference(
|
||||
InferenceState(
|
||||
model_path=request.model_path,
|
||||
label_map=request.label_map,
|
||||
index_to_label={v: k for k, v in request.label_map.items()},
|
||||
session=session,
|
||||
)
|
||||
)
|
||||
return LoadModelResponse(loaded=True, model_path=request.model_path, num_classes=len(request.label_map))
|
||||
|
||||
|
||||
@router.post("/start", response_model=StartInferenceResponse)
|
||||
async def start_inference(request: StartInferenceRequest):
|
||||
"""Start continuous inference. Requires a model to be loaded first."""
|
||||
state = get_inference()
|
||||
if not state:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_409_CONFLICT, detail="No model loaded. Call POST /inference/load first."
|
||||
)
|
||||
if state.get_running():
|
||||
raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail="Inference is already running.")
|
||||
|
||||
try:
|
||||
from ria_toolkit_oss.orchestration.executor import _DEVICE_ALIASES
|
||||
from ria_toolkit_oss.sdr import get_sdr_device
|
||||
except ImportError as e:
|
||||
raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail=f"SDR import failed: {e}")
|
||||
|
||||
sdr_cfg = request.sdr_config
|
||||
# Merge any pending configure request on top of the start config.
|
||||
pending = state.pop_pending_config() or {}
|
||||
center_freq = float(pending.get("center_freq") or sdr_cfg.center_freq)
|
||||
sample_rate = float(pending.get("sample_rate") or sdr_cfg.sample_rate)
|
||||
raw_gain = pending.get("gain") if "gain" in pending else sdr_cfg.gain
|
||||
gain = None if raw_gain == "auto" else float(raw_gain)
|
||||
try:
|
||||
sdr = get_sdr_device(_DEVICE_ALIASES.get(sdr_cfg.device.lower(), sdr_cfg.device.lower()))
|
||||
sdr.init_rx(sample_rate=sample_rate, center_frequency=center_freq, gain=gain, channel=0)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail=f"SDR initialisation failed: {e}")
|
||||
|
||||
state.stop_event.clear()
|
||||
state.thread = threading.Thread(target=_inference_loop, args=(state, sdr), daemon=True)
|
||||
state.thread.start()
|
||||
return StartInferenceResponse(running=True)
|
||||
|
||||
|
||||
@router.post("/stop", response_model=StopInferenceResponse)
|
||||
async def stop_inference():
|
||||
"""Stop the running inference loop."""
|
||||
state = get_inference()
|
||||
if not state or not state.get_running():
|
||||
return StopInferenceResponse(stopped=False)
|
||||
_stop_current_inference(state)
|
||||
return StopInferenceResponse(stopped=True)
|
||||
|
||||
|
||||
@router.post("/configure", response_model=ConfigureResponse)
|
||||
async def configure_inference(request: ConfigureRequest):
|
||||
"""Update SDR parameters (center_freq, sample_rate, gain) on the fly.
|
||||
|
||||
If inference is running the change is applied at the next capture boundary.
|
||||
If inference is not running the config is stored and applied when it starts.
|
||||
Only fields present in the request body are updated.
|
||||
"""
|
||||
state = get_inference()
|
||||
if not state:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_409_CONFLICT,
|
||||
detail="No model loaded. Call POST /inference/load first.",
|
||||
)
|
||||
pending = {k: v for k, v in request.model_dump().items() if v is not None}
|
||||
if pending:
|
||||
state.set_pending_config(pending)
|
||||
return ConfigureResponse(configured=bool(pending))
|
||||
|
||||
|
||||
@router.get("/status", response_model=InferenceStatusResponse | None)
|
||||
async def inference_status():
|
||||
"""Return the latest inference result, or null if no model is loaded."""
|
||||
state = get_inference()
|
||||
if not state:
|
||||
return None
|
||||
latest = state.get_latest()
|
||||
return InferenceStatusResponse(**latest) if latest else None
|
||||
|
|
@ -1,112 +0,0 @@
|
|||
"""Orchestrator routes: campaign deployment, status, and cancellation."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import threading
|
||||
import time
|
||||
import uuid
|
||||
from typing import Any
|
||||
|
||||
from fastapi import APIRouter, HTTPException, status
|
||||
|
||||
from ria_toolkit_oss.orchestration.campaign import CampaignConfig
|
||||
from ria_toolkit_oss.orchestration.executor import CampaignExecutor
|
||||
|
||||
from ..models import (
|
||||
CampaignStatusResponse,
|
||||
CancelResponse,
|
||||
DeployRequest,
|
||||
DeployResponse,
|
||||
)
|
||||
from ..state import (
|
||||
CampaignCancelledError,
|
||||
CampaignState,
|
||||
get_campaign,
|
||||
set_campaign,
|
||||
update_campaign,
|
||||
)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
def _make_progress_cb(campaign_id: str, cancel_event: threading.Event):
|
||||
def cb(step_index: int, total_steps: int, step_result: Any) -> None:
|
||||
update_campaign(campaign_id, progress=step_index)
|
||||
if cancel_event.is_set():
|
||||
raise CampaignCancelledError(f"Cancelled at step {step_index}/{total_steps}")
|
||||
|
||||
return cb
|
||||
|
||||
|
||||
def _run_campaign_thread(campaign_id: str, cfg: CampaignConfig) -> None:
|
||||
state = get_campaign(campaign_id)
|
||||
try:
|
||||
result = CampaignExecutor(
|
||||
config=cfg,
|
||||
progress_cb=_make_progress_cb(campaign_id, state.cancel_event),
|
||||
).run()
|
||||
update_campaign(
|
||||
campaign_id, status="completed", progress=cfg.total_steps(), result=result.to_dict(), ended_at=time.time()
|
||||
)
|
||||
except CampaignCancelledError:
|
||||
update_campaign(campaign_id, status="cancelled", ended_at=time.time())
|
||||
except Exception as e:
|
||||
update_campaign(campaign_id, status="failed", error=str(e), ended_at=time.time())
|
||||
|
||||
|
||||
@router.post("/deploy", response_model=DeployResponse)
|
||||
async def deploy(request: DeployRequest):
|
||||
"""Deploy a campaign config and start execution. Returns a ``campaign_id`` for polling.
|
||||
Cancellation takes effect at step boundaries, not mid-capture.
|
||||
|
||||
External scripts are not permitted in server-deployed campaigns. Configure
|
||||
transmitters without the ``script`` field, or run campaigns via the CLI.
|
||||
"""
|
||||
try:
|
||||
cfg = CampaignConfig.from_dict(request.config)
|
||||
except (ValueError, KeyError) as e:
|
||||
raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail=str(e))
|
||||
|
||||
if cfg.transmitters and any(t.script for t in cfg.transmitters):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
|
||||
detail="External scripts are not permitted in server-deployed campaigns. "
|
||||
"Remove the 'script' field from all transmitters, or run the campaign via the CLI.",
|
||||
)
|
||||
|
||||
campaign_id = str(uuid.uuid4())
|
||||
cancel_event = threading.Event()
|
||||
thread = threading.Thread(target=_run_campaign_thread, args=(campaign_id, cfg), daemon=True)
|
||||
set_campaign(
|
||||
CampaignState(
|
||||
campaign_id=campaign_id,
|
||||
status="running",
|
||||
config_name=cfg.name,
|
||||
cancel_event=cancel_event,
|
||||
thread=thread,
|
||||
total_steps=cfg.total_steps(),
|
||||
)
|
||||
)
|
||||
thread.start()
|
||||
return DeployResponse(campaign_id=campaign_id)
|
||||
|
||||
|
||||
@router.get("/status/{campaign_id}", response_model=CampaignStatusResponse)
|
||||
async def get_status(campaign_id: str):
|
||||
"""Get the status and progress of a deployed campaign."""
|
||||
state = get_campaign(campaign_id)
|
||||
if not state:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=f"Campaign {campaign_id!r} not found")
|
||||
return CampaignStatusResponse(**state.to_dict())
|
||||
|
||||
|
||||
@router.post("/cancel/{campaign_id}", response_model=CancelResponse)
|
||||
async def cancel(campaign_id: str):
|
||||
"""Request cancellation. Takes effect at the next step boundary."""
|
||||
state = get_campaign(campaign_id)
|
||||
if not state:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=f"Campaign {campaign_id!r} not found")
|
||||
if state.status != "running":
|
||||
return CancelResponse(campaign_id=campaign_id, cancelled=False)
|
||||
state.cancel_event.set()
|
||||
return CancelResponse(campaign_id=campaign_id, cancelled=True)
|
||||
|
|
@ -1,121 +0,0 @@
|
|||
"""In-memory state for running campaigns and inference sessions."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import threading
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Optional
|
||||
|
||||
|
||||
class CampaignCancelledError(Exception):
|
||||
"""Raised by the progress callback when a cancel is requested."""
|
||||
|
||||
|
||||
@dataclass
|
||||
class CampaignState:
|
||||
campaign_id: str
|
||||
status: str # "running" | "completed" | "failed" | "cancelled"
|
||||
config_name: str
|
||||
cancel_event: threading.Event
|
||||
thread: threading.Thread
|
||||
total_steps: int = 0
|
||||
progress: int = 0
|
||||
result: Optional[dict] = None
|
||||
error: Optional[str] = None
|
||||
started_at: float = field(default_factory=time.time)
|
||||
ended_at: Optional[float] = None
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return {
|
||||
"campaign_id": self.campaign_id,
|
||||
"status": self.status,
|
||||
"config_name": self.config_name,
|
||||
"progress": self.progress,
|
||||
"total_steps": self.total_steps,
|
||||
"result": self.result,
|
||||
"error": self.error,
|
||||
"started_at": self.started_at,
|
||||
"ended_at": self.ended_at,
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class InferenceState:
|
||||
model_path: str
|
||||
label_map: dict[str, int] # class_name -> class_index
|
||||
index_to_label: dict[int, str] # reverse: class_index -> class_name
|
||||
session: Any # onnxruntime.InferenceSession
|
||||
stop_event: threading.Event = field(default_factory=threading.Event)
|
||||
thread: Optional[threading.Thread] = None
|
||||
sdr: Any = None # live SDR object while inference is running
|
||||
running: bool = False
|
||||
_lock: threading.Lock = field(default_factory=threading.Lock, repr=False)
|
||||
_latest: Optional[dict] = field(default=None, repr=False)
|
||||
_pending_sdr_config: Optional[dict] = field(default=None, repr=False)
|
||||
|
||||
def set_latest(self, result: dict) -> None:
|
||||
with self._lock:
|
||||
self._latest = result
|
||||
|
||||
def get_latest(self) -> Optional[dict]:
|
||||
with self._lock:
|
||||
return self._latest
|
||||
|
||||
def set_pending_config(self, config: dict) -> None:
|
||||
with self._lock:
|
||||
self._pending_sdr_config = config
|
||||
|
||||
def pop_pending_config(self) -> Optional[dict]:
|
||||
with self._lock:
|
||||
cfg = self._pending_sdr_config
|
||||
self._pending_sdr_config = None
|
||||
return cfg
|
||||
|
||||
def set_running(self, value: bool) -> None:
|
||||
with self._lock:
|
||||
self.running = value
|
||||
|
||||
def get_running(self) -> bool:
|
||||
with self._lock:
|
||||
return self.running
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Module-level stores
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_campaigns: dict[str, CampaignState] = {}
|
||||
_campaigns_lock = threading.Lock()
|
||||
|
||||
_inference: Optional[InferenceState] = None
|
||||
_inference_lock = threading.Lock()
|
||||
|
||||
|
||||
def get_campaign(campaign_id: str) -> Optional[CampaignState]:
|
||||
with _campaigns_lock:
|
||||
return _campaigns.get(campaign_id)
|
||||
|
||||
|
||||
def set_campaign(state: CampaignState) -> None:
|
||||
with _campaigns_lock:
|
||||
_campaigns[state.campaign_id] = state
|
||||
|
||||
|
||||
def update_campaign(campaign_id: str, **kwargs) -> None:
|
||||
with _campaigns_lock:
|
||||
state = _campaigns.get(campaign_id)
|
||||
if state:
|
||||
for k, v in kwargs.items():
|
||||
setattr(state, k, v)
|
||||
|
||||
|
||||
def get_inference() -> Optional[InferenceState]:
|
||||
with _inference_lock:
|
||||
return _inference
|
||||
|
||||
|
||||
def set_inference(state: Optional[InferenceState]) -> None:
|
||||
global _inference
|
||||
with _inference_lock:
|
||||
_inference = state
|
||||
|
|
@ -2,7 +2,6 @@
|
|||
.. todo:: Need to add some information here about signal generation and the signal generators in this module.
|
||||
"""
|
||||
|
||||
import warnings
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
|
|
@ -228,7 +227,7 @@ def noise(
|
|||
|
||||
# TODO figure out a better way to make it conform to [-1,1]
|
||||
if not np.array_equal(magnitude, magnitude2):
|
||||
warnings.warn("basic_signal_generator.noise: magnitude clipped to [-1, 1]")
|
||||
print("Warning: clipping in basic_signal_generator.noise")
|
||||
|
||||
phase = np.random.uniform(low=0, high=2 * np.pi, size=length)
|
||||
complex_awgn = magnitude2 * np.exp(1j * phase)
|
||||
|
|
@ -269,9 +268,6 @@ def chirp(sample_rate: int, num_samples: int, center_frequency: Optional[float]
|
|||
.. todo:: Usage examples coming soon!
|
||||
"""
|
||||
# Ensure that the generated chirp signal remains within a safe frequency range to avoid aliasing.
|
||||
if num_samples < 2:
|
||||
raise ValueError("num_samples must be >= 2 for chirp generation")
|
||||
|
||||
chirp_start_frequency = center_frequency - sample_rate / 4
|
||||
chirp_end_frequency = center_frequency + sample_rate / 4
|
||||
|
||||
|
|
@ -311,9 +307,6 @@ def lfm_chirp_complex(
|
|||
down_part = np.flip(up_part)
|
||||
baseband_chirp = np.concatenate([up_part, down_part])
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unknown chirp_type '{chirp_type}'. Must be 'up', 'down', or 'up_down'.")
|
||||
|
||||
# Generate the full signal by tiling the windowed chirp
|
||||
num_chirps = round(total_time / chirp_period)
|
||||
full_signal = np.tile(baseband_chirp, num_chirps)
|
||||
|
|
|
|||
|
|
@ -1,7 +1,6 @@
|
|||
import numpy as np
|
||||
|
||||
from ria_toolkit_oss.signal.block_generator.block import Block
|
||||
from ria_toolkit_oss.signal.block_generator.data_types import DataType
|
||||
from utils.signal.block_generator.block import Block
|
||||
from utils.signal.block_generator.data_types import DataType
|
||||
|
||||
|
||||
class FrequencyUpConversion(Block):
|
||||
|
|
|
|||
|
|
@ -5,7 +5,6 @@ and return a corresponding numpy.ndarray with the impairment model applied;
|
|||
we call the latter the impaired data.
|
||||
"""
|
||||
|
||||
import warnings
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
|
|
@ -59,14 +58,13 @@ def generate_awgn(signal: ArrayLike | Recording, snr: Optional[float] = 1) -> np
|
|||
|
||||
# Calculate the RMS power of the signal to solve for the RMS power of the noise
|
||||
signal_rms_power = np.sqrt(np.mean(np.abs(data) ** 2))
|
||||
noise_rms_power = signal_rms_power / np.sqrt(snr_linear)
|
||||
noise_rms_power = signal_rms_power / snr_linear
|
||||
|
||||
# Generate complex AWGN: independent Gaussian I and Q components.
|
||||
# Each component has std = noise_rms_power / sqrt(2) so total power = noise_rms_power^2.
|
||||
component_std = noise_rms_power / np.sqrt(2)
|
||||
complex_awgn = np.random.normal(scale=component_std, size=(c, n)) + 1j * np.random.normal(
|
||||
scale=component_std, size=(c, n)
|
||||
)
|
||||
# Generate the AWGN noise which has the same shape as data
|
||||
variance = noise_rms_power**2
|
||||
magnitude = np.random.normal(loc=0, scale=np.sqrt(variance), size=(c, n))
|
||||
phase = np.random.uniform(low=0, high=2 * np.pi, size=(c, n))
|
||||
complex_awgn = magnitude * np.exp(1j * phase)
|
||||
|
||||
if isinstance(signal, Recording):
|
||||
return Recording(data=complex_awgn, metadata=signal.metadata)
|
||||
|
|
@ -380,8 +378,7 @@ def quantize_tape(
|
|||
raise ValueError("signal must be CxN complex.")
|
||||
|
||||
if rounding_type not in {"ceiling", "floor"}:
|
||||
warnings.warn('rounding_type must be either "floor" or "ceiling", floor has been selected by default')
|
||||
rounding_type = "floor"
|
||||
raise UserWarning('rounding_type must be either "floor" or "ceiling", floor has been selected by default')
|
||||
|
||||
if c == 1:
|
||||
iq_data = convert_to_2xn(data)
|
||||
|
|
@ -458,8 +455,7 @@ def quantize_parts(
|
|||
raise ValueError("signal must be CxN complex.")
|
||||
|
||||
if rounding_type not in {"ceiling", "floor"}:
|
||||
warnings.warn('rounding_type must be either "floor" or "ceiling", floor has been selected by default')
|
||||
rounding_type = "floor"
|
||||
raise UserWarning('rounding_type must be either "floor" or "ceiling", floor has been selected by default')
|
||||
|
||||
if c == 1:
|
||||
iq_data = convert_to_2xn(data)
|
||||
|
|
@ -614,11 +610,8 @@ def cut_out( # noqa: C901 # TODO: Simplify function
|
|||
raise ValueError("signal must be CxN complex.")
|
||||
|
||||
if fill_type not in {"zeros", "ones", "low-snr", "avg-snr", "high-snr"}:
|
||||
warnings.warn(
|
||||
'fill_type must be "zeros", "ones", "low-snr", "avg-snr", or "high-snr", '
|
||||
'"ones" has been selected by default'
|
||||
)
|
||||
fill_type = "ones"
|
||||
raise UserWarning("""fill_type must be "zeros", "ones", "low-snr", "avg-snr", or "high-snr",
|
||||
"ones" has been selected by default""")
|
||||
|
||||
if max_section_size < 1 or max_section_size >= n:
|
||||
raise ValueError("max_section_size must be at least 1 and must be less than the length of signal.")
|
||||
|
|
|
|||
|
|
@ -9,7 +9,6 @@ not the same as the signal at the end of the medium. What is sent is not what is
|
|||
Three causes of impairment are attenuation, distortion, and noise.
|
||||
"""
|
||||
|
||||
import warnings
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
|
|
@ -56,6 +55,8 @@ def add_awgn_to_signal(signal: ArrayLike | Recording, snr: Optional[float] = 1)
|
|||
raise ValueError("signal must be CxN complex.")
|
||||
|
||||
noise = iq_augmentations.generate_awgn(signal=data, snr=snr)
|
||||
print(f"noise is {noise}")
|
||||
|
||||
noisy_signal = data + noise
|
||||
|
||||
if isinstance(signal, Recording):
|
||||
|
|
@ -100,18 +101,16 @@ def time_shift(signal: ArrayLike | Recording, shift: Optional[int] = 1) -> np.nd
|
|||
raise ValueError("signal must be CxN complex.")
|
||||
|
||||
if shift > n:
|
||||
warnings.warn("shift is greater than signal length")
|
||||
raise UserWarning("shift is greater than signal length")
|
||||
|
||||
shifted_data = np.zeros_like(data)
|
||||
|
||||
if c == 1:
|
||||
# New iq array shifted left or right depending on sign of shift
|
||||
# This should work even if shift > iqdata.shape[1]
|
||||
if shift > 0:
|
||||
if shift >= 0:
|
||||
# Shift to right
|
||||
shifted_data[:, shift:] = data[:, :-shift]
|
||||
elif shift == 0:
|
||||
shifted_data[:] = data
|
||||
|
||||
else:
|
||||
# Shift to the left
|
||||
|
|
@ -204,7 +203,7 @@ def phase_shift(signal: ArrayLike | Recording, phase: Optional[float] = np.pi) -
|
|||
>>> rec = Recording(data=[[1+1j, 2+2j, 3+3j, 4+4j]])
|
||||
>>> new_rec = phase_shift(rec, np.pi/2)
|
||||
>>> new_rec.data
|
||||
array([[-1+1j, -2+2j, -3+3j, -4+4j]])
|
||||
array([[-1+1j, -2+2j -3+3j -4+4j]])
|
||||
"""
|
||||
# TODO: Additional info needs to be added to docstring description
|
||||
|
||||
|
|
@ -355,9 +354,8 @@ def resample(signal: ArrayLike | Recording, up: Optional[int] = 4, down: Optiona
|
|||
resampled_iqdata = resampled_iqdata[:, :n]
|
||||
|
||||
else:
|
||||
empty_array = np.zeros((1, n), dtype=resampled_iqdata.dtype)
|
||||
empty_array = np.zeros(resampled_iqdata.shape, dtype=resampled_iqdata.dtype)
|
||||
empty_array[:, : resampled_iqdata.shape[1]] = resampled_iqdata
|
||||
resampled_iqdata = empty_array
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
|
|
|
|||
Binary file not shown.
|
Before Width: | Height: | Size: 130 B After Width: | Height: | Size: 90 KiB |
Binary file not shown.
|
Before Width: | Height: | Size: 130 B After Width: | Height: | Size: 19 KiB |
|
|
@ -1,258 +0,0 @@
|
|||
"""Campaign orchestration CLI commands.
|
||||
|
||||
Usage examples::
|
||||
|
||||
# Enroll a single device using a device profile YAML (App 1 workflow)
|
||||
ria campaign enroll --config iphone13.yml
|
||||
|
||||
# Run a full custom campaign config
|
||||
ria campaign run --config my_campaign.yml
|
||||
|
||||
# Validate a config file without running it
|
||||
ria campaign validate --config my_campaign.yml
|
||||
"""
|
||||
|
||||
import sys
|
||||
|
||||
import click
|
||||
|
||||
from ria_toolkit_oss.orchestration.campaign import CampaignConfig
|
||||
from ria_toolkit_oss.orchestration.executor import CampaignExecutor, StepResult
|
||||
|
||||
|
||||
@click.group()
|
||||
def campaign():
|
||||
"""Orchestrate automated RF capture campaigns."""
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# ria campaign validate
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@campaign.command()
|
||||
@click.option(
|
||||
"--config",
|
||||
"-c",
|
||||
required=True,
|
||||
type=click.Path(exists=True),
|
||||
help="Campaign YAML config or device profile YAML.",
|
||||
)
|
||||
@click.option(
|
||||
"--profile",
|
||||
is_flag=True,
|
||||
default=False,
|
||||
help="Parse as a device profile (App 1 style) rather than a full campaign config.",
|
||||
)
|
||||
def validate(config, profile):
|
||||
"""Validate a campaign or device profile YAML without running it.
|
||||
|
||||
\b
|
||||
Examples:
|
||||
ria campaign validate --config iphone13.yml --profile
|
||||
ria campaign validate --config campaign.yml
|
||||
"""
|
||||
try:
|
||||
if profile:
|
||||
cfg = CampaignConfig.from_device_profile(config)
|
||||
else:
|
||||
cfg = CampaignConfig.from_yaml(config)
|
||||
except (FileNotFoundError, ValueError, KeyError) as e:
|
||||
raise click.ClickException(str(e))
|
||||
|
||||
click.echo(click.style("✓ Config valid", fg="green", bold=True))
|
||||
click.echo(f" Campaign name : {cfg.name}")
|
||||
click.echo(f" Mode : {cfg.mode}")
|
||||
click.echo(f" Transmitters : {len(cfg.transmitters)}")
|
||||
click.echo(f" Total steps : {cfg.total_steps()}")
|
||||
click.echo(f" Capture time : {cfg.total_capture_time_s():.0f}s")
|
||||
click.echo(f" Recorder : {cfg.recorder.device} @ {cfg.recorder.center_freq/1e6:.2f} MHz")
|
||||
click.echo(f" Sample rate : {cfg.recorder.sample_rate/1e6:.1f} MS/s")
|
||||
click.echo(f" Output path : {cfg.output.path}")
|
||||
click.echo()
|
||||
|
||||
for tx in cfg.transmitters:
|
||||
click.echo(f" Transmitter: {tx.id} ({tx.type}, {tx.control_method}, {len(tx.schedule)} steps)")
|
||||
for step in tx.schedule:
|
||||
extras = []
|
||||
if step.channel is not None:
|
||||
extras.append(f"ch={step.channel}")
|
||||
if step.bandwidth_mhz is not None:
|
||||
extras.append(f"{int(step.bandwidth_mhz)}MHz")
|
||||
if step.traffic:
|
||||
extras.append(step.traffic)
|
||||
suffix = f" [{', '.join(extras)}]" if extras else ""
|
||||
click.echo(f" [{step.duration:.0f}s] {step.label}{suffix}")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# ria campaign enroll
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@campaign.command()
|
||||
@click.option(
|
||||
"--config",
|
||||
"-c",
|
||||
required=True,
|
||||
type=click.Path(exists=True),
|
||||
help="Device profile YAML (App 1 enrollment format).",
|
||||
)
|
||||
@click.option(
|
||||
"--output",
|
||||
"-o",
|
||||
default=None,
|
||||
help="Override output directory from config.",
|
||||
)
|
||||
@click.option(
|
||||
"--report",
|
||||
default="qa_report.json",
|
||||
show_default=True,
|
||||
help="Path for the JSON QA report.",
|
||||
)
|
||||
@click.option("--verbose", "-v", is_flag=True, help="Verbose output.")
|
||||
@click.option("--dry-run", is_flag=True, help="Parse and validate config, then exit.")
|
||||
def enroll(config, output, report, verbose, dry_run):
|
||||
"""Enroll a single device by running its capture profile.
|
||||
|
||||
Parses a device profile YAML (App 1 format), generates a capture
|
||||
campaign, and runs it. Outputs labelled SigMF recordings and a
|
||||
JSON QA report.
|
||||
|
||||
\b
|
||||
Examples:
|
||||
ria campaign enroll --config iphone13.yml
|
||||
ria campaign enroll --config airpods.yml --output ./my_recordings
|
||||
ria campaign enroll --config iphone13.yml --dry-run
|
||||
"""
|
||||
try:
|
||||
cfg = CampaignConfig.from_device_profile(config)
|
||||
except (FileNotFoundError, ValueError, KeyError) as e:
|
||||
raise click.ClickException(str(e))
|
||||
|
||||
if output:
|
||||
cfg.output.path = output
|
||||
|
||||
_print_campaign_summary(cfg)
|
||||
|
||||
if dry_run:
|
||||
click.echo(click.style("Dry run — exiting before capture.", fg="yellow"))
|
||||
return
|
||||
|
||||
result = _run_campaign(cfg, verbose=verbose)
|
||||
result.write_report(report)
|
||||
|
||||
_print_result_summary(result, report)
|
||||
sys.exit(0 if result.failed == 0 else 1)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# ria campaign run
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@campaign.command()
|
||||
@click.option(
|
||||
"--config",
|
||||
"-c",
|
||||
required=True,
|
||||
type=click.Path(exists=True),
|
||||
help="Full campaign YAML config.",
|
||||
)
|
||||
@click.option(
|
||||
"--output",
|
||||
"-o",
|
||||
default=None,
|
||||
help="Override output directory from config.",
|
||||
)
|
||||
@click.option(
|
||||
"--report",
|
||||
default="qa_report.json",
|
||||
show_default=True,
|
||||
help="Path for the JSON QA report.",
|
||||
)
|
||||
@click.option("--verbose", "-v", is_flag=True, help="Verbose output.")
|
||||
@click.option("--dry-run", is_flag=True, help="Parse and validate config, then exit.")
|
||||
def run(config, output, report, verbose, dry_run):
|
||||
"""Run a full campaign from a campaign config YAML.
|
||||
|
||||
\b
|
||||
Examples:
|
||||
ria campaign run --config wifi_capture.yml
|
||||
ria campaign run --config campaign.yml --output ./data --dry-run
|
||||
"""
|
||||
try:
|
||||
cfg = CampaignConfig.from_yaml(config)
|
||||
except (FileNotFoundError, ValueError, KeyError) as e:
|
||||
raise click.ClickException(str(e))
|
||||
|
||||
if output:
|
||||
cfg.output.path = output
|
||||
|
||||
_print_campaign_summary(cfg)
|
||||
|
||||
if dry_run:
|
||||
click.echo(click.style("Dry run — exiting before capture.", fg="yellow"))
|
||||
return
|
||||
|
||||
result = _run_campaign(cfg, verbose=verbose)
|
||||
result.write_report(report)
|
||||
|
||||
_print_result_summary(result, report)
|
||||
sys.exit(0 if result.failed == 0 else 1)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Shared helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _print_campaign_summary(cfg: CampaignConfig) -> None:
|
||||
click.echo()
|
||||
click.echo(click.style(f"Campaign: {cfg.name}", bold=True))
|
||||
click.echo(f" Transmitters : {len(cfg.transmitters)}")
|
||||
click.echo(f" Total steps : {cfg.total_steps()}")
|
||||
click.echo(f" Capture time : ~{cfg.total_capture_time_s():.0f}s")
|
||||
click.echo(f" Output : {cfg.output.path}")
|
||||
click.echo()
|
||||
|
||||
|
||||
def _make_progress_cb(total: int):
|
||||
"""Return a progress callback that prints step results to stderr."""
|
||||
|
||||
def cb(idx: int, _total: int, step: StepResult) -> None:
|
||||
status = (
|
||||
click.style("✓", fg="green")
|
||||
if step.ok
|
||||
else (click.style("⚑", fg="yellow") if step.qa.flagged else click.style("✗", fg="red"))
|
||||
)
|
||||
snr_str = f"SNR {step.qa.snr_db:.1f} dB" if not step.error else f"ERROR: {step.error}"
|
||||
click.echo(
|
||||
f" [{idx:>3}/{_total}] {status} {step.transmitter_id}/{step.step_label} — {snr_str}",
|
||||
err=True,
|
||||
)
|
||||
|
||||
return cb
|
||||
|
||||
|
||||
def _run_campaign(cfg: CampaignConfig, verbose: bool = False):
|
||||
executor = CampaignExecutor(
|
||||
config=cfg,
|
||||
progress_cb=_make_progress_cb(cfg.total_steps()),
|
||||
verbose=verbose,
|
||||
)
|
||||
return executor.run()
|
||||
|
||||
|
||||
def _print_result_summary(result, report_path: str) -> None:
|
||||
click.echo()
|
||||
click.echo(click.style("Campaign complete", bold=True))
|
||||
click.echo(f" Steps : {result.total_steps}")
|
||||
click.echo(f" Passed : {click.style(str(result.passed), fg='green')}")
|
||||
if result.flagged:
|
||||
click.echo(f" Flagged : {click.style(str(result.flagged), fg='yellow')} (review required)")
|
||||
if result.failed:
|
||||
click.echo(f" Failed : {click.style(str(result.failed), fg='red')}")
|
||||
click.echo(f" Duration: {result.duration_s:.0f}s")
|
||||
click.echo(f" Report : {report_path}")
|
||||
click.echo()
|
||||
|
|
@ -315,7 +315,7 @@ def capture(
|
|||
ident = ident or config.get("ident") or config.get("serial") # Support legacy 'serial' in config
|
||||
sample_rate = sample_rate or config.get("sample_rate")
|
||||
center_frequency = center_frequency or config.get("center_frequency")
|
||||
gain = gain if gain is not None else config.get("gain")
|
||||
gain = gain or config.get("gain")
|
||||
bandwidth = bandwidth or config.get("bandwidth")
|
||||
num_samples = num_samples or config.get("num_samples")
|
||||
duration = duration or config.get("duration")
|
||||
|
|
|
|||
|
|
@ -3,7 +3,6 @@
|
|||
This module contains all the CLI bindings for the ria package.
|
||||
"""
|
||||
|
||||
from .campaign import campaign
|
||||
from .capture import capture
|
||||
from .combine import combine
|
||||
from .convert import convert
|
||||
|
|
@ -14,7 +13,6 @@ from .generate import generate
|
|||
|
||||
# from .generate import generate
|
||||
from .init import init
|
||||
from .serve import serve
|
||||
from .split import split
|
||||
from .transform import transform
|
||||
from .transmit import transmit
|
||||
|
|
|
|||
|
|
@ -332,7 +332,7 @@ def parse_ident(ident: Optional[str]) -> tuple[Optional[str], Optional[str]]:
|
|||
return ident, None
|
||||
|
||||
|
||||
def get_sdr_device(device_type: str, ident: Optional[str] = None, tx=False): # noqa: C901
|
||||
def get_sdr_device(device_type: str, ident: Optional[str] = None, tx=False):
|
||||
"""
|
||||
Get TX-capable SDR device instance.
|
||||
|
||||
|
|
@ -346,11 +346,6 @@ def get_sdr_device(device_type: str, ident: Optional[str] = None, tx=False): #
|
|||
Raises:
|
||||
click.ClickException: If device cannot be initialized or doesn't support TX
|
||||
"""
|
||||
if device_type in ("mock", "sim"):
|
||||
from ria_toolkit_oss.sdr.mock import MockSDR
|
||||
|
||||
return MockSDR()
|
||||
|
||||
TX_CAPABLE_DEVICES = ["pluto", "hackrf", "bladerf", "usrp"]
|
||||
if tx and device_type not in TX_CAPABLE_DEVICES:
|
||||
raise click.ClickException(
|
||||
|
|
|
|||
|
|
@ -1,6 +1,5 @@
|
|||
"""Device discovery utilities for SDR devices."""
|
||||
|
||||
import importlib
|
||||
import json
|
||||
import re
|
||||
import subprocess
|
||||
|
|
@ -43,28 +42,15 @@ def load_sdr_drivers(verbose: bool = False) -> Tuple[List[str], List[str], Dict[
|
|||
for driver_name, module_path in drivers.items():
|
||||
try:
|
||||
# Attempt to import the driver module
|
||||
import warnings
|
||||
|
||||
if not verbose:
|
||||
# Suppress output for quiet loading
|
||||
import warnings
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore")
|
||||
mod = importlib.import_module(module_path)
|
||||
__import__(module_path)
|
||||
else:
|
||||
mod = importlib.import_module(module_path)
|
||||
|
||||
# Verify the loaded module is from the expected package to guard against
|
||||
# dependency-confusion / sys.path injection attacks.
|
||||
mod_file = getattr(mod, "__file__", None) or ""
|
||||
expected_pkg = module_path.split(".")[0] # e.g. "ria_toolkit_oss"
|
||||
pkg_root = importlib.import_module(expected_pkg).__file__ or ""
|
||||
import os as _os
|
||||
|
||||
pkg_dir = _os.path.dirname(_os.path.dirname(pkg_root))
|
||||
if mod_file and not _os.path.realpath(mod_file).startswith(_os.path.realpath(pkg_dir)):
|
||||
warnings.warn(
|
||||
f"SDR driver '{driver_name}' loaded from unexpected location: {mod_file}",
|
||||
stacklevel=2,
|
||||
)
|
||||
__import__(module_path)
|
||||
|
||||
_loaded_drivers.append(driver_name)
|
||||
|
||||
|
|
|
|||
|
|
@ -214,7 +214,7 @@ def apply_post_processing(
|
|||
)
|
||||
|
||||
# 3. AWGN (Final stage usually)
|
||||
if add_noise:
|
||||
if add_noise == "awgn":
|
||||
npow = channel_params.get("noise_power", 0.1)
|
||||
echo_verbose(f"Applying AWGN (Power={npow})", verbose)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,51 +0,0 @@
|
|||
"""``ria serve`` — start the RT-OSS HTTP server for RIA Hub integration."""
|
||||
|
||||
import click
|
||||
|
||||
|
||||
@click.command()
|
||||
@click.option("--host", default="0.0.0.0", show_default=True, help="Bind host.")
|
||||
@click.option("--port", default=8080, show_default=True, type=int, help="Bind port.")
|
||||
@click.option(
|
||||
"--api-key",
|
||||
envvar="RT_OSS_API_KEY",
|
||||
default="",
|
||||
help="Required X-API-Key value. Also reads RT_OSS_API_KEY. Empty = no auth (dev only).",
|
||||
)
|
||||
@click.option(
|
||||
"--log-level",
|
||||
default="info",
|
||||
show_default=True,
|
||||
type=click.Choice(["debug", "info", "warning", "error"], case_sensitive=False),
|
||||
)
|
||||
def serve(host: str, port: int, api_key: str, log_level: str):
|
||||
"""Start the RT-OSS HTTP server.
|
||||
|
||||
\b
|
||||
Endpoints:
|
||||
POST /orchestrator/deploy
|
||||
GET /orchestrator/status/{campaign_id}
|
||||
POST /orchestrator/cancel/{campaign_id}
|
||||
POST /inference/load
|
||||
POST /inference/start
|
||||
POST /inference/stop
|
||||
GET /inference/status
|
||||
GET /health
|
||||
"""
|
||||
try:
|
||||
import uvicorn
|
||||
|
||||
from ria_toolkit_oss.server.app import create_app
|
||||
except ImportError as e:
|
||||
raise click.ClickException(
|
||||
f"Server dependencies missing: {e}\nInstall with: pip install ria-toolkit-oss[server]"
|
||||
)
|
||||
|
||||
if not api_key:
|
||||
click.echo(
|
||||
click.style("Warning: ", fg="yellow", bold=True) + "no API key set — all requests unauthenticated.",
|
||||
err=True,
|
||||
)
|
||||
|
||||
click.echo(f"Starting RT-OSS server on http://{host}:{port}")
|
||||
uvicorn.run(create_app(api_key=api_key), host=host, port=port, log_level=log_level.lower())
|
||||
|
|
@ -204,15 +204,9 @@ def load_custom_transforms(transform_dir):
|
|||
if not py_files:
|
||||
raise click.ClickException(f"No .py files found in {transform_dir}")
|
||||
|
||||
click.echo(
|
||||
f"WARNING: Loading custom transforms from '{transform_dir}'. "
|
||||
"Each .py file will be executed as Python code — only use directories you trust.",
|
||||
err=True,
|
||||
)
|
||||
|
||||
for py_file in py_files:
|
||||
try:
|
||||
# Load module dynamically — executes the file as Python code.
|
||||
# Load module dynamically
|
||||
spec = importlib.util.spec_from_file_location(py_file.stem, py_file)
|
||||
if spec is None or spec.loader is None:
|
||||
click.echo(f"Warning: Could not load {py_file.name}")
|
||||
|
|
|
|||
|
|
@ -393,7 +393,7 @@ def transmit(
|
|||
ident = ident or config.get("ident") or config.get("serial") # Support legacy 'serial' in config
|
||||
sample_rate = sample_rate or config.get("sample_rate")
|
||||
center_frequency = center_frequency or config.get("center_frequency")
|
||||
gain = gain if gain is not None else config.get("gain")
|
||||
gain = gain or config.get("gain")
|
||||
bandwidth = bandwidth or config.get("bandwidth")
|
||||
input_file = input_file or config.get("input")
|
||||
generate = generate or config.get("generate")
|
||||
|
|
|
|||
|
|
@ -67,135 +67,3 @@ def test_annotation_area():
|
|||
annotation_area = sample_annotation.area()
|
||||
|
||||
assert annotation_area == 600000
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Additional coverage
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_annotation_area_zero_sample_count():
|
||||
# An annotation with sample_count=0 has area 0 even with a wide frequency band.
|
||||
ann = Annotation(0, 0, 0.0, 1000.0)
|
||||
assert ann.area() == 0
|
||||
|
||||
|
||||
def test_annotation_area_zero_bandwidth():
|
||||
# An annotation with equal freq edges has area 0 (degenerate band).
|
||||
ann = Annotation(0, 100, 500.0, 500.0)
|
||||
assert ann.area() == 0
|
||||
|
||||
|
||||
def test_annotation_overlap_no_overlap_disjoint_time():
|
||||
# Annotations that are completely separate in time have zero overlap.
|
||||
ann1 = Annotation(sample_start=0, sample_count=5, freq_lower_edge=0, freq_upper_edge=100)
|
||||
ann2 = Annotation(sample_start=10, sample_count=5, freq_lower_edge=0, freq_upper_edge=100)
|
||||
assert ann1.overlap(ann2) == 0
|
||||
|
||||
|
||||
def test_annotation_overlap_no_overlap_disjoint_frequency():
|
||||
# Annotations that are completely separate in frequency have zero overlap.
|
||||
ann1 = Annotation(sample_start=0, sample_count=10, freq_lower_edge=0, freq_upper_edge=100)
|
||||
ann2 = Annotation(sample_start=0, sample_count=10, freq_lower_edge=200, freq_upper_edge=300)
|
||||
assert ann1.overlap(ann2) == 0
|
||||
|
||||
|
||||
def test_annotation_overlap_touching_only_time():
|
||||
# Annotations that share only a single sample boundary do NOT overlap.
|
||||
# ann1 covers [0, 5), ann2 covers [5, 10) — they touch but don't overlap.
|
||||
ann1 = Annotation(sample_start=0, sample_count=5, freq_lower_edge=0, freq_upper_edge=100)
|
||||
ann2 = Annotation(sample_start=5, sample_count=5, freq_lower_edge=0, freq_upper_edge=100)
|
||||
assert ann1.overlap(ann2) == 0
|
||||
|
||||
|
||||
def test_annotation_overlap_touching_only_frequency():
|
||||
# Annotations that share only a single frequency edge do NOT overlap.
|
||||
ann1 = Annotation(sample_start=0, sample_count=10, freq_lower_edge=0, freq_upper_edge=100)
|
||||
ann2 = Annotation(sample_start=0, sample_count=10, freq_lower_edge=100, freq_upper_edge=200)
|
||||
assert ann1.overlap(ann2) == 0
|
||||
|
||||
|
||||
def test_annotation_overlap_with_self():
|
||||
# An annotation's overlap with itself equals its own area.
|
||||
ann = Annotation(sample_start=0, sample_count=10, freq_lower_edge=0, freq_upper_edge=100)
|
||||
assert ann.overlap(ann) == ann.area()
|
||||
|
||||
|
||||
def test_annotation_overlap_symmetry():
|
||||
# overlap(a, b) == overlap(b, a)
|
||||
ann1 = Annotation(sample_start=0, sample_count=10, freq_lower_edge=0, freq_upper_edge=100)
|
||||
ann2 = Annotation(sample_start=5, sample_count=10, freq_lower_edge=50, freq_upper_edge=150)
|
||||
assert ann1.overlap(ann2) == ann2.overlap(ann1)
|
||||
|
||||
|
||||
def test_annotation_overlap_partial_known_value():
|
||||
# ann1: samples [0,10), freq [0,100) → area = 10*100 = 1000
|
||||
# ann2: samples [5,15), freq [50,150) → area = 10*100 = 1000
|
||||
# overlap in samples: [5,10) = 5; in freq: [50,100) = 50 → overlap = 250
|
||||
ann1 = Annotation(sample_start=0, sample_count=10, freq_lower_edge=0, freq_upper_edge=100)
|
||||
ann2 = Annotation(sample_start=5, sample_count=10, freq_lower_edge=50, freq_upper_edge=150)
|
||||
assert ann1.overlap(ann2) == 5 * 50
|
||||
|
||||
|
||||
def test_annotation_detail_default_is_empty_dict():
|
||||
ann = Annotation(0, 10, 0.0, 100.0)
|
||||
assert ann.detail == {}
|
||||
|
||||
|
||||
def test_annotation_detail_accepts_valid_dict():
|
||||
ann = Annotation(0, 10, 0.0, 100.0, detail={"snr": 10.5, "modulation": "BPSK"})
|
||||
assert ann.detail == {"snr": 10.5, "modulation": "BPSK"}
|
||||
|
||||
|
||||
def test_annotation_detail_rejects_non_serializable():
|
||||
# A dict containing a non-JSON-serializable value should raise ValueError.
|
||||
import pytest
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
Annotation(0, 10, 0.0, 100.0, detail={"obj": object()})
|
||||
|
||||
|
||||
def test_annotation_to_sigmf_format_keys():
|
||||
# to_sigmf_format() should include the SigMF standard keys.
|
||||
ann = Annotation(
|
||||
sample_start=100,
|
||||
sample_count=200,
|
||||
freq_lower_edge=1000.0,
|
||||
freq_upper_edge=2000.0,
|
||||
label="WiFi",
|
||||
comment="test signal",
|
||||
detail={"snr_db": 15},
|
||||
)
|
||||
result = ann.to_sigmf_format()
|
||||
|
||||
# Top-level keys: sample_start index and sample_count length
|
||||
assert "sample_start" in result or any("start" in k.lower() for k in result)
|
||||
assert "metadata" in result
|
||||
|
||||
metadata = result["metadata"]
|
||||
# Frequency bounds must be present
|
||||
assert ann.freq_lower_edge in metadata.values()
|
||||
assert ann.freq_upper_edge in metadata.values()
|
||||
|
||||
# Label and comment
|
||||
assert ann.label in metadata.values()
|
||||
assert ann.comment in metadata.values()
|
||||
|
||||
# detail passthrough
|
||||
assert metadata.get("ria:detail") == {"snr_db": 15}
|
||||
|
||||
|
||||
def test_annotation_to_sigmf_format_values():
|
||||
# Check that numeric values are correctly mapped.
|
||||
ann = Annotation(
|
||||
sample_start=50,
|
||||
sample_count=100,
|
||||
freq_lower_edge=500.0,
|
||||
freq_upper_edge=1500.0,
|
||||
)
|
||||
result = ann.to_sigmf_format()
|
||||
|
||||
# sample_start and sample_count must appear at the top level
|
||||
values = list(result.values())
|
||||
assert 50 in values or ann.sample_start in values
|
||||
assert 100 in values or ann.sample_count in values
|
||||
|
|
|
|||
|
|
@ -218,249 +218,3 @@ def test_remove_from_metadata_1():
|
|||
|
||||
with pytest.raises(ValueError):
|
||||
recording.remove_from_metadata("timestamp")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Additional coverage
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
# --- Invalid construction ---
|
||||
|
||||
|
||||
def test_real_data_raises():
|
||||
# Real (non-complex) input must raise ValueError.
|
||||
with pytest.raises(ValueError):
|
||||
Recording(data=[[1.0, 2.0, 3.0]])
|
||||
|
||||
|
||||
def test_3d_data_raises():
|
||||
# 3-D complex array must raise ValueError.
|
||||
with pytest.raises(ValueError):
|
||||
Recording(data=np.ones((2, 3, 4), dtype=np.complex128))
|
||||
|
||||
|
||||
def test_non_dict_metadata_raises():
|
||||
# Metadata must be a python dict.
|
||||
with pytest.raises(ValueError):
|
||||
Recording(data=COMPLEX_DATA_1, metadata="sample_rate=1e6")
|
||||
|
||||
|
||||
def test_non_serializable_metadata_raises():
|
||||
# Metadata containing a non-JSON-serializable value must raise ValueError.
|
||||
with pytest.raises(ValueError):
|
||||
Recording(data=COMPLEX_DATA_1, metadata={"bad": object()})
|
||||
|
||||
|
||||
def test_non_annotation_list_raises():
|
||||
# annotations must be a list of Annotation objects.
|
||||
with pytest.raises(ValueError):
|
||||
Recording(data=COMPLEX_DATA_1, annotations=["not an annotation"])
|
||||
|
||||
|
||||
def test_non_list_annotations_raises():
|
||||
# annotations must be a list (not some other type).
|
||||
with pytest.raises(ValueError):
|
||||
Recording(data=COMPLEX_DATA_1, annotations=Annotation(0, 10, 0, 100))
|
||||
|
||||
|
||||
def test_invalid_timestamp_type_raises():
|
||||
# timestamp must be int or float, not str.
|
||||
with pytest.raises(ValueError):
|
||||
Recording(data=COMPLEX_DATA_1, timestamp="now")
|
||||
|
||||
|
||||
# --- generate_recording_id ---
|
||||
|
||||
|
||||
def test_generate_recording_id_length():
|
||||
# SHA-256 hex digest is always 64 characters.
|
||||
rid = generate_recording_id(data=np.array(COMPLEX_DATA_1), timestamp=123.0)
|
||||
assert len(rid) == 64
|
||||
|
||||
|
||||
def test_generate_recording_id_is_hex():
|
||||
rid = generate_recording_id(data=np.array(COMPLEX_DATA_1), timestamp=123.0)
|
||||
assert all(c in "0123456789abcdef" for c in rid)
|
||||
|
||||
|
||||
def test_generate_recording_id_deterministic():
|
||||
# Same data + timestamp must always produce the same ID.
|
||||
data = np.array(COMPLEX_DATA_1)
|
||||
rid1 = generate_recording_id(data=data, timestamp=42.0)
|
||||
rid2 = generate_recording_id(data=data, timestamp=42.0)
|
||||
assert rid1 == rid2
|
||||
|
||||
|
||||
def test_generate_recording_id_differs_by_data():
|
||||
data1 = np.array([[1 + 1j, 2 + 2j]])
|
||||
data2 = np.array([[3 + 3j, 4 + 4j]])
|
||||
rid1 = generate_recording_id(data=data1, timestamp=1.0)
|
||||
rid2 = generate_recording_id(data=data2, timestamp=1.0)
|
||||
assert rid1 != rid2
|
||||
|
||||
|
||||
def test_generate_recording_id_differs_by_timestamp():
|
||||
data = np.array(COMPLEX_DATA_1)
|
||||
rid1 = generate_recording_id(data=data, timestamp=1.0)
|
||||
rid2 = generate_recording_id(data=data, timestamp=2.0)
|
||||
assert rid1 != rid2
|
||||
|
||||
|
||||
def test_generate_recording_id_no_timestamp_uses_current_time():
|
||||
# Without a timestamp the function should still return a 64-char hex string.
|
||||
rid = generate_recording_id(data=np.array(COMPLEX_DATA_1))
|
||||
assert len(rid) == 64
|
||||
|
||||
|
||||
# --- add_to_metadata validation ---
|
||||
|
||||
|
||||
def test_add_to_metadata_camelcase_key_raises():
|
||||
rec = Recording(data=COMPLEX_DATA_1)
|
||||
with pytest.raises(ValueError):
|
||||
rec.add_to_metadata(key="sampleRate", value=1e6)
|
||||
|
||||
|
||||
def test_add_to_metadata_key_with_space_raises():
|
||||
rec = Recording(data=COMPLEX_DATA_1)
|
||||
with pytest.raises(ValueError):
|
||||
rec.add_to_metadata(key="sample rate", value=1e6)
|
||||
|
||||
|
||||
def test_add_to_metadata_key_with_digit_raises():
|
||||
# Regex ^[a-z_]+$ does NOT allow digits; "freq_2" is therefore invalid.
|
||||
rec = Recording(data=COMPLEX_DATA_1)
|
||||
with pytest.raises(ValueError):
|
||||
rec.add_to_metadata(key="freq_2", value=1e6)
|
||||
|
||||
|
||||
def test_add_to_metadata_duplicate_key_raises():
|
||||
rec = Recording(data=COMPLEX_DATA_1)
|
||||
rec.add_to_metadata(key="author", value="alice")
|
||||
with pytest.raises(ValueError):
|
||||
rec.add_to_metadata(key="author", value="bob")
|
||||
|
||||
|
||||
def test_add_to_metadata_valid_underscore_key():
|
||||
rec = Recording(data=COMPLEX_DATA_1)
|
||||
rec.add_to_metadata(key="sample_rate", value=1e6)
|
||||
assert rec.metadata["sample_rate"] == 1e6
|
||||
|
||||
|
||||
# --- update_metadata protected key enforcement ---
|
||||
|
||||
|
||||
def test_update_metadata_rec_id_raises():
|
||||
rec = Recording(data=COMPLEX_DATA_1, metadata=SAMPLE_METADATA)
|
||||
with pytest.raises(ValueError):
|
||||
rec.update_metadata(key="rec_id", value="fakeid")
|
||||
|
||||
|
||||
def test_update_metadata_timestamp_raises():
|
||||
rec = Recording(data=COMPLEX_DATA_1, metadata=SAMPLE_METADATA)
|
||||
with pytest.raises(ValueError):
|
||||
rec.update_metadata(key="timestamp", value=0.0)
|
||||
|
||||
|
||||
# --- remove_from_metadata ---
|
||||
|
||||
|
||||
def test_remove_from_metadata_rec_id_raises():
|
||||
rec = Recording(data=COMPLEX_DATA_1)
|
||||
with pytest.raises(ValueError):
|
||||
rec.remove_from_metadata("rec_id")
|
||||
|
||||
|
||||
def test_remove_from_metadata_removes_key():
|
||||
rec = Recording(data=COMPLEX_DATA_1)
|
||||
rec.add_to_metadata("foo", "bar")
|
||||
assert "foo" in rec.metadata
|
||||
rec.remove_from_metadata("foo")
|
||||
assert "foo" not in rec.metadata
|
||||
|
||||
|
||||
# --- setitem is blocked ---
|
||||
|
||||
|
||||
def test_setitem_raises():
|
||||
rec = Recording(data=COMPLEX_DATA_1)
|
||||
with pytest.raises(ValueError):
|
||||
rec[0, 0] = 999 + 0j
|
||||
|
||||
|
||||
# --- data property read-only for large recordings ---
|
||||
|
||||
|
||||
def test_data_read_only_for_large_recording():
|
||||
# For recordings with more than 1024 samples the data property returns a
|
||||
# read-only view; writing to it must raise ValueError.
|
||||
large_data = np.ones(2048, dtype=np.complex128)
|
||||
rec = Recording(data=large_data)
|
||||
view = rec.data
|
||||
with pytest.raises((ValueError, TypeError)):
|
||||
view[0] = 0 + 0j
|
||||
|
||||
|
||||
def test_data_copy_for_small_recording():
|
||||
# For recordings with 1024 or fewer samples the property returns a copy;
|
||||
# mutating the copy must NOT affect the recording's internal data.
|
||||
rec = Recording(data=COMPLEX_DATA_1)
|
||||
copy = rec.data
|
||||
copy[0, 0] = -999 + 0j # mutate the copy
|
||||
assert rec.data[0, 0] != -999 + 0j # internal data is unchanged
|
||||
|
||||
|
||||
# --- trim edge cases ---
|
||||
|
||||
|
||||
def test_trim_negative_start_raises():
|
||||
rec = Recording(data=COMPLEX_DATA_1)
|
||||
with pytest.raises(IndexError):
|
||||
rec.trim(start_sample=-1, num_samples=3)
|
||||
|
||||
|
||||
def test_trim_beyond_end_raises():
|
||||
rec = Recording(data=COMPLEX_DATA_1)
|
||||
with pytest.raises(IndexError):
|
||||
rec.trim(start_sample=3, num_samples=10)
|
||||
|
||||
|
||||
def test_trim_preserves_metadata():
|
||||
# Use a fresh dict to avoid pollution from tests that mutate SAMPLE_METADATA via Recording.
|
||||
meta = {"source": "original", "timestamp": 1723472227.698788}
|
||||
rec = Recording(data=COMPLEX_DATA_1, metadata=meta)
|
||||
trimmed = rec.trim(start_sample=0, num_samples=3)
|
||||
assert trimmed.metadata["source"] == "original"
|
||||
|
||||
|
||||
# --- annotations ---
|
||||
|
||||
|
||||
def test_recording_with_annotations_stores_them():
|
||||
ann = Annotation(sample_start=0, sample_count=2, freq_lower_edge=0, freq_upper_edge=100)
|
||||
rec = Recording(data=COMPLEX_DATA_1, annotations=[ann])
|
||||
assert len(rec.annotations) == 1
|
||||
assert rec.annotations[0] == ann
|
||||
|
||||
|
||||
def test_recording_annotations_is_copy():
|
||||
# Mutating the returned list must not affect the internal annotation list.
|
||||
ann = Annotation(sample_start=0, sample_count=2, freq_lower_edge=0, freq_upper_edge=100)
|
||||
rec = Recording(data=COMPLEX_DATA_1, annotations=[ann])
|
||||
returned = rec.annotations
|
||||
returned.append(ann) # mutate the copy
|
||||
assert len(rec.annotations) == 1 # internal list unchanged
|
||||
|
||||
|
||||
# --- n_chan property ---
|
||||
|
||||
|
||||
def test_n_chan_single_channel():
|
||||
rec = Recording(data=COMPLEX_DATA_1)
|
||||
assert rec.n_chan == 1
|
||||
|
||||
|
||||
def test_n_chan_multi_channel():
|
||||
rec = Recording(data=COMPLEX_DATA_2)
|
||||
assert rec.n_chan == len(COMPLEX_DATA_2)
|
||||
|
|
|
|||
|
|
@ -1,489 +0,0 @@
|
|||
"""Tests for orchestration campaign schema and YAML parsing."""
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
|
||||
import pytest
|
||||
import yaml
|
||||
|
||||
from ria_toolkit_oss.orchestration.campaign import (
|
||||
CampaignConfig,
|
||||
CaptureStep,
|
||||
QAConfig,
|
||||
RecorderConfig,
|
||||
parse_bandwidth_mhz,
|
||||
parse_duration,
|
||||
parse_frequency,
|
||||
parse_gain,
|
||||
)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# parse_duration
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestParseDuration:
|
||||
def test_seconds_suffix(self):
|
||||
assert parse_duration("30s") == 30.0
|
||||
|
||||
def test_seconds_suffix_long(self):
|
||||
assert parse_duration("30sec") == 30.0
|
||||
|
||||
def test_minutes_suffix(self):
|
||||
assert parse_duration("1.5m") == 90.0
|
||||
|
||||
def test_minutes_suffix_long(self):
|
||||
assert parse_duration("2min") == 120.0
|
||||
|
||||
def test_hours_suffix(self):
|
||||
assert parse_duration("2h") == 7200.0
|
||||
|
||||
def test_hours_suffix_long(self):
|
||||
assert parse_duration("1hr") == 3600.0
|
||||
|
||||
def test_numeric_int(self):
|
||||
assert parse_duration(45) == 45.0
|
||||
|
||||
def test_numeric_float(self):
|
||||
assert parse_duration(1.5) == 1.5
|
||||
|
||||
def test_bare_number_string(self):
|
||||
# No unit → treated as seconds
|
||||
assert parse_duration("60") == 60.0
|
||||
|
||||
def test_invalid_raises(self):
|
||||
with pytest.raises(ValueError):
|
||||
parse_duration("two minutes")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# parse_frequency
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestParseFrequency:
|
||||
def test_ghz(self):
|
||||
assert parse_frequency("2.45GHz") == pytest.approx(2.45e9)
|
||||
|
||||
def test_mhz(self):
|
||||
assert parse_frequency("40MHz") == pytest.approx(40e6)
|
||||
|
||||
def test_khz(self):
|
||||
assert parse_frequency("433k") == pytest.approx(433e3)
|
||||
|
||||
def test_scientific_notation_string(self):
|
||||
assert parse_frequency("915e6") == pytest.approx(915e6)
|
||||
|
||||
def test_numeric_float(self):
|
||||
assert parse_frequency(2.45e9) == pytest.approx(2.45e9)
|
||||
|
||||
def test_numeric_int(self):
|
||||
assert parse_frequency(1000000) == pytest.approx(1e6)
|
||||
|
||||
def test_hz_suffix_optional(self):
|
||||
# "40M" and "40MHz" should both work
|
||||
assert parse_frequency("40M") == pytest.approx(40e6)
|
||||
assert parse_frequency("40MHz") == pytest.approx(40e6)
|
||||
|
||||
def test_invalid_raises(self):
|
||||
with pytest.raises(ValueError):
|
||||
parse_frequency("two point four gigs")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# parse_gain
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestParseGain:
|
||||
def test_db_suffix(self):
|
||||
assert parse_gain("40dB") == pytest.approx(40.0)
|
||||
|
||||
def test_db_suffix_lowercase(self):
|
||||
assert parse_gain("32db") == pytest.approx(32.0)
|
||||
|
||||
def test_auto(self):
|
||||
assert parse_gain("auto") == "auto"
|
||||
|
||||
def test_auto_case_insensitive(self):
|
||||
assert parse_gain("AUTO") == "auto"
|
||||
|
||||
def test_numeric_int(self):
|
||||
assert parse_gain(32) == pytest.approx(32.0)
|
||||
|
||||
def test_numeric_float(self):
|
||||
assert parse_gain(32.5) == pytest.approx(32.5)
|
||||
|
||||
def test_invalid_raises(self):
|
||||
with pytest.raises(ValueError):
|
||||
parse_gain("high")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# parse_bandwidth_mhz
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestParseBandwidthMhz:
|
||||
def test_mhz_suffix(self):
|
||||
assert parse_bandwidth_mhz("20MHz") == pytest.approx(20.0)
|
||||
|
||||
def test_numeric(self):
|
||||
assert parse_bandwidth_mhz(40) == pytest.approx(40.0)
|
||||
|
||||
def test_none(self):
|
||||
assert parse_bandwidth_mhz(None) is None
|
||||
|
||||
def test_invalid_raises(self):
|
||||
with pytest.raises(ValueError):
|
||||
parse_bandwidth_mhz("wide")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# CaptureStep.from_dict
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestCaptureStep:
|
||||
def test_wifi_step_auto_label(self):
|
||||
d = {"channel": 6, "bandwidth": "20MHz", "traffic": "iperf_udp", "duration": "30s"}
|
||||
step = CaptureStep.from_dict(d)
|
||||
assert step.duration == 30.0
|
||||
assert step.channel == 6
|
||||
assert step.bandwidth_mhz == 20.0
|
||||
assert step.traffic == "iperf_udp"
|
||||
assert step.label == "ch06_20mhz_iperf_udp"
|
||||
|
||||
def test_explicit_label(self):
|
||||
d = {"channel": 1, "bandwidth": "20MHz", "traffic": "idle", "duration": "30s", "label": "my_label"}
|
||||
step = CaptureStep.from_dict(d)
|
||||
assert step.label == "my_label"
|
||||
|
||||
def test_fallback_label(self):
|
||||
# No channel/bandwidth/traffic → label falls back to "capture"
|
||||
d = {"duration": "10s"}
|
||||
step = CaptureStep.from_dict(d)
|
||||
assert step.label == "capture"
|
||||
|
||||
def test_power_parsed(self):
|
||||
d = {"channel": 6, "bandwidth": "20MHz", "traffic": "idle", "duration": "30s", "power": "15dBm"}
|
||||
step = CaptureStep.from_dict(d)
|
||||
assert step.power_dbm == pytest.approx(15.0)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# RecorderConfig.from_dict
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestRecorderConfig:
|
||||
def test_basic(self):
|
||||
d = {"device": "usrp_b210", "center_freq": "2.45GHz", "sample_rate": "40MHz", "gain": "40dB"}
|
||||
rec = RecorderConfig.from_dict(d)
|
||||
assert rec.device == "usrp_b210"
|
||||
assert rec.center_freq == pytest.approx(2.45e9)
|
||||
assert rec.sample_rate == pytest.approx(40e6)
|
||||
assert rec.gain == pytest.approx(40.0)
|
||||
assert rec.bandwidth is None
|
||||
|
||||
def test_auto_gain(self):
|
||||
d = {"device": "pluto", "center_freq": "2.45GHz", "sample_rate": "20MHz", "gain": "auto"}
|
||||
rec = RecorderConfig.from_dict(d)
|
||||
assert rec.gain == "auto"
|
||||
|
||||
def test_bandwidth_set(self):
|
||||
d = {"device": "pluto", "center_freq": "2.45GHz", "sample_rate": "20MHz", "gain": 32, "bandwidth": "20MHz"}
|
||||
rec = RecorderConfig.from_dict(d)
|
||||
assert rec.bandwidth == pytest.approx(20e6)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# QAConfig.from_dict
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestQAConfig:
|
||||
def test_defaults(self):
|
||||
qa = QAConfig.from_dict({})
|
||||
assert qa.snr_threshold_db == pytest.approx(10.0)
|
||||
assert qa.min_duration_s == pytest.approx(25.0)
|
||||
assert qa.flag_for_review is True
|
||||
|
||||
def test_custom_values(self):
|
||||
d = {"snr_threshold": "15dB", "min_duration": "28s", "flag_for_review": False}
|
||||
qa = QAConfig.from_dict(d)
|
||||
assert qa.snr_threshold_db == pytest.approx(15.0)
|
||||
assert qa.min_duration_s == pytest.approx(28.0)
|
||||
assert qa.flag_for_review is False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# CampaignConfig.from_device_profile
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _write_device_profile(d: dict) -> str:
|
||||
"""Write a dict as YAML to a temp file and return the path."""
|
||||
f = tempfile.NamedTemporaryFile(mode="w", suffix=".yml", delete=False)
|
||||
yaml.dump(d, f)
|
||||
f.close()
|
||||
return f.name
|
||||
|
||||
|
||||
WIFI_PROFILE = {
|
||||
"device": {"name": "iPhone_13_WiFi", "type": "wifi"},
|
||||
"capture": {
|
||||
"channels": [1, 6, 11],
|
||||
"bandwidth": "20MHz",
|
||||
"traffic_patterns": ["idle", "ping", "iperf_udp"],
|
||||
"duration_per_config": "30s",
|
||||
"script": "./scripts/wifi_control.sh",
|
||||
},
|
||||
"recorder": {
|
||||
"device": "usrp_b210",
|
||||
"center_freq": "2.45GHz",
|
||||
"sample_rate": "40MHz",
|
||||
"gain": "auto",
|
||||
},
|
||||
"output": {"path": "/tmp/test_recordings", "device_id": "iphone13_wifi_001"},
|
||||
}
|
||||
|
||||
BT_PROFILE = {
|
||||
"device": {"name": "AirPods_Pro", "type": "bluetooth"},
|
||||
"capture": {
|
||||
"traffic_patterns": ["idle", "audio_stream", "data_transfer"],
|
||||
"duration_per_config": "30s",
|
||||
},
|
||||
"recorder": {
|
||||
"device": "usrp_b210",
|
||||
"center_freq": "2.45GHz",
|
||||
"sample_rate": "40MHz",
|
||||
"gain": "auto",
|
||||
},
|
||||
"output": {"path": "/tmp/test_recordings", "device_id": "airpods_pro_bt_001"},
|
||||
}
|
||||
|
||||
|
||||
class TestDeviceProfileParsing:
|
||||
def test_wifi_schedule_count(self):
|
||||
"""WiFi: 3 channels × 3 traffic = 9 steps."""
|
||||
path = _write_device_profile(WIFI_PROFILE)
|
||||
try:
|
||||
cfg = CampaignConfig.from_device_profile(path)
|
||||
finally:
|
||||
os.unlink(path)
|
||||
assert len(cfg.transmitters) == 1
|
||||
assert len(cfg.transmitters[0].schedule) == 9
|
||||
|
||||
def test_wifi_campaign_name(self):
|
||||
path = _write_device_profile(WIFI_PROFILE)
|
||||
try:
|
||||
cfg = CampaignConfig.from_device_profile(path)
|
||||
finally:
|
||||
os.unlink(path)
|
||||
assert cfg.name == "enroll_iphone13_wifi_001"
|
||||
|
||||
def test_wifi_step_labels(self):
|
||||
path = _write_device_profile(WIFI_PROFILE)
|
||||
try:
|
||||
cfg = CampaignConfig.from_device_profile(path)
|
||||
finally:
|
||||
os.unlink(path)
|
||||
labels = [s.label for s in cfg.transmitters[0].schedule]
|
||||
assert "ch01_20mhz_idle" in labels
|
||||
assert "ch06_20mhz_ping" in labels
|
||||
assert "ch11_20mhz_iperf_udp" in labels
|
||||
|
||||
def test_wifi_step_ordering(self):
|
||||
"""Steps iterate channels first, then traffic."""
|
||||
path = _write_device_profile(WIFI_PROFILE)
|
||||
try:
|
||||
cfg = CampaignConfig.from_device_profile(path)
|
||||
finally:
|
||||
os.unlink(path)
|
||||
steps = cfg.transmitters[0].schedule
|
||||
assert steps[0].channel == 1 and steps[0].traffic == "idle"
|
||||
assert steps[1].channel == 1 and steps[1].traffic == "ping"
|
||||
assert steps[3].channel == 6 and steps[3].traffic == "idle"
|
||||
assert steps[8].channel == 11 and steps[8].traffic == "iperf_udp"
|
||||
|
||||
def test_wifi_step_duration(self):
|
||||
path = _write_device_profile(WIFI_PROFILE)
|
||||
try:
|
||||
cfg = CampaignConfig.from_device_profile(path)
|
||||
finally:
|
||||
os.unlink(path)
|
||||
for step in cfg.transmitters[0].schedule:
|
||||
assert step.duration == pytest.approx(30.0)
|
||||
|
||||
def test_wifi_bandwidth(self):
|
||||
path = _write_device_profile(WIFI_PROFILE)
|
||||
try:
|
||||
cfg = CampaignConfig.from_device_profile(path)
|
||||
finally:
|
||||
os.unlink(path)
|
||||
for step in cfg.transmitters[0].schedule:
|
||||
assert step.bandwidth_mhz == pytest.approx(20.0)
|
||||
|
||||
def test_wifi_recorder(self):
|
||||
path = _write_device_profile(WIFI_PROFILE)
|
||||
try:
|
||||
cfg = CampaignConfig.from_device_profile(path)
|
||||
finally:
|
||||
os.unlink(path)
|
||||
assert cfg.recorder.device == "usrp_b210"
|
||||
assert cfg.recorder.center_freq == pytest.approx(2.45e9)
|
||||
assert cfg.recorder.sample_rate == pytest.approx(40e6)
|
||||
assert cfg.recorder.gain == "auto"
|
||||
|
||||
def test_wifi_total_capture_time(self):
|
||||
path = _write_device_profile(WIFI_PROFILE)
|
||||
try:
|
||||
cfg = CampaignConfig.from_device_profile(path)
|
||||
finally:
|
||||
os.unlink(path)
|
||||
assert cfg.total_capture_time_s() == pytest.approx(270.0) # 9 × 30s
|
||||
|
||||
def test_bt_schedule_count(self):
|
||||
"""BT: no channels, 3 traffic patterns = 3 steps."""
|
||||
path = _write_device_profile(BT_PROFILE)
|
||||
try:
|
||||
cfg = CampaignConfig.from_device_profile(path)
|
||||
finally:
|
||||
os.unlink(path)
|
||||
assert len(cfg.transmitters[0].schedule) == 3
|
||||
|
||||
def test_bt_no_channel(self):
|
||||
path = _write_device_profile(BT_PROFILE)
|
||||
try:
|
||||
cfg = CampaignConfig.from_device_profile(path)
|
||||
finally:
|
||||
os.unlink(path)
|
||||
for step in cfg.transmitters[0].schedule:
|
||||
assert step.channel is None
|
||||
|
||||
def test_bt_step_labels(self):
|
||||
path = _write_device_profile(BT_PROFILE)
|
||||
try:
|
||||
cfg = CampaignConfig.from_device_profile(path)
|
||||
finally:
|
||||
os.unlink(path)
|
||||
labels = [s.label for s in cfg.transmitters[0].schedule]
|
||||
assert labels == ["idle", "audio_stream", "data_transfer"]
|
||||
|
||||
def test_missing_file_raises(self):
|
||||
with pytest.raises(FileNotFoundError):
|
||||
CampaignConfig.from_device_profile("/nonexistent/path/profile.yml")
|
||||
|
||||
def test_invalid_yaml_raises(self):
|
||||
with tempfile.NamedTemporaryFile(mode="w", suffix=".yml", delete=False) as f:
|
||||
f.write(": bad: yaml: [\n")
|
||||
path = f.name
|
||||
try:
|
||||
with pytest.raises(ValueError, match="Invalid YAML"):
|
||||
CampaignConfig.from_device_profile(path)
|
||||
finally:
|
||||
os.unlink(path)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# CampaignConfig.from_yaml (full campaign format)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
FULL_CAMPAIGN = {
|
||||
"campaign": {"name": "wifi_capture_001", "mode": "controlled_testbed"},
|
||||
"transmitters": [
|
||||
{
|
||||
"id": "laptop_wifi",
|
||||
"type": "wifi",
|
||||
"control_method": "external_script",
|
||||
"script": "./scripts/wifi_control.sh",
|
||||
"device": "/dev/wlan0",
|
||||
"schedule": [
|
||||
{"channel": 6, "bandwidth": "20MHz", "traffic": "iperf_tcp", "duration": "30s"},
|
||||
{"channel": 36, "bandwidth": "40MHz", "traffic": "ping_flood", "duration": "30s"},
|
||||
],
|
||||
}
|
||||
],
|
||||
"recorder": {
|
||||
"device": "usrp_b210",
|
||||
"center_freq": "2.45GHz",
|
||||
"sample_rate": "20MHz",
|
||||
"gain": "40dB",
|
||||
},
|
||||
"qa": {"snr_threshold": "10dB", "min_duration": "25s", "flag_for_review": True},
|
||||
"output": {"format": "sigmf", "path": "./recordings"},
|
||||
}
|
||||
|
||||
|
||||
class TestFullCampaignParsing:
|
||||
def test_name(self):
|
||||
path = _write_device_profile(FULL_CAMPAIGN)
|
||||
try:
|
||||
cfg = CampaignConfig.from_yaml(path)
|
||||
finally:
|
||||
os.unlink(path)
|
||||
assert cfg.name == "wifi_capture_001"
|
||||
|
||||
def test_mode(self):
|
||||
path = _write_device_profile(FULL_CAMPAIGN)
|
||||
try:
|
||||
cfg = CampaignConfig.from_yaml(path)
|
||||
finally:
|
||||
os.unlink(path)
|
||||
assert cfg.mode == "controlled_testbed"
|
||||
|
||||
def test_transmitter_id(self):
|
||||
path = _write_device_profile(FULL_CAMPAIGN)
|
||||
try:
|
||||
cfg = CampaignConfig.from_yaml(path)
|
||||
finally:
|
||||
os.unlink(path)
|
||||
assert cfg.transmitters[0].id == "laptop_wifi"
|
||||
assert cfg.transmitters[0].control_method == "external_script"
|
||||
assert cfg.transmitters[0].script == "./scripts/wifi_control.sh"
|
||||
|
||||
def test_schedule_count(self):
|
||||
path = _write_device_profile(FULL_CAMPAIGN)
|
||||
try:
|
||||
cfg = CampaignConfig.from_yaml(path)
|
||||
finally:
|
||||
os.unlink(path)
|
||||
assert len(cfg.transmitters[0].schedule) == 2
|
||||
|
||||
def test_qa_config(self):
|
||||
path = _write_device_profile(FULL_CAMPAIGN)
|
||||
try:
|
||||
cfg = CampaignConfig.from_yaml(path)
|
||||
finally:
|
||||
os.unlink(path)
|
||||
assert cfg.qa.snr_threshold_db == pytest.approx(10.0)
|
||||
assert cfg.qa.min_duration_s == pytest.approx(25.0)
|
||||
assert cfg.qa.flag_for_review is True
|
||||
|
||||
def test_total_steps(self):
|
||||
path = _write_device_profile(FULL_CAMPAIGN)
|
||||
try:
|
||||
cfg = CampaignConfig.from_yaml(path)
|
||||
finally:
|
||||
os.unlink(path)
|
||||
assert cfg.total_steps() == 2
|
||||
|
||||
def test_no_transmitters_raises(self):
|
||||
bad = dict(FULL_CAMPAIGN)
|
||||
bad["transmitters"] = []
|
||||
path = _write_device_profile(bad)
|
||||
try:
|
||||
with pytest.raises(ValueError, match="at least one transmitter"):
|
||||
CampaignConfig.from_yaml(path)
|
||||
finally:
|
||||
os.unlink(path)
|
||||
|
||||
def test_missing_recorder_raises(self):
|
||||
bad = {k: v for k, v in FULL_CAMPAIGN.items() if k != "recorder"}
|
||||
path = _write_device_profile(bad)
|
||||
try:
|
||||
with pytest.raises((KeyError, ValueError)):
|
||||
CampaignConfig.from_yaml(path)
|
||||
finally:
|
||||
os.unlink(path)
|
||||
|
|
@ -1,274 +0,0 @@
|
|||
"""Tests for the `ria campaign` CLI commands."""
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
|
||||
import yaml
|
||||
from click.testing import CliRunner
|
||||
|
||||
from ria_toolkit_oss_cli.cli import cli
|
||||
|
||||
|
||||
def _write_yaml(d: dict, suffix=".yml") -> str:
|
||||
f = tempfile.NamedTemporaryFile(mode="w", suffix=suffix, delete=False)
|
||||
yaml.dump(d, f)
|
||||
f.close()
|
||||
return f.name
|
||||
|
||||
|
||||
WIFI_PROFILE = {
|
||||
"device": {"name": "iPhone_13_WiFi", "type": "wifi"},
|
||||
"capture": {
|
||||
"channels": [1, 6, 11],
|
||||
"bandwidth": "20MHz",
|
||||
"traffic_patterns": ["idle", "ping", "iperf_udp"],
|
||||
"duration_per_config": "30s",
|
||||
},
|
||||
"recorder": {
|
||||
"device": "usrp_b210",
|
||||
"center_freq": "2.45GHz",
|
||||
"sample_rate": "40MHz",
|
||||
"gain": "auto",
|
||||
},
|
||||
"output": {"path": "/tmp/test_enroll", "device_id": "iphone13_wifi_001"},
|
||||
}
|
||||
|
||||
FULL_CAMPAIGN = {
|
||||
"campaign": {"name": "wifi_capture_001", "mode": "controlled_testbed"},
|
||||
"transmitters": [
|
||||
{
|
||||
"id": "laptop_wifi",
|
||||
"type": "wifi",
|
||||
"control_method": "external_script",
|
||||
"schedule": [
|
||||
{"channel": 6, "bandwidth": "20MHz", "traffic": "iperf_udp", "duration": "30s"},
|
||||
{"channel": 11, "bandwidth": "20MHz", "traffic": "idle", "duration": "30s"},
|
||||
],
|
||||
}
|
||||
],
|
||||
"recorder": {
|
||||
"device": "usrp_b210",
|
||||
"center_freq": "2.45GHz",
|
||||
"sample_rate": "20MHz",
|
||||
"gain": "40dB",
|
||||
},
|
||||
"qa": {"snr_threshold": "10dB", "min_duration": "25s", "flag_for_review": True},
|
||||
"output": {"format": "sigmf", "path": "/tmp/test_campaign"},
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# ria campaign --help
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestCampaignHelp:
|
||||
def test_campaign_help(self):
|
||||
runner = CliRunner()
|
||||
result = runner.invoke(cli, ["campaign", "--help"])
|
||||
assert result.exit_code == 0
|
||||
assert "campaign" in result.output.lower()
|
||||
|
||||
def test_subcommands_listed(self):
|
||||
runner = CliRunner()
|
||||
result = runner.invoke(cli, ["campaign", "--help"])
|
||||
assert result.exit_code == 0
|
||||
for sub in ("validate", "enroll", "run"):
|
||||
assert sub in result.output
|
||||
|
||||
def test_validate_help(self):
|
||||
runner = CliRunner()
|
||||
result = runner.invoke(cli, ["campaign", "validate", "--help"])
|
||||
assert result.exit_code == 0
|
||||
|
||||
def test_enroll_help(self):
|
||||
runner = CliRunner()
|
||||
result = runner.invoke(cli, ["campaign", "enroll", "--help"])
|
||||
assert result.exit_code == 0
|
||||
|
||||
def test_run_help(self):
|
||||
runner = CliRunner()
|
||||
result = runner.invoke(cli, ["campaign", "run", "--help"])
|
||||
assert result.exit_code == 0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# ria campaign validate
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestCampaignValidate:
|
||||
def test_validate_device_profile(self):
|
||||
path = _write_yaml(WIFI_PROFILE)
|
||||
try:
|
||||
runner = CliRunner()
|
||||
result = runner.invoke(cli, ["campaign", "validate", "--config", path, "--profile"])
|
||||
assert result.exit_code == 0
|
||||
assert "✓" in result.output or "valid" in result.output.lower()
|
||||
finally:
|
||||
os.unlink(path)
|
||||
|
||||
def test_validate_shows_campaign_name(self):
|
||||
path = _write_yaml(WIFI_PROFILE)
|
||||
try:
|
||||
runner = CliRunner()
|
||||
result = runner.invoke(cli, ["campaign", "validate", "--config", path, "--profile"])
|
||||
assert "enroll_iphone13_wifi_001" in result.output
|
||||
finally:
|
||||
os.unlink(path)
|
||||
|
||||
def test_validate_shows_step_count(self):
|
||||
path = _write_yaml(WIFI_PROFILE)
|
||||
try:
|
||||
runner = CliRunner()
|
||||
result = runner.invoke(cli, ["campaign", "validate", "--config", path, "--profile"])
|
||||
assert "9" in result.output # 9 total steps
|
||||
finally:
|
||||
os.unlink(path)
|
||||
|
||||
def test_validate_shows_capture_time(self):
|
||||
path = _write_yaml(WIFI_PROFILE)
|
||||
try:
|
||||
runner = CliRunner()
|
||||
result = runner.invoke(cli, ["campaign", "validate", "--config", path, "--profile"])
|
||||
assert "270" in result.output # 270s total
|
||||
finally:
|
||||
os.unlink(path)
|
||||
|
||||
def test_validate_full_campaign(self):
|
||||
path = _write_yaml(FULL_CAMPAIGN)
|
||||
try:
|
||||
runner = CliRunner()
|
||||
result = runner.invoke(cli, ["campaign", "validate", "--config", path])
|
||||
assert result.exit_code == 0
|
||||
assert "wifi_capture_001" in result.output
|
||||
finally:
|
||||
os.unlink(path)
|
||||
|
||||
def test_validate_shows_steps(self):
|
||||
path = _write_yaml(WIFI_PROFILE)
|
||||
try:
|
||||
runner = CliRunner()
|
||||
result = runner.invoke(cli, ["campaign", "validate", "--config", path, "--profile"])
|
||||
assert "ch01_20mhz_idle" in result.output
|
||||
assert "ch06_20mhz_ping" in result.output
|
||||
assert "ch11_20mhz_iperf_udp" in result.output
|
||||
finally:
|
||||
os.unlink(path)
|
||||
|
||||
def test_validate_missing_file(self):
|
||||
runner = CliRunner()
|
||||
result = runner.invoke(cli, ["campaign", "validate", "--config", "/nonexistent/file.yml"])
|
||||
assert result.exit_code != 0
|
||||
|
||||
def test_validate_bad_yaml(self):
|
||||
with tempfile.NamedTemporaryFile(mode="w", suffix=".yml", delete=False) as f:
|
||||
f.write(": broken yaml [\n")
|
||||
path = f.name
|
||||
try:
|
||||
runner = CliRunner()
|
||||
result = runner.invoke(cli, ["campaign", "validate", "--config", path])
|
||||
assert result.exit_code != 0
|
||||
finally:
|
||||
os.unlink(path)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# ria campaign enroll --dry-run
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestCampaignEnrollDryRun:
|
||||
def test_dry_run_exits_cleanly(self):
|
||||
path = _write_yaml(WIFI_PROFILE)
|
||||
try:
|
||||
runner = CliRunner()
|
||||
result = runner.invoke(cli, ["campaign", "enroll", "--config", path, "--dry-run"])
|
||||
assert result.exit_code == 0
|
||||
finally:
|
||||
os.unlink(path)
|
||||
|
||||
def test_dry_run_shows_campaign_info(self):
|
||||
path = _write_yaml(WIFI_PROFILE)
|
||||
try:
|
||||
runner = CliRunner()
|
||||
result = runner.invoke(cli, ["campaign", "enroll", "--config", path, "--dry-run"])
|
||||
assert "enroll_iphone13_wifi_001" in result.output
|
||||
assert "9" in result.output
|
||||
finally:
|
||||
os.unlink(path)
|
||||
|
||||
def test_dry_run_does_not_capture(self):
|
||||
"""Dry run should not create any output files."""
|
||||
path = _write_yaml(WIFI_PROFILE)
|
||||
try:
|
||||
runner = CliRunner()
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
runner.invoke(
|
||||
cli,
|
||||
["campaign", "enroll", "--config", path, "--output", tmpdir, "--dry-run"],
|
||||
)
|
||||
# No .sigmf-data files should have been created
|
||||
sigmf_files = list(os.walk(tmpdir))
|
||||
all_files = [f for _, _, files in sigmf_files for f in files]
|
||||
assert not any(f.endswith(".sigmf-data") for f in all_files)
|
||||
finally:
|
||||
os.unlink(path)
|
||||
|
||||
def test_dry_run_output_override(self):
|
||||
path = _write_yaml(WIFI_PROFILE)
|
||||
try:
|
||||
runner = CliRunner()
|
||||
result = runner.invoke(
|
||||
cli,
|
||||
["campaign", "enroll", "--config", path, "--output", "/tmp/custom_out", "--dry-run"],
|
||||
)
|
||||
assert result.exit_code == 0
|
||||
assert "Dry run" in result.output
|
||||
finally:
|
||||
os.unlink(path)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# ria campaign run --dry-run
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestCampaignRunDryRun:
|
||||
def test_dry_run_exits_cleanly(self):
|
||||
path = _write_yaml(FULL_CAMPAIGN)
|
||||
try:
|
||||
runner = CliRunner()
|
||||
result = runner.invoke(cli, ["campaign", "run", "--config", path, "--dry-run"])
|
||||
assert result.exit_code == 0
|
||||
finally:
|
||||
os.unlink(path)
|
||||
|
||||
def test_dry_run_shows_campaign_name(self):
|
||||
path = _write_yaml(FULL_CAMPAIGN)
|
||||
try:
|
||||
runner = CliRunner()
|
||||
result = runner.invoke(cli, ["campaign", "run", "--config", path, "--dry-run"])
|
||||
assert "wifi_capture_001" in result.output
|
||||
finally:
|
||||
os.unlink(path)
|
||||
|
||||
def test_dry_run_does_not_create_report(self):
|
||||
path = _write_yaml(FULL_CAMPAIGN)
|
||||
try:
|
||||
runner = CliRunner()
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
report_path = os.path.join(tmpdir, "qa_report.json")
|
||||
result = runner.invoke(
|
||||
cli,
|
||||
["campaign", "run", "--config", path, "--dry-run", "--report", report_path],
|
||||
)
|
||||
assert result.exit_code == 0
|
||||
assert not os.path.exists(report_path)
|
||||
finally:
|
||||
os.unlink(path)
|
||||
|
||||
def test_missing_config_fails(self):
|
||||
runner = CliRunner()
|
||||
result = runner.invoke(cli, ["campaign", "run", "--config", "/nonexistent.yml"])
|
||||
assert result.exit_code != 0
|
||||
|
|
@ -1,145 +0,0 @@
|
|||
"""Tests for orchestration labeler."""
|
||||
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from ria_toolkit_oss.datatypes.recording import Recording
|
||||
from ria_toolkit_oss.orchestration.campaign import CaptureStep
|
||||
from ria_toolkit_oss.orchestration.labeler import build_output_filename, label_recording
|
||||
|
||||
|
||||
def _simple_recording() -> Recording:
|
||||
sr = 1e6
|
||||
n = 1000
|
||||
data = np.ones(n, dtype=np.complex64)
|
||||
return Recording(data, metadata={"sample_rate": sr, "center_frequency": 2.45e9})
|
||||
|
||||
|
||||
def _wifi_step() -> CaptureStep:
|
||||
return CaptureStep(
|
||||
duration=30.0,
|
||||
label="ch06_20mhz_idle",
|
||||
channel=6,
|
||||
bandwidth_mhz=20.0,
|
||||
traffic="idle",
|
||||
)
|
||||
|
||||
|
||||
def _bt_step() -> CaptureStep:
|
||||
return CaptureStep(
|
||||
duration=30.0,
|
||||
label="audio_stream",
|
||||
traffic="audio_stream",
|
||||
connection_interval_ms=7.5,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# label_recording
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestLabelRecording:
|
||||
def test_device_id_set(self):
|
||||
rec = label_recording(_simple_recording(), "iphone13_001", _wifi_step(), time.time())
|
||||
assert rec.metadata["device_id"] == "iphone13_001"
|
||||
|
||||
def test_capture_timestamp_set(self):
|
||||
ts = time.time()
|
||||
rec = label_recording(_simple_recording(), "iphone13_001", _wifi_step(), ts)
|
||||
assert rec.metadata["capture_timestamp"] == pytest.approx(ts, abs=1.0)
|
||||
|
||||
def test_step_label_set(self):
|
||||
rec = label_recording(_simple_recording(), "iphone13_001", _wifi_step(), time.time())
|
||||
assert rec.metadata["step_label"] == "ch06_20mhz_idle"
|
||||
|
||||
def test_step_duration_set(self):
|
||||
rec = label_recording(_simple_recording(), "iphone13_001", _wifi_step(), time.time())
|
||||
assert rec.metadata["step_duration_s"] == pytest.approx(30.0)
|
||||
|
||||
def test_campaign_name_optional(self):
|
||||
rec = label_recording(_simple_recording(), "iphone13_001", _wifi_step(), time.time())
|
||||
assert "campaign" not in rec.metadata
|
||||
|
||||
def test_campaign_name_when_provided(self):
|
||||
rec = label_recording(
|
||||
_simple_recording(), "iphone13_001", _wifi_step(), time.time(), campaign_name="test_campaign"
|
||||
)
|
||||
assert rec.metadata["campaign"] == "test_campaign"
|
||||
|
||||
def test_wifi_channel_set(self):
|
||||
rec = label_recording(_simple_recording(), "iphone13_001", _wifi_step(), time.time())
|
||||
assert rec.metadata["wifi_channel"] == 6
|
||||
|
||||
def test_wifi_bandwidth_set(self):
|
||||
rec = label_recording(_simple_recording(), "iphone13_001", _wifi_step(), time.time())
|
||||
assert rec.metadata["wifi_bandwidth_mhz"] == pytest.approx(20.0)
|
||||
|
||||
def test_traffic_pattern_set(self):
|
||||
rec = label_recording(_simple_recording(), "iphone13_001", _wifi_step(), time.time())
|
||||
assert rec.metadata["traffic_pattern"] == "idle"
|
||||
|
||||
def test_bt_connection_interval_set(self):
|
||||
rec = label_recording(_simple_recording(), "airpods_001", _bt_step(), time.time())
|
||||
assert rec.metadata["bt_connection_interval_ms"] == pytest.approx(7.5)
|
||||
|
||||
def test_no_channel_key_for_bt(self):
|
||||
"""BT steps with no channel should not add wifi_channel to metadata."""
|
||||
rec = label_recording(_simple_recording(), "airpods_001", _bt_step(), time.time())
|
||||
assert "wifi_channel" not in rec.metadata
|
||||
|
||||
def test_no_bandwidth_key_for_bt(self):
|
||||
rec = label_recording(_simple_recording(), "airpods_001", _bt_step(), time.time())
|
||||
assert "wifi_bandwidth_mhz" not in rec.metadata
|
||||
|
||||
def test_power_dbm_set(self):
|
||||
step = CaptureStep(duration=30.0, label="test", traffic="idle", power_dbm=15.0)
|
||||
rec = label_recording(_simple_recording(), "dev_001", step, time.time())
|
||||
assert rec.metadata["tx_power_dbm"] == pytest.approx(15.0)
|
||||
|
||||
def test_no_power_key_when_unset(self):
|
||||
rec = label_recording(_simple_recording(), "iphone13_001", _wifi_step(), time.time())
|
||||
assert "tx_power_dbm" not in rec.metadata
|
||||
|
||||
def test_returns_same_recording(self):
|
||||
"""label_recording should mutate and return the same Recording object."""
|
||||
rec = _simple_recording()
|
||||
result = label_recording(rec, "iphone13_001", _wifi_step(), time.time())
|
||||
assert result is rec
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# build_output_filename
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestBuildOutputFilename:
|
||||
def test_basic_wifi(self):
|
||||
step = CaptureStep(duration=30.0, label="ch06_20mhz_idle")
|
||||
fn = build_output_filename("iphone13_wifi_001", step)
|
||||
assert fn == "iphone13_wifi_001/ch06_20mhz_idle"
|
||||
|
||||
def test_bt_step(self):
|
||||
step = CaptureStep(duration=30.0, label="audio_stream")
|
||||
fn = build_output_filename("airpods_pro_bt_001", step)
|
||||
assert fn == "airpods_pro_bt_001/audio_stream"
|
||||
|
||||
def test_spaces_in_device_id_replaced(self):
|
||||
step = CaptureStep(duration=30.0, label="idle")
|
||||
fn = build_output_filename("my device", step)
|
||||
assert " " not in fn
|
||||
assert fn == "my_device/idle"
|
||||
|
||||
def test_slashes_in_label_replaced(self):
|
||||
step = CaptureStep(duration=30.0, label="ch/6/idle")
|
||||
fn = build_output_filename("dev_001", step)
|
||||
assert "/" not in fn.split("/", 1)[1] # only the separator slash should remain
|
||||
|
||||
def test_path_structure(self):
|
||||
"""Filename should be exactly '<device_id>/<label>' (one level of nesting)."""
|
||||
step = CaptureStep(duration=30.0, label="idle")
|
||||
fn = build_output_filename("dev_001", step)
|
||||
parts = fn.split("/")
|
||||
assert len(parts) == 2
|
||||
|
|
@ -1,192 +0,0 @@
|
|||
"""Tests for orchestration QA metrics."""
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from ria_toolkit_oss.datatypes.recording import Recording
|
||||
from ria_toolkit_oss.orchestration.campaign import QAConfig
|
||||
from ria_toolkit_oss.orchestration.qa import QAResult, check_recording, estimate_snr_db
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fixtures
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_recording(n_samples: int, sample_rate: float, signal: np.ndarray) -> Recording:
|
||||
return Recording(
|
||||
signal.astype(np.complex64),
|
||||
metadata={"sample_rate": sample_rate, "center_frequency": 2.45e9},
|
||||
)
|
||||
|
||||
|
||||
def _tone(n: int, sr: float, freq_hz: float = 100e3, amplitude: float = 0.5) -> np.ndarray:
|
||||
t = np.arange(n) / sr
|
||||
return (np.exp(1j * 2 * np.pi * freq_hz * t) * amplitude).astype(np.complex64)
|
||||
|
||||
|
||||
def _noise(n: int, amplitude: float = 0.001) -> np.ndarray:
|
||||
rng = np.random.default_rng(42)
|
||||
return ((rng.standard_normal(n) + 1j * rng.standard_normal(n)) * amplitude).astype(np.complex64)
|
||||
|
||||
|
||||
DEFAULT_QA = QAConfig(snr_threshold_db=10.0, min_duration_s=25.0, flag_for_review=True)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# estimate_snr_db
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestEstimateSnrDb:
|
||||
def test_high_snr_tone(self):
|
||||
sr = 1e6
|
||||
samples = _tone(int(sr * 1), sr)
|
||||
snr = estimate_snr_db(samples)
|
||||
assert snr > 20.0, f"Expected high SNR for clean tone, got {snr:.1f} dB"
|
||||
|
||||
def test_pure_noise_low_snr(self):
|
||||
sr = 1e6
|
||||
rng = np.random.default_rng(0)
|
||||
samples = (rng.standard_normal(int(sr)) + 1j * rng.standard_normal(int(sr))).astype(np.complex64)
|
||||
snr = estimate_snr_db(samples)
|
||||
# Pure noise should yield a low (possibly negative) SNR
|
||||
assert snr < 15.0, f"Expected low SNR for noise, got {snr:.1f} dB"
|
||||
|
||||
def test_snr_increases_with_amplitude(self):
|
||||
sr = 1e6
|
||||
n = int(sr)
|
||||
rng = np.random.default_rng(1)
|
||||
noise = (rng.standard_normal(n) + 1j * rng.standard_normal(n)).astype(np.complex64) * 0.01
|
||||
t = np.arange(n) / sr
|
||||
tone = np.exp(1j * 2 * np.pi * 100e3 * t).astype(np.complex64)
|
||||
|
||||
low_snr = estimate_snr_db(noise + tone * 0.1)
|
||||
high_snr = estimate_snr_db(noise + tone * 1.0)
|
||||
assert high_snr > low_snr
|
||||
|
||||
def test_short_input_still_works(self):
|
||||
# Input shorter than n_fft=4096 should not raise
|
||||
samples = _tone(512, 1e6)
|
||||
snr = estimate_snr_db(samples)
|
||||
assert np.isfinite(snr)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# check_recording — pass cases
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestCheckRecordingPass:
|
||||
def test_clean_tone_passes(self):
|
||||
sr = 1e6
|
||||
rec = _make_recording(int(sr * 30), sr, _tone(int(sr * 30), sr))
|
||||
result = check_recording(rec, DEFAULT_QA)
|
||||
assert result.passed is True
|
||||
assert result.flagged is False
|
||||
assert result.snr_db > 10.0
|
||||
assert abs(result.duration_s - 30.0) < 0.1
|
||||
|
||||
def test_duration_exactly_at_threshold(self):
|
||||
sr = 1e6
|
||||
n = int(sr * 25) # exactly at min_duration_s
|
||||
rec = _make_recording(n, sr, _tone(n, sr))
|
||||
result = check_recording(rec, DEFAULT_QA)
|
||||
assert result.flagged is False
|
||||
|
||||
def test_issues_empty_when_passing(self):
|
||||
sr = 1e6
|
||||
rec = _make_recording(int(sr * 30), sr, _tone(int(sr * 30), sr))
|
||||
result = check_recording(rec, DEFAULT_QA)
|
||||
assert result.issues == []
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# check_recording — flag cases
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestCheckRecordingFlag:
|
||||
def test_short_recording_flagged(self):
|
||||
sr = 1e6
|
||||
n = int(sr * 10) # shorter than 25s min
|
||||
rec = _make_recording(n, sr, _tone(n, sr))
|
||||
result = check_recording(rec, DEFAULT_QA)
|
||||
assert result.flagged is True
|
||||
assert any("Duration" in issue for issue in result.issues)
|
||||
|
||||
def test_low_snr_flagged(self):
|
||||
sr = 1e6
|
||||
n = int(sr * 30)
|
||||
rec = _make_recording(n, sr, _noise(n, amplitude=0.001))
|
||||
result = check_recording(rec, DEFAULT_QA)
|
||||
assert result.flagged is True
|
||||
assert any("SNR" in issue for issue in result.issues)
|
||||
|
||||
def test_flag_for_review_still_passes(self):
|
||||
"""With flag_for_review=True, flagged recordings are still marked passed."""
|
||||
sr = 1e6
|
||||
n = int(sr * 10) # short → will be flagged
|
||||
rec = _make_recording(n, sr, _tone(n, sr))
|
||||
qa = QAConfig(snr_threshold_db=10.0, min_duration_s=25.0, flag_for_review=True)
|
||||
result = check_recording(rec, qa)
|
||||
assert result.flagged is True
|
||||
assert result.passed is True # human review, not auto-reject
|
||||
|
||||
def test_flag_for_review_false_fails(self):
|
||||
"""With flag_for_review=False, a flagged recording is also marked failed."""
|
||||
sr = 1e6
|
||||
n = int(sr * 10)
|
||||
rec = _make_recording(n, sr, _tone(n, sr))
|
||||
qa = QAConfig(snr_threshold_db=10.0, min_duration_s=25.0, flag_for_review=False)
|
||||
result = check_recording(rec, qa)
|
||||
assert result.flagged is True
|
||||
assert result.passed is False
|
||||
|
||||
def test_multiple_issues_reported(self):
|
||||
"""Both short duration AND low SNR should both appear in issues list."""
|
||||
sr = 1e6
|
||||
n = int(sr * 5) # very short
|
||||
rec = _make_recording(n, sr, _noise(n, amplitude=0.0001))
|
||||
result = check_recording(rec, DEFAULT_QA)
|
||||
assert result.flagged is True
|
||||
assert len(result.issues) >= 2
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# check_recording — multichannel input
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestCheckRecordingMultichannel:
|
||||
def test_multichannel_recording(self):
|
||||
"""2-channel recording should evaluate channel 0 without error."""
|
||||
sr = 1e6
|
||||
n = int(sr * 30)
|
||||
ch0 = _tone(n, sr)
|
||||
ch1 = _tone(n, sr, freq_hz=200e3)
|
||||
data = np.stack([ch0, ch1]) # shape (2, N)
|
||||
rec = Recording(data, metadata={"sample_rate": sr, "center_frequency": 2.45e9})
|
||||
result = check_recording(rec, DEFAULT_QA)
|
||||
assert result.passed is True
|
||||
assert result.flagged is False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# QAResult.to_dict
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestQAResultToDict:
|
||||
def test_to_dict_keys(self):
|
||||
r = QAResult(passed=True, flagged=False, snr_db=18.3, duration_s=30.0)
|
||||
d = r.to_dict()
|
||||
assert set(d.keys()) == {"passed", "flagged", "snr_db", "duration_s", "issues"}
|
||||
|
||||
def test_to_dict_values(self):
|
||||
r = QAResult(passed=False, flagged=True, snr_db=7.5, duration_s=10.2, issues=["SNR below threshold"])
|
||||
d = r.to_dict()
|
||||
assert d["passed"] is False
|
||||
assert d["flagged"] is True
|
||||
assert d["snr_db"] == pytest.approx(7.5, abs=0.01)
|
||||
assert d["duration_s"] == pytest.approx(10.2, abs=0.01)
|
||||
assert d["issues"] == ["SNR below threshold"]
|
||||
|
|
@ -1,463 +0,0 @@
|
|||
"""Tests for the RT-OSS HTTP server.
|
||||
|
||||
Covers: auth, inference lifecycle (without SDR/ONNX hardware), orchestrator
|
||||
lifecycle (with mocked executor), and state helpers.
|
||||
|
||||
``start_inference`` and ``_inference_loop`` require real SDR hardware and an
|
||||
ONNX model file — those are integration tests left for hardware-in-the-loop CI.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import threading
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
import ria_toolkit_oss.server.state as state_module
|
||||
from ria_toolkit_oss.server.app import create_app
|
||||
from ria_toolkit_oss.server.state import CampaignState, InferenceState, set_inference
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fixtures
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def reset_state():
|
||||
"""Wipe global server state before and after every test."""
|
||||
state_module._inference = None
|
||||
state_module._campaigns.clear()
|
||||
yield
|
||||
state_module._inference = None
|
||||
state_module._campaigns.clear()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client():
|
||||
"""Unauthenticated client (dev mode — no API key configured)."""
|
||||
return TestClient(create_app(api_key=""))
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def auth_client():
|
||||
"""Client for an app configured with API key 'test-secret'."""
|
||||
return TestClient(create_app(api_key="test-secret"))
|
||||
|
||||
|
||||
def _mock_inference_state(**kwargs) -> InferenceState:
|
||||
"""Return a minimal InferenceState with a fake ONNX session."""
|
||||
session = MagicMock()
|
||||
defaults = dict(
|
||||
model_path="/models/test.onnx",
|
||||
label_map={"iphone13": 0, "noise": 1},
|
||||
index_to_label={0: "iphone13", 1: "noise"},
|
||||
session=session,
|
||||
)
|
||||
defaults.update(kwargs)
|
||||
return InferenceState(**defaults)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Health check
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestHealth:
|
||||
def test_health_returns_ok(self, client):
|
||||
resp = client.get("/health")
|
||||
assert resp.status_code == 200
|
||||
assert resp.json() == {"status": "ok"}
|
||||
|
||||
def test_health_requires_no_auth(self, auth_client):
|
||||
# /health has no auth dependency — should be 200 even without a key
|
||||
resp = auth_client.get("/health")
|
||||
assert resp.status_code == 200
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Authentication
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestAuth:
|
||||
def test_missing_key_rejected(self, auth_client):
|
||||
resp = auth_client.get("/inference/status")
|
||||
assert resp.status_code == 403
|
||||
|
||||
def test_wrong_key_rejected(self, auth_client):
|
||||
resp = auth_client.get("/inference/status", headers={"X-API-Key": "wrong"})
|
||||
assert resp.status_code == 403
|
||||
|
||||
def test_correct_key_accepted(self, auth_client):
|
||||
resp = auth_client.get("/inference/status", headers={"X-API-Key": "test-secret"})
|
||||
# 200 null is fine here — no model loaded yet
|
||||
assert resp.status_code == 200
|
||||
|
||||
def test_dev_mode_no_key_required(self, client):
|
||||
resp = client.get("/inference/status")
|
||||
assert resp.status_code == 200
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# POST /inference/load
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestInferenceLoad:
|
||||
def test_load_returns_loaded_true(self, client):
|
||||
mock_session = MagicMock()
|
||||
with patch("ria_toolkit_oss.server.routers.inference._load_onnx_session", return_value=mock_session):
|
||||
resp = client.post(
|
||||
"/inference/load",
|
||||
json={"model_path": "/models/m.onnx", "label_map": {"iphone13": 0, "noise": 1}},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
body = resp.json()
|
||||
assert body["loaded"] is True
|
||||
assert body["model_path"] == "/models/m.onnx"
|
||||
assert body["num_classes"] == 2
|
||||
|
||||
def test_load_stores_state(self, client):
|
||||
mock_session = MagicMock()
|
||||
with patch("ria_toolkit_oss.server.routers.inference._load_onnx_session", return_value=mock_session):
|
||||
client.post(
|
||||
"/inference/load",
|
||||
json={"model_path": "/models/m.onnx", "label_map": {"zone_a": 0}},
|
||||
)
|
||||
assert state_module._inference is not None
|
||||
assert state_module._inference.model_path == "/models/m.onnx"
|
||||
|
||||
def test_load_builds_reverse_index(self, client):
|
||||
mock_session = MagicMock()
|
||||
with patch("ria_toolkit_oss.server.routers.inference._load_onnx_session", return_value=mock_session):
|
||||
client.post(
|
||||
"/inference/load",
|
||||
json={"model_path": "/m.onnx", "label_map": {"cat": 0, "dog": 1}},
|
||||
)
|
||||
assert state_module._inference.index_to_label == {0: "cat", 1: "dog"}
|
||||
|
||||
def test_load_503_when_onnxruntime_missing(self, client):
|
||||
from fastapi import HTTPException as FastAPIHTTPException
|
||||
|
||||
with patch(
|
||||
"ria_toolkit_oss.server.routers.inference._load_onnx_session",
|
||||
side_effect=FastAPIHTTPException(status_code=503, detail="onnxruntime not installed"),
|
||||
):
|
||||
resp = client.post(
|
||||
"/inference/load",
|
||||
json={"model_path": "/m.onnx", "label_map": {}},
|
||||
)
|
||||
assert resp.status_code == 503
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# GET /inference/status
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestInferenceStatus:
|
||||
def test_returns_null_when_no_model_loaded(self, client):
|
||||
resp = client.get("/inference/status")
|
||||
assert resp.status_code == 200
|
||||
assert resp.json() is None
|
||||
|
||||
def test_returns_null_when_model_loaded_but_no_result_yet(self, client):
|
||||
set_inference(_mock_inference_state())
|
||||
resp = client.get("/inference/status")
|
||||
assert resp.status_code == 200
|
||||
assert resp.json() is None
|
||||
|
||||
def test_returns_latest_result(self, client):
|
||||
state = _mock_inference_state()
|
||||
state.set_latest(
|
||||
{
|
||||
"timestamp": 1234567890.0,
|
||||
"idle": False,
|
||||
"device_id": "iphone13",
|
||||
"confidence": 0.94,
|
||||
"snr_db": 18.5,
|
||||
}
|
||||
)
|
||||
set_inference(state)
|
||||
|
||||
resp = client.get("/inference/status")
|
||||
assert resp.status_code == 200
|
||||
body = resp.json()
|
||||
assert body["device_id"] == "iphone13"
|
||||
assert body["confidence"] == 0.94
|
||||
assert body["idle"] is False
|
||||
|
||||
def test_idle_result_returned(self, client):
|
||||
state = _mock_inference_state()
|
||||
state.set_latest(
|
||||
{
|
||||
"timestamp": 1234567890.0,
|
||||
"idle": True,
|
||||
"device_id": None,
|
||||
"confidence": 0.55,
|
||||
"snr_db": 2.1,
|
||||
}
|
||||
)
|
||||
set_inference(state)
|
||||
|
||||
resp = client.get("/inference/status")
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["idle"] is True
|
||||
assert resp.json()["device_id"] is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# POST /inference/configure
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestInferenceConfigure:
|
||||
def test_configure_409_when_no_model_loaded(self, client):
|
||||
resp = client.post("/inference/configure", json={"center_freq": 2450000000})
|
||||
assert resp.status_code == 409
|
||||
|
||||
def test_configure_stores_pending_config(self, client):
|
||||
set_inference(_mock_inference_state())
|
||||
resp = client.post(
|
||||
"/inference/configure",
|
||||
json={"center_freq": 915000000, "gain": 30},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["configured"] is True
|
||||
|
||||
pending = state_module._inference.pop_pending_config()
|
||||
assert pending["center_freq"] == 915000000
|
||||
assert pending["gain"] == 30
|
||||
|
||||
def test_configure_empty_body_returns_configured_false(self, client):
|
||||
set_inference(_mock_inference_state())
|
||||
resp = client.post("/inference/configure", json={})
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["configured"] is False
|
||||
|
||||
def test_configure_only_sends_provided_fields(self, client):
|
||||
set_inference(_mock_inference_state())
|
||||
client.post("/inference/configure", json={"sample_rate": 20000000})
|
||||
pending = state_module._inference.pop_pending_config()
|
||||
assert "sample_rate" in pending
|
||||
assert "center_freq" not in pending
|
||||
assert "gain" not in pending
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# POST /inference/stop
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestInferenceStop:
|
||||
def test_stop_returns_false_when_not_running(self, client):
|
||||
resp = client.post("/inference/stop")
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["stopped"] is False
|
||||
|
||||
def test_stop_returns_false_when_model_loaded_but_not_started(self, client):
|
||||
set_inference(_mock_inference_state())
|
||||
resp = client.post("/inference/stop")
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["stopped"] is False
|
||||
|
||||
def test_stop_signals_running_thread(self, client):
|
||||
state = _mock_inference_state()
|
||||
state.running = True
|
||||
# Thread that waits for stop_event
|
||||
barrier = threading.Event()
|
||||
|
||||
def _dummy_loop():
|
||||
barrier.set()
|
||||
state.stop_event.wait(timeout=2)
|
||||
state.running = False
|
||||
|
||||
state.thread = threading.Thread(target=_dummy_loop, daemon=True)
|
||||
state.thread.start()
|
||||
barrier.wait(timeout=1)
|
||||
set_inference(state)
|
||||
|
||||
resp = client.post("/inference/stop")
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["stopped"] is True
|
||||
assert state.stop_event.is_set()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# POST /orchestrator/deploy
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestOrchestratorDeploy:
|
||||
def test_deploy_422_on_invalid_config(self, client):
|
||||
with patch(
|
||||
"ria_toolkit_oss.server.routers.orchestrator.CampaignConfig.from_dict",
|
||||
side_effect=ValueError("missing required field 'name'"),
|
||||
):
|
||||
resp = client.post("/orchestrator/deploy", json={"config": {}})
|
||||
assert resp.status_code == 422
|
||||
|
||||
def test_deploy_returns_campaign_id(self, client):
|
||||
mock_cfg = MagicMock()
|
||||
mock_cfg.name = "test_campaign"
|
||||
mock_cfg.total_steps.return_value = 5
|
||||
mock_executor = MagicMock()
|
||||
mock_executor.return_value.run.return_value = MagicMock(to_dict=lambda: {})
|
||||
|
||||
with (
|
||||
patch("ria_toolkit_oss.server.routers.orchestrator.CampaignConfig.from_dict", return_value=mock_cfg),
|
||||
patch("ria_toolkit_oss.server.routers.orchestrator.CampaignExecutor", mock_executor),
|
||||
):
|
||||
resp = client.post("/orchestrator/deploy", json={"config": {"name": "test_campaign"}})
|
||||
|
||||
assert resp.status_code == 200
|
||||
body = resp.json()
|
||||
assert "campaign_id" in body
|
||||
assert len(body["campaign_id"]) > 0
|
||||
|
||||
def test_deploy_registers_campaign_in_state(self, client):
|
||||
mock_cfg = MagicMock()
|
||||
mock_cfg.name = "test_campaign"
|
||||
mock_cfg.total_steps.return_value = 3
|
||||
mock_executor = MagicMock()
|
||||
mock_executor.return_value.run.return_value = MagicMock(to_dict=lambda: {})
|
||||
|
||||
with (
|
||||
patch("ria_toolkit_oss.server.routers.orchestrator.CampaignConfig.from_dict", return_value=mock_cfg),
|
||||
patch("ria_toolkit_oss.server.routers.orchestrator.CampaignExecutor", mock_executor),
|
||||
):
|
||||
resp = client.post("/orchestrator/deploy", json={"config": {}})
|
||||
|
||||
campaign_id = resp.json()["campaign_id"]
|
||||
assert state_module._campaigns.get(campaign_id) is not None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# GET /orchestrator/status/{campaign_id}
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestOrchestratorStatus:
|
||||
def test_status_404_for_unknown_id(self, client):
|
||||
resp = client.get("/orchestrator/status/nonexistent-id")
|
||||
assert resp.status_code == 404
|
||||
|
||||
def test_status_returns_campaign_state(self, client):
|
||||
cancel_event = threading.Event()
|
||||
state = CampaignState(
|
||||
campaign_id="abc-123",
|
||||
status="running",
|
||||
config_name="test",
|
||||
cancel_event=cancel_event,
|
||||
thread=MagicMock(),
|
||||
total_steps=10,
|
||||
progress=3,
|
||||
)
|
||||
state_module._campaigns["abc-123"] = state
|
||||
|
||||
resp = client.get("/orchestrator/status/abc-123")
|
||||
assert resp.status_code == 200
|
||||
body = resp.json()
|
||||
assert body["campaign_id"] == "abc-123"
|
||||
assert body["status"] == "running"
|
||||
assert body["progress"] == 3
|
||||
assert body["total_steps"] == 10
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# POST /orchestrator/cancel/{campaign_id}
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestOrchestratorCancel:
|
||||
def test_cancel_404_for_unknown_id(self, client):
|
||||
resp = client.post("/orchestrator/cancel/no-such-id")
|
||||
assert resp.status_code == 404
|
||||
|
||||
def test_cancel_sets_cancel_event(self, client):
|
||||
cancel_event = threading.Event()
|
||||
state = CampaignState(
|
||||
campaign_id="camp-to-cancel",
|
||||
status="running",
|
||||
config_name="test",
|
||||
cancel_event=cancel_event,
|
||||
thread=MagicMock(),
|
||||
)
|
||||
state_module._campaigns["camp-to-cancel"] = state
|
||||
|
||||
resp = client.post("/orchestrator/cancel/camp-to-cancel")
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["cancelled"] is True
|
||||
assert cancel_event.is_set()
|
||||
|
||||
def test_cancel_already_completed_returns_false(self, client):
|
||||
cancel_event = threading.Event()
|
||||
state = CampaignState(
|
||||
campaign_id="done",
|
||||
status="completed",
|
||||
config_name="test",
|
||||
cancel_event=cancel_event,
|
||||
thread=MagicMock(),
|
||||
)
|
||||
state_module._campaigns["done"] = state
|
||||
|
||||
resp = client.post("/orchestrator/cancel/done")
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["cancelled"] is False
|
||||
assert not cancel_event.is_set()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# State helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestInferenceStateHelpers:
|
||||
def test_set_and_get_latest(self):
|
||||
state = _mock_inference_state()
|
||||
payload = {"timestamp": 1.0, "idle": False, "device_id": "dev1", "confidence": 0.9, "snr_db": 15.0}
|
||||
state.set_latest(payload)
|
||||
assert state.get_latest() == payload
|
||||
|
||||
def test_get_latest_returns_none_initially(self):
|
||||
state = _mock_inference_state()
|
||||
assert state.get_latest() is None
|
||||
|
||||
def test_set_and_pop_pending_config(self):
|
||||
state = _mock_inference_state()
|
||||
state.set_pending_config({"center_freq": 915e6})
|
||||
popped = state.pop_pending_config()
|
||||
assert popped == {"center_freq": 915e6}
|
||||
assert state.pop_pending_config() is None # cleared after pop
|
||||
|
||||
def test_pending_config_overwrite(self):
|
||||
state = _mock_inference_state()
|
||||
state.set_pending_config({"center_freq": 915e6})
|
||||
state.set_pending_config({"center_freq": 2450e6, "gain": 40})
|
||||
assert state.pop_pending_config()["center_freq"] == 2450e6
|
||||
|
||||
def test_thread_safety_latest(self):
|
||||
"""Multiple threads writing latest; final read should not raise."""
|
||||
state = _mock_inference_state()
|
||||
results = []
|
||||
|
||||
def writer(val):
|
||||
for _ in range(100):
|
||||
state.set_latest({"v": val})
|
||||
|
||||
def reader():
|
||||
for _ in range(100):
|
||||
results.append(state.get_latest())
|
||||
|
||||
threads = [threading.Thread(target=writer, args=(i,)) for i in range(4)]
|
||||
threads.append(threading.Thread(target=reader))
|
||||
for t in threads:
|
||||
t.start()
|
||||
for t in threads:
|
||||
t.join(timeout=5)
|
||||
|
||||
# No exception raised and reader got non-None values
|
||||
non_none = [r for r in results if r is not None]
|
||||
assert len(non_none) > 0
|
||||
|
|
@ -1,5 +1,4 @@
|
|||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from ria_toolkit_oss.datatypes import Recording
|
||||
from ria_toolkit_oss.transforms import iq_augmentations
|
||||
|
|
@ -207,7 +206,7 @@ def test_cut_out_avg_snr_1():
|
|||
transformed_data = iq_augmentations.cut_out(TEST_DATA1, max_section_size=2, fill_type="avg-snr")
|
||||
assert np.allclose(
|
||||
transformed_data,
|
||||
np.asarray([[1.04504475 - 3.19650874j, 2.18835276 + 1.87922077j, 3 + 3j, 3.38706877 - 0.53958902j]]),
|
||||
np.asarray([[-1.26516288 - 0.36655702j, -2.44693984 + 1.27294267j, 3 + 3j, 4.1583403 - 0.96625365j]]),
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -225,198 +224,3 @@ def test_patch_shuffle_rec():
|
|||
transformed_rec = iq_augmentations.patch_shuffle(rec, max_patch_size=3)
|
||||
assert np.array_equal(transformed_rec.data, np.asarray([[3 + 2j, 1 + 4j, 5 + 5j, 2 - 6j, 4 + 4j]]))
|
||||
assert rec.metadata == transformed_rec.metadata
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Additional coverage: error paths and missing Recording variants
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
# --- generate_awgn ---
|
||||
|
||||
|
||||
def test_generate_awgn_recording_input():
|
||||
# generate_awgn() with a Recording should return a Recording with same metadata.
|
||||
rec = Recording(data=TEST_DATA1, metadata=TEST_METADATA)
|
||||
result = iq_augmentations.generate_awgn(rec, snr=10)
|
||||
assert isinstance(result, Recording)
|
||||
assert result.metadata == rec.metadata
|
||||
|
||||
|
||||
def test_generate_awgn_invalid_real_raises():
|
||||
with pytest.raises(ValueError):
|
||||
iq_augmentations.generate_awgn(np.array([[1.0, 2.0, 3.0]]))
|
||||
|
||||
|
||||
def test_generate_awgn_invalid_1d_raises():
|
||||
with pytest.raises(ValueError):
|
||||
iq_augmentations.generate_awgn(np.array([1 + 1j, 2 + 2j]))
|
||||
|
||||
|
||||
# --- time_reversal ---
|
||||
|
||||
|
||||
def test_time_reversal_invalid_real_raises():
|
||||
with pytest.raises(ValueError):
|
||||
iq_augmentations.time_reversal(np.array([[1.0, 2.0, 3.0]]))
|
||||
|
||||
|
||||
def test_time_reversal_multi_channel_raises():
|
||||
with pytest.raises(NotImplementedError):
|
||||
iq_augmentations.time_reversal([[1 + 1j, 2 + 2j], [3 + 3j, 4 + 4j]])
|
||||
|
||||
|
||||
# --- spectral_inversion ---
|
||||
|
||||
|
||||
def test_spectral_inversion_invalid_real_raises():
|
||||
with pytest.raises(ValueError):
|
||||
iq_augmentations.spectral_inversion(np.array([[1.0, 2.0, 3.0]]))
|
||||
|
||||
|
||||
def test_spectral_inversion_multi_channel_raises():
|
||||
with pytest.raises(NotImplementedError):
|
||||
iq_augmentations.spectral_inversion([[1 + 1j, 2 + 2j], [3 + 3j, 4 + 4j]])
|
||||
|
||||
|
||||
# --- channel_swap ---
|
||||
|
||||
|
||||
def test_channel_swap_invalid_real_raises():
|
||||
with pytest.raises(ValueError):
|
||||
iq_augmentations.channel_swap(np.array([[1.0, 2.0, 3.0]]))
|
||||
|
||||
|
||||
# --- amplitude_reversal ---
|
||||
|
||||
|
||||
def test_amplitude_reversal_invalid_real_raises():
|
||||
with pytest.raises(ValueError):
|
||||
iq_augmentations.amplitude_reversal(np.array([[1.0, 2.0, 3.0]]))
|
||||
|
||||
|
||||
# --- drop_samples ---
|
||||
|
||||
|
||||
def test_drop_samples_rec_input():
|
||||
# drop_samples() with a Recording should return a Recording.
|
||||
np.random.seed(0)
|
||||
rec = Recording(data=TEST_DATA1, metadata=TEST_METADATA)
|
||||
result = iq_augmentations.drop_samples(rec, max_section_size=2, fill_type="zeros")
|
||||
assert isinstance(result, Recording)
|
||||
assert result.metadata == rec.metadata
|
||||
|
||||
|
||||
def test_drop_samples_invalid_max_section_size_zero():
|
||||
# max_section_size < 1 must raise ValueError.
|
||||
with pytest.raises(ValueError):
|
||||
iq_augmentations.drop_samples(TEST_DATA1, max_section_size=0)
|
||||
|
||||
|
||||
def test_drop_samples_invalid_max_section_size_too_large():
|
||||
# max_section_size >= n must raise ValueError.
|
||||
with pytest.raises(ValueError):
|
||||
iq_augmentations.drop_samples(TEST_DATA1, max_section_size=len(TEST_DATA1[0]))
|
||||
|
||||
|
||||
def test_drop_samples_invalid_fill_type_raises():
|
||||
with pytest.raises(ValueError):
|
||||
iq_augmentations.drop_samples(TEST_DATA1, max_section_size=2, fill_type="unknown")
|
||||
|
||||
|
||||
def test_drop_samples_invalid_real_raises():
|
||||
with pytest.raises(ValueError):
|
||||
iq_augmentations.drop_samples(np.array([[1.0, 2.0, 3.0]]))
|
||||
|
||||
|
||||
# --- quantize_tape ---
|
||||
|
||||
|
||||
def test_quantize_tape_invalid_rounding_type_raises():
|
||||
# An unrecognised rounding_type must raise UserWarning.
|
||||
with pytest.warns(UserWarning):
|
||||
iq_augmentations.quantize_tape(TEST_DATA1, rounding_type="round")
|
||||
|
||||
|
||||
def test_quantize_tape_invalid_real_raises():
|
||||
with pytest.raises(ValueError):
|
||||
iq_augmentations.quantize_tape(np.array([[1.0, 2.0, 3.0]]))
|
||||
|
||||
|
||||
# --- quantize_parts ---
|
||||
|
||||
|
||||
def test_quantize_parts_invalid_rounding_type_raises():
|
||||
with pytest.warns(UserWarning):
|
||||
iq_augmentations.quantize_parts(TEST_DATA1, rounding_type="round")
|
||||
|
||||
|
||||
# --- magnitude_rescale ---
|
||||
|
||||
|
||||
def test_magnitude_rescale_invalid_bounds_negative_raises():
|
||||
with pytest.raises(ValueError):
|
||||
iq_augmentations.magnitude_rescale(TEST_DATA1, starting_bounds=(-1, 2))
|
||||
|
||||
|
||||
def test_magnitude_rescale_invalid_bounds_too_large_raises():
|
||||
n = len(TEST_DATA1[0])
|
||||
with pytest.raises(ValueError):
|
||||
iq_augmentations.magnitude_rescale(TEST_DATA1, starting_bounds=(0, n))
|
||||
|
||||
|
||||
# --- cut_out ---
|
||||
|
||||
|
||||
def test_cut_out_zeros():
|
||||
# cut_out() with fill_type='zeros' must fill the section with 0+0j.
|
||||
np.random.seed(0)
|
||||
result = iq_augmentations.cut_out(TEST_DATA1, max_section_size=2, fill_type="zeros")
|
||||
assert result.dtype == np.asarray(TEST_DATA1).dtype or np.iscomplexobj(result)
|
||||
|
||||
|
||||
def test_cut_out_low_snr():
|
||||
# cut_out() with 'low-snr' should change the signal.
|
||||
np.random.seed(0)
|
||||
result = iq_augmentations.cut_out(TEST_DATA1, max_section_size=2, fill_type="low-snr")
|
||||
assert result.shape == np.asarray(TEST_DATA1).shape
|
||||
|
||||
|
||||
def test_cut_out_high_snr():
|
||||
# cut_out() with 'high-snr' should return data with same shape.
|
||||
np.random.seed(0)
|
||||
result = iq_augmentations.cut_out(TEST_DATA1, max_section_size=2, fill_type="high-snr")
|
||||
assert result.shape == np.asarray(TEST_DATA1).shape
|
||||
|
||||
|
||||
def test_cut_out_rec_input():
|
||||
# cut_out() with Recording should return Recording with preserved metadata.
|
||||
np.random.seed(0)
|
||||
rec = Recording(data=TEST_DATA1, metadata=TEST_METADATA)
|
||||
result = iq_augmentations.cut_out(rec, max_section_size=2, fill_type="zeros")
|
||||
assert isinstance(result, Recording)
|
||||
assert result.metadata == rec.metadata
|
||||
|
||||
|
||||
def test_cut_out_invalid_fill_type_raises():
|
||||
with pytest.warns(UserWarning):
|
||||
iq_augmentations.cut_out(TEST_DATA1, max_section_size=2, fill_type="bad")
|
||||
|
||||
|
||||
def test_cut_out_invalid_max_section_size_raises():
|
||||
with pytest.raises(ValueError):
|
||||
iq_augmentations.cut_out(TEST_DATA1, max_section_size=0)
|
||||
|
||||
|
||||
# --- patch_shuffle ---
|
||||
|
||||
|
||||
def test_patch_shuffle_max_patch_size_leq_1_raises():
|
||||
with pytest.raises(ValueError):
|
||||
iq_augmentations.patch_shuffle(TEST_DATA1, max_patch_size=1)
|
||||
|
||||
|
||||
def test_patch_shuffle_max_patch_size_too_large_raises():
|
||||
n = len(TEST_DATA1[0])
|
||||
with pytest.raises(ValueError):
|
||||
iq_augmentations.patch_shuffle(TEST_DATA1, max_patch_size=n + 1)
|
||||
|
|
|
|||
|
|
@ -1,406 +0,0 @@
|
|||
"""
|
||||
Unit tests for ria_toolkit_oss.transforms.iq_impairments.
|
||||
|
||||
Bugs/issues identified during review:
|
||||
- time_shift(signal, shift=0) returns all-zeros instead of the original signal.
|
||||
This is because `data[:, :-0]` evaluates as `data[:, :0]` (empty slice).
|
||||
Tests marked with BUG comments document this known failure.
|
||||
- resample() 'else' branch creates 'empty_array' but never returns it (dead code).
|
||||
When up < down, a shorter-than-input array is returned instead of zero-padded.
|
||||
- add_awgn_to_signal() contains a leftover debug print() call.
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from ria_toolkit_oss.datatypes import Recording
|
||||
from ria_toolkit_oss.transforms import iq_impairments
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Shared fixtures
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
SAMPLE_METADATA = {"source": "test", "timestamp": 1700000000.0}
|
||||
|
||||
# 1×4 complex signal
|
||||
DATA_4 = np.array([[1 + 1j, 2 + 2j, 3 + 3j, 4 + 4j]], dtype=np.complex128)
|
||||
|
||||
# 1×5 complex signal
|
||||
DATA_5 = np.array([[1 + 0j, 2 + 0j, 3 + 0j, 4 + 0j, 5 + 0j]], dtype=np.complex128)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# add_awgn_to_signal
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_add_awgn_array_shape():
|
||||
"""Output shape matches input."""
|
||||
result = iq_impairments.add_awgn_to_signal(DATA_4, snr=10)
|
||||
assert result.shape == DATA_4.shape
|
||||
|
||||
|
||||
def test_add_awgn_array_is_complex():
|
||||
"""Result must be complex."""
|
||||
result = iq_impairments.add_awgn_to_signal(DATA_4, snr=10)
|
||||
assert np.iscomplexobj(result)
|
||||
|
||||
|
||||
def test_add_awgn_not_identical_to_input():
|
||||
"""AWGN must actually change the signal."""
|
||||
np.random.seed(42)
|
||||
result = iq_impairments.add_awgn_to_signal(DATA_4, snr=10)
|
||||
assert not np.array_equal(result, DATA_4)
|
||||
|
||||
|
||||
def test_add_awgn_recording_input():
|
||||
"""Returns a Recording when given a Recording; metadata is preserved."""
|
||||
rec = Recording(data=DATA_4.copy(), metadata=SAMPLE_METADATA)
|
||||
result = iq_impairments.add_awgn_to_signal(rec, snr=10)
|
||||
assert isinstance(result, Recording)
|
||||
assert result.metadata["source"] == "test"
|
||||
assert result.data.shape == DATA_4.shape
|
||||
|
||||
|
||||
def test_add_awgn_recording_data_changed():
|
||||
"""AWGN must change the data even when a Recording is passed in."""
|
||||
np.random.seed(42)
|
||||
rec = Recording(data=DATA_4.copy(), metadata=SAMPLE_METADATA)
|
||||
result = iq_impairments.add_awgn_to_signal(rec, snr=10)
|
||||
assert not np.array_equal(result.data, DATA_4)
|
||||
|
||||
|
||||
def test_add_awgn_invalid_real_input():
|
||||
"""Raises ValueError for real (non-complex) input."""
|
||||
real_data = np.array([[1.0, 2.0, 3.0]])
|
||||
with pytest.raises(ValueError):
|
||||
iq_impairments.add_awgn_to_signal(real_data)
|
||||
|
||||
|
||||
def test_add_awgn_snr_approximated():
|
||||
"""With a large SNR the output should be close to the original signal."""
|
||||
np.random.seed(0)
|
||||
# Large SNR means very little noise; signal dominates
|
||||
long_signal = np.ones((1, 100000), dtype=np.complex128)
|
||||
result = iq_impairments.add_awgn_to_signal(long_signal, snr=60)
|
||||
assert np.allclose(result, long_signal, atol=0.01)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# time_shift
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_time_shift_positive():
|
||||
"""Positive shift moves samples right; leading samples become zero."""
|
||||
result = iq_impairments.time_shift(DATA_5, shift=2)
|
||||
expected = np.array([[0 + 0j, 0 + 0j, 1 + 0j, 2 + 0j, 3 + 0j]])
|
||||
assert np.array_equal(result, expected)
|
||||
|
||||
|
||||
def test_time_shift_negative():
|
||||
"""Negative shift moves samples left; trailing samples become zero."""
|
||||
result = iq_impairments.time_shift(DATA_5, shift=-2)
|
||||
expected = np.array([[3 + 0j, 4 + 0j, 5 + 0j, 0 + 0j, 0 + 0j]])
|
||||
assert np.array_equal(result, expected)
|
||||
|
||||
|
||||
def test_time_shift_shape_preserved():
|
||||
"""Output shape must equal input shape."""
|
||||
result = iq_impairments.time_shift(DATA_5, shift=1)
|
||||
assert result.shape == DATA_5.shape
|
||||
|
||||
|
||||
def test_time_shift_recording_input():
|
||||
"""Returns a Recording when given a Recording; metadata preserved."""
|
||||
rec = Recording(data=DATA_5.copy(), metadata=SAMPLE_METADATA)
|
||||
result = iq_impairments.time_shift(rec, shift=2)
|
||||
assert isinstance(result, Recording)
|
||||
assert result.metadata["source"] == "test"
|
||||
expected = np.array([[0 + 0j, 0 + 0j, 1 + 0j, 2 + 0j, 3 + 0j]])
|
||||
assert np.array_equal(result.data, expected)
|
||||
|
||||
|
||||
def test_time_shift_invalid_real_input():
|
||||
"""Raises ValueError for real input."""
|
||||
real_data = np.array([[1.0, 2.0, 3.0]])
|
||||
with pytest.raises(ValueError):
|
||||
iq_impairments.time_shift(real_data)
|
||||
|
||||
|
||||
def test_time_shift_large_shift_warns():
|
||||
"""shift > n raises a UserWarning."""
|
||||
with pytest.warns(UserWarning):
|
||||
iq_impairments.time_shift(DATA_5, shift=100)
|
||||
|
||||
|
||||
def test_time_shift_zero_is_identity():
|
||||
"""shift=0 returns the original signal unchanged."""
|
||||
result = iq_impairments.time_shift(DATA_5, shift=0)
|
||||
assert np.array_equal(result, DATA_5)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# frequency_shift
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_frequency_shift_zero_is_identity():
|
||||
"""A shift of 0 leaves the signal unchanged (cos(0)=1, sin(0)=0)."""
|
||||
result = iq_impairments.frequency_shift(DATA_4, shift=0.0)
|
||||
assert np.allclose(result, DATA_4)
|
||||
|
||||
|
||||
def test_frequency_shift_shape_preserved():
|
||||
"""Output shape must equal input shape."""
|
||||
result = iq_impairments.frequency_shift(DATA_4, shift=0.25)
|
||||
assert result.shape == DATA_4.shape
|
||||
|
||||
|
||||
def test_frequency_shift_is_complex():
|
||||
"""Output must be complex."""
|
||||
result = iq_impairments.frequency_shift(DATA_4, shift=0.1)
|
||||
assert np.iscomplexobj(result)
|
||||
|
||||
|
||||
def test_frequency_shift_half_nyquist():
|
||||
"""Shift of 0.5 (Nyquist) alternates sign: exp(j*π*n) = (-1)^n."""
|
||||
# Start with a real signal equal to [1, 1, 1, 1] (on the real axis).
|
||||
signal = np.array([[1 + 0j, 1 + 0j, 1 + 0j, 1 + 0j]], dtype=np.complex128)
|
||||
result = iq_impairments.frequency_shift(signal, shift=0.5)
|
||||
n = np.arange(4)
|
||||
expected = signal * np.exp(1j * 2 * np.pi * 0.5 * n)
|
||||
assert np.allclose(result, expected)
|
||||
|
||||
|
||||
def test_frequency_shift_recording_input():
|
||||
"""Returns a Recording when given a Recording; metadata preserved."""
|
||||
rec = Recording(data=DATA_4.copy(), metadata=SAMPLE_METADATA)
|
||||
result = iq_impairments.frequency_shift(rec, shift=0.25)
|
||||
assert isinstance(result, Recording)
|
||||
assert result.metadata["source"] == "test"
|
||||
assert result.data.shape == DATA_4.shape
|
||||
|
||||
|
||||
def test_frequency_shift_out_of_range_positive():
|
||||
"""shift > 0.5 raises ValueError."""
|
||||
with pytest.raises(ValueError):
|
||||
iq_impairments.frequency_shift(DATA_4, shift=0.6)
|
||||
|
||||
|
||||
def test_frequency_shift_out_of_range_negative():
|
||||
"""shift < -0.5 raises ValueError."""
|
||||
with pytest.raises(ValueError):
|
||||
iq_impairments.frequency_shift(DATA_4, shift=-0.51)
|
||||
|
||||
|
||||
def test_frequency_shift_invalid_real_input():
|
||||
"""Raises ValueError for real (non-complex) input."""
|
||||
real_data = np.array([[1.0, 2.0, 3.0]])
|
||||
with pytest.raises(ValueError):
|
||||
iq_impairments.frequency_shift(real_data, shift=0.1)
|
||||
|
||||
|
||||
def test_frequency_shift_boundary_values():
|
||||
"""Boundary values ±0.5 are accepted without error."""
|
||||
iq_impairments.frequency_shift(DATA_4, shift=0.5)
|
||||
iq_impairments.frequency_shift(DATA_4, shift=-0.5)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# phase_shift
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_phase_shift_zero_is_identity():
|
||||
"""Phase shift of 0 leaves signal unchanged."""
|
||||
result = iq_impairments.phase_shift(DATA_4, phase=0.0)
|
||||
assert np.allclose(result, DATA_4)
|
||||
|
||||
|
||||
def test_phase_shift_pi_negates():
|
||||
"""Phase shift of π negates the signal: exp(jπ) = -1."""
|
||||
result = iq_impairments.phase_shift(DATA_4, phase=np.pi)
|
||||
assert np.allclose(result, -DATA_4)
|
||||
|
||||
|
||||
def test_phase_shift_half_pi():
|
||||
"""Phase shift of π/2 multiplies by j: exp(j π/2) = j."""
|
||||
result = iq_impairments.phase_shift(DATA_4, phase=np.pi / 2)
|
||||
expected = DATA_4 * 1j
|
||||
assert np.allclose(result, expected)
|
||||
|
||||
|
||||
def test_phase_shift_shape_preserved():
|
||||
"""Output shape must equal input shape."""
|
||||
result = iq_impairments.phase_shift(DATA_4, phase=np.pi / 4)
|
||||
assert result.shape == DATA_4.shape
|
||||
|
||||
|
||||
def test_phase_shift_recording_input():
|
||||
"""Returns a Recording when given a Recording; metadata preserved."""
|
||||
rec = Recording(data=DATA_4.copy(), metadata=SAMPLE_METADATA)
|
||||
result = iq_impairments.phase_shift(rec, phase=np.pi / 2)
|
||||
assert isinstance(result, Recording)
|
||||
assert result.metadata["source"] == "test"
|
||||
expected = DATA_4 * 1j
|
||||
assert np.allclose(result.data, expected)
|
||||
|
||||
|
||||
def test_phase_shift_out_of_range_positive():
|
||||
"""phase > π raises ValueError."""
|
||||
with pytest.raises(ValueError):
|
||||
iq_impairments.phase_shift(DATA_4, phase=np.pi + 0.01)
|
||||
|
||||
|
||||
def test_phase_shift_out_of_range_negative():
|
||||
"""phase < -π raises ValueError."""
|
||||
with pytest.raises(ValueError):
|
||||
iq_impairments.phase_shift(DATA_4, phase=-np.pi - 0.01)
|
||||
|
||||
|
||||
def test_phase_shift_boundary_values():
|
||||
"""Boundary values ±π are accepted without error."""
|
||||
iq_impairments.phase_shift(DATA_4, phase=np.pi)
|
||||
iq_impairments.phase_shift(DATA_4, phase=-np.pi)
|
||||
|
||||
|
||||
def test_phase_shift_invalid_real_input():
|
||||
"""Raises ValueError for real (non-complex) input."""
|
||||
real_data = np.array([[1.0, 2.0, 3.0]])
|
||||
with pytest.raises(ValueError):
|
||||
iq_impairments.phase_shift(real_data, phase=0.0)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# iq_imbalance
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_iq_imbalance_basic_shape():
|
||||
"""Output shape matches input shape."""
|
||||
result = iq_impairments.iq_imbalance(DATA_4, amplitude_imbalance=1.0, phase_imbalance=0.1, dc_offset=0.0)
|
||||
assert result.shape == DATA_4.shape
|
||||
|
||||
|
||||
def test_iq_imbalance_is_complex():
|
||||
"""Output must be complex."""
|
||||
result = iq_impairments.iq_imbalance(DATA_4, amplitude_imbalance=1.0, phase_imbalance=0.1, dc_offset=0.0)
|
||||
assert np.iscomplexobj(result)
|
||||
|
||||
|
||||
def test_iq_imbalance_changes_signal():
|
||||
"""IQ imbalance with non-zero parameters must change the signal."""
|
||||
result = iq_impairments.iq_imbalance(DATA_4, amplitude_imbalance=3.0, phase_imbalance=0.5, dc_offset=2.0)
|
||||
assert not np.allclose(result, DATA_4)
|
||||
|
||||
|
||||
def test_iq_imbalance_recording_input():
|
||||
"""Returns a Recording when given a Recording; metadata preserved."""
|
||||
rec = Recording(data=DATA_4.copy(), metadata=SAMPLE_METADATA)
|
||||
result = iq_impairments.iq_imbalance(rec, amplitude_imbalance=1.0, phase_imbalance=0.1, dc_offset=0.0)
|
||||
assert isinstance(result, Recording)
|
||||
assert result.metadata["source"] == "test"
|
||||
assert result.data.shape == DATA_4.shape
|
||||
|
||||
|
||||
def test_iq_imbalance_phase_out_of_range_positive():
|
||||
"""phase_imbalance > π raises ValueError."""
|
||||
with pytest.raises(ValueError):
|
||||
iq_impairments.iq_imbalance(DATA_4, phase_imbalance=np.pi + 0.01)
|
||||
|
||||
|
||||
def test_iq_imbalance_phase_out_of_range_negative():
|
||||
"""phase_imbalance < -π raises ValueError."""
|
||||
with pytest.raises(ValueError):
|
||||
iq_impairments.iq_imbalance(DATA_4, phase_imbalance=-np.pi - 0.01)
|
||||
|
||||
|
||||
def test_iq_imbalance_phase_boundary_values():
|
||||
"""Boundary values ±π are accepted without error."""
|
||||
iq_impairments.iq_imbalance(DATA_4, phase_imbalance=np.pi)
|
||||
iq_impairments.iq_imbalance(DATA_4, phase_imbalance=-np.pi)
|
||||
|
||||
|
||||
def test_iq_imbalance_invalid_real_input():
|
||||
"""Raises ValueError for real (non-complex) input."""
|
||||
real_data = np.array([[1.0, 2.0, 3.0]])
|
||||
with pytest.raises(ValueError):
|
||||
iq_impairments.iq_imbalance(real_data)
|
||||
|
||||
|
||||
def test_iq_imbalance_amplitude_symmetry():
|
||||
"""Swapping sign of amplitude_imbalance should exchange I and Q scaling."""
|
||||
pos = iq_impairments.iq_imbalance(DATA_4, amplitude_imbalance=3.0, phase_imbalance=0.0, dc_offset=0.0)
|
||||
neg = iq_impairments.iq_imbalance(DATA_4, amplitude_imbalance=-3.0, phase_imbalance=0.0, dc_offset=0.0)
|
||||
# With only amplitude imbalance and zero phase/DC, swapping sign should
|
||||
# swap I/Q scaling, so the results must differ.
|
||||
assert not np.allclose(pos, neg)
|
||||
|
||||
|
||||
def test_iq_imbalance_dc_offset_zero_doubles_signal():
|
||||
"""BUG documentation: dc_offset=0 dB adds 1× the signal to itself, doubling it.
|
||||
|
||||
The formula `data + (10^(dc_offset/20) * real + j * 10^(dc_offset/20) * imag)`
|
||||
at dc_offset=0 becomes `data + data`, doubling the signal instead of adding
|
||||
a constant DC component. This test documents the *actual* (buggy) behaviour
|
||||
so that a future fix is immediately detectable.
|
||||
"""
|
||||
# Use a pure real signal so we can reason without phase effects.
|
||||
signal = np.array([[2 + 0j]], dtype=np.complex128)
|
||||
result = iq_impairments.iq_imbalance(signal, amplitude_imbalance=0.0, phase_imbalance=0.0, dc_offset=0.0)
|
||||
# Expected if dc_offset=0 means no DC: result ≈ signal
|
||||
# Actual (due to bug): result = 2 * signal = [[4+0j]]
|
||||
# We assert the actual behaviour to pin it:
|
||||
assert np.allclose(result.real, 4.0), (
|
||||
"dc_offset=0 currently doubles the signal (adds 1× copy). "
|
||||
"If this assertion fails, the dc_offset formula has been fixed — update this test."
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# resample
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_resample_upsample_shape():
|
||||
"""up=2, down=1 — resampled signal is truncated to original length."""
|
||||
signal = np.array([[1 + 1j, 2 + 2j, 4 + 4j, 8 + 8j]], dtype=np.complex128)
|
||||
result = iq_impairments.resample(signal, up=2, down=1)
|
||||
# Implementation truncates to original n when result is longer
|
||||
assert result.shape[0] == 1
|
||||
assert result.shape[1] == signal.shape[1]
|
||||
|
||||
|
||||
def test_resample_is_complex():
|
||||
"""Resampled output is complex."""
|
||||
result = iq_impairments.resample(DATA_4, up=2, down=1)
|
||||
assert np.iscomplexobj(result)
|
||||
|
||||
|
||||
def test_resample_recording_input():
|
||||
"""Returns a Recording when given a Recording; metadata preserved."""
|
||||
rec = Recording(data=DATA_4.copy(), metadata=SAMPLE_METADATA)
|
||||
result = iq_impairments.resample(rec, up=2, down=1)
|
||||
assert isinstance(result, Recording)
|
||||
assert result.metadata["source"] == "test"
|
||||
|
||||
|
||||
def test_resample_unchanged_ratio():
|
||||
"""up == down should return the same number of samples."""
|
||||
result = iq_impairments.resample(DATA_4, up=3, down=3)
|
||||
assert result.shape[1] == DATA_4.shape[1]
|
||||
|
||||
|
||||
def test_resample_invalid_real_input():
|
||||
"""Raises ValueError for real (non-complex) input."""
|
||||
real_data = np.array([[1.0, 2.0, 3.0]])
|
||||
with pytest.raises(ValueError):
|
||||
iq_impairments.resample(real_data)
|
||||
|
||||
|
||||
def test_resample_downsample_returns_same_length():
|
||||
"""Downsampling zero-pads output to match input length."""
|
||||
signal = np.array([[1 + 1j, 2 + 2j, 3 + 3j, 4 + 4j, 5 + 5j, 6 + 6j]], dtype=np.complex128)
|
||||
result = iq_impairments.resample(signal, up=1, down=2)
|
||||
assert result.shape[1] == signal.shape[1]
|
||||
|
|
@ -1,209 +0,0 @@
|
|||
"""
|
||||
Unit tests for ria_toolkit_oss.utils.array_conversion.
|
||||
|
||||
Covers:
|
||||
- is_1xn / is_2xn classification
|
||||
- convert_to_1xn / convert_to_2xn conversion
|
||||
- Round-trip invariance
|
||||
- Error paths for invalid inputs
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from ria_toolkit_oss.utils.array_conversion import (
|
||||
convert_to_1xn,
|
||||
convert_to_2xn,
|
||||
is_1xn,
|
||||
is_2xn,
|
||||
)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fixtures
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
COMPLEX_1XN = np.array([[1 + 2j, 3 + 4j, 5 + 6j]], dtype=np.complex128) # shape (1, 3)
|
||||
REAL_2XN = np.array([[1.0, 3.0, 5.0], [2.0, 4.0, 6.0]], dtype=np.float64) # shape (2, 3)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# is_1xn
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_is_1xn_true_for_complex_1xn():
|
||||
assert is_1xn(COMPLEX_1XN) is True
|
||||
|
||||
|
||||
def test_is_1xn_false_for_real_2xn():
|
||||
assert is_1xn(REAL_2XN) is False
|
||||
|
||||
|
||||
def test_is_1xn_false_for_1d_complex():
|
||||
arr = np.array([1 + 2j, 3 + 4j]) # 1-D
|
||||
assert is_1xn(arr) is False
|
||||
|
||||
|
||||
def test_is_1xn_false_for_3d():
|
||||
arr = np.ones((1, 3, 3), dtype=np.complex128)
|
||||
assert is_1xn(arr) is False
|
||||
|
||||
|
||||
def test_is_1xn_false_for_real_1xn():
|
||||
arr = np.array([[1.0, 2.0, 3.0]]) # real 1×N
|
||||
assert is_1xn(arr) is False
|
||||
|
||||
|
||||
def test_is_1xn_false_for_complex_2xn():
|
||||
arr = np.array([[1 + 2j, 3 + 4j], [5 + 6j, 7 + 8j]]) # complex 2×N
|
||||
assert is_1xn(arr) is False
|
||||
|
||||
|
||||
def test_is_1xn_single_sample():
|
||||
arr = np.array([[1 + 0j]]) # shape (1, 1)
|
||||
assert is_1xn(arr) is True
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# is_2xn
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_is_2xn_true_for_real_2xn():
|
||||
assert is_2xn(REAL_2XN) is True
|
||||
|
||||
|
||||
def test_is_2xn_false_for_complex_1xn():
|
||||
assert is_2xn(COMPLEX_1XN) is False
|
||||
|
||||
|
||||
def test_is_2xn_false_for_1d():
|
||||
arr = np.array([1.0, 2.0, 3.0])
|
||||
assert is_2xn(arr) is False
|
||||
|
||||
|
||||
def test_is_2xn_false_for_3xn():
|
||||
arr = np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]) # shape (3, 2)
|
||||
assert is_2xn(arr) is False
|
||||
|
||||
|
||||
def test_is_2xn_false_for_complex_2xn():
|
||||
arr = np.array([[1 + 2j, 3 + 4j], [5 + 6j, 7 + 8j]]) # complex 2×N
|
||||
assert is_2xn(arr) is False
|
||||
|
||||
|
||||
def test_is_2xn_single_column():
|
||||
arr = np.array([[1.0], [2.0]]) # shape (2, 1)
|
||||
assert is_2xn(arr) is True
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# convert_to_2xn
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_convert_to_2xn_from_1xn_shape():
|
||||
result = convert_to_2xn(COMPLEX_1XN)
|
||||
assert result.shape == (2, COMPLEX_1XN.shape[1])
|
||||
|
||||
|
||||
def test_convert_to_2xn_from_1xn_values():
|
||||
"""First row is real, second row is imaginary."""
|
||||
result = convert_to_2xn(COMPLEX_1XN)
|
||||
assert np.array_equal(result[0], COMPLEX_1XN[0].real)
|
||||
assert np.array_equal(result[1], COMPLEX_1XN[0].imag)
|
||||
|
||||
|
||||
def test_convert_to_2xn_from_1xn_is_real():
|
||||
result = convert_to_2xn(COMPLEX_1XN)
|
||||
assert not np.iscomplexobj(result)
|
||||
|
||||
|
||||
def test_convert_to_2xn_from_2xn_is_copy():
|
||||
"""Already-2xN input returns a copy (not the same object)."""
|
||||
result = convert_to_2xn(REAL_2XN)
|
||||
assert np.array_equal(result, REAL_2XN)
|
||||
assert result is not REAL_2XN
|
||||
|
||||
|
||||
def test_convert_to_2xn_invalid_raises():
|
||||
"""1-D array is neither 1xN nor 2xN — must raise ValueError."""
|
||||
arr = np.array([1.0, 2.0, 3.0])
|
||||
with pytest.raises(ValueError):
|
||||
convert_to_2xn(arr)
|
||||
|
||||
|
||||
def test_convert_to_2xn_invalid_complex_2xn_raises():
|
||||
"""Complex 2×N is not a recognised format — must raise ValueError."""
|
||||
arr = np.array([[1 + 2j, 3 + 4j], [5 + 6j, 7 + 8j]])
|
||||
with pytest.raises(ValueError):
|
||||
convert_to_2xn(arr)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# convert_to_1xn
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_convert_to_1xn_from_2xn_shape():
|
||||
result = convert_to_1xn(REAL_2XN)
|
||||
assert result.shape == (1, REAL_2XN.shape[1])
|
||||
|
||||
|
||||
def test_convert_to_1xn_from_2xn_values():
|
||||
"""Real part from row 0, imaginary from row 1."""
|
||||
result = convert_to_1xn(REAL_2XN)
|
||||
assert np.array_equal(result[0].real, REAL_2XN[0])
|
||||
assert np.array_equal(result[0].imag, REAL_2XN[1])
|
||||
|
||||
|
||||
def test_convert_to_1xn_from_2xn_is_complex():
|
||||
result = convert_to_1xn(REAL_2XN)
|
||||
assert np.iscomplexobj(result)
|
||||
|
||||
|
||||
def test_convert_to_1xn_from_1xn_is_copy():
|
||||
"""Already-1xN input returns a copy (not the same object)."""
|
||||
result = convert_to_1xn(COMPLEX_1XN)
|
||||
assert np.array_equal(result, COMPLEX_1XN)
|
||||
assert result is not COMPLEX_1XN
|
||||
|
||||
|
||||
def test_convert_to_1xn_invalid_raises():
|
||||
"""1-D array is neither 1xN nor 2xN — must raise ValueError."""
|
||||
arr = np.array([1.0, 2.0, 3.0])
|
||||
with pytest.raises(ValueError):
|
||||
convert_to_1xn(arr)
|
||||
|
||||
|
||||
def test_convert_to_1xn_invalid_3xn_raises():
|
||||
"""3×N array is not a recognised format — must raise ValueError."""
|
||||
arr = np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]])
|
||||
with pytest.raises(ValueError):
|
||||
convert_to_1xn(arr)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Round-trip invariance
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_roundtrip_1xn_to_2xn_to_1xn():
|
||||
"""1xN → 2xN → 1xN should recover the original values."""
|
||||
intermediate = convert_to_2xn(COMPLEX_1XN)
|
||||
recovered = convert_to_1xn(intermediate)
|
||||
assert np.allclose(recovered, COMPLEX_1XN)
|
||||
|
||||
|
||||
def test_roundtrip_2xn_to_1xn_to_2xn():
|
||||
"""2xN → 1xN → 2xN should recover the original values."""
|
||||
intermediate = convert_to_1xn(REAL_2XN)
|
||||
recovered = convert_to_2xn(intermediate)
|
||||
assert np.allclose(recovered, REAL_2XN)
|
||||
|
||||
|
||||
def test_roundtrip_preserves_precision():
|
||||
"""Values survive a double conversion with full float64 precision."""
|
||||
data = np.array([[1.23456789 + 9.87654321j, -0.1 - 0.2j]], dtype=np.complex128)
|
||||
recovered = convert_to_1xn(convert_to_2xn(data))
|
||||
assert np.allclose(recovered, data, atol=1e-14)
|
||||
Loading…
Reference in New Issue
Block a user