Compare commits
No commits in common. "2881aaf06e5e6647460cb19c7891d54bbeb45f9e" and "a502dd97a9907968266fff59765c775cdbd7d66b" have entirely different histories.
2881aaf06e
...
a502dd97a9
12
poetry.lock
generated
12
poetry.lock
generated
|
|
@ -1,4 +1,4 @@
|
||||||
# This file is automatically @generated by Poetry 2.1.2 and should not be changed by hand.
|
# This file is automatically @generated by Poetry 2.3.3 and should not be changed by hand.
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "alabaster"
|
name = "alabaster"
|
||||||
|
|
@ -230,14 +230,14 @@ uvloop = ["uvloop (>=0.15.2) ; sys_platform != \"win32\"", "winloop (>=0.5.0) ;
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "cachetools"
|
name = "cachetools"
|
||||||
version = "7.0.6"
|
version = "7.0.5"
|
||||||
description = "Extensible memoizing collections and decorators"
|
description = "Extensible memoizing collections and decorators"
|
||||||
optional = false
|
optional = false
|
||||||
python-versions = ">=3.10"
|
python-versions = ">=3.10"
|
||||||
groups = ["test"]
|
groups = ["test"]
|
||||||
files = [
|
files = [
|
||||||
{file = "cachetools-7.0.6-py3-none-any.whl", hash = "sha256:4e94956cfdd3086f12042cdd29318f5ced3893014f7d0d059bf3ead3f85b7f8b"},
|
{file = "cachetools-7.0.5-py3-none-any.whl", hash = "sha256:46bc8ebefbe485407621d0a4264b23c080cedd913921bad7ac3ed2f26c183114"},
|
||||||
{file = "cachetools-7.0.6.tar.gz", hash = "sha256:e5d524d36d65703a87243a26ff08ad84f73352adbeafb1cde81e207b456aaf24"},
|
{file = "cachetools-7.0.5.tar.gz", hash = "sha256:0cd042c24377200c1dcd225f8b7b12b0ca53cc2c961b43757e774ebe190fd990"},
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
|
|
@ -1271,7 +1271,7 @@ files = [
|
||||||
|
|
||||||
[package.dependencies]
|
[package.dependencies]
|
||||||
attrs = ">=22.2.0"
|
attrs = ">=22.2.0"
|
||||||
jsonschema-specifications = ">=2023.03.6"
|
jsonschema-specifications = ">=2023.3.6"
|
||||||
referencing = ">=0.28.4"
|
referencing = ">=0.28.4"
|
||||||
rpds-py = ">=0.25.0"
|
rpds-py = ">=0.25.0"
|
||||||
|
|
||||||
|
|
@ -3749,4 +3749,4 @@ files = [
|
||||||
[metadata]
|
[metadata]
|
||||||
lock-version = "2.1"
|
lock-version = "2.1"
|
||||||
python-versions = ">=3.10"
|
python-versions = ">=3.10"
|
||||||
content-hash = "66c9adf647316db90f963da05e8a83574378bfa4db2c69ce751446b5ee7c408c"
|
content-hash = "ffde300b2fc93161d2279a6e2b899bc988d3b5eb3833135821830affc9a5fb62"
|
||||||
|
|
|
||||||
|
|
@ -50,7 +50,7 @@ dependencies = [
|
||||||
"pyyaml (>=6.0.3,<7.0.0)",
|
"pyyaml (>=6.0.3,<7.0.0)",
|
||||||
"click (>=8.1.0,<9.0.0)",
|
"click (>=8.1.0,<9.0.0)",
|
||||||
"matplotlib (>=3.8.0,<4.0.0)",
|
"matplotlib (>=3.8.0,<4.0.0)",
|
||||||
"paramiko (>=3.5.1)"
|
"paramiko (>=4.0.0)"
|
||||||
]
|
]
|
||||||
|
|
||||||
# [project.optional-dependencies] Commented out to prevent Tox tests from failing
|
# [project.optional-dependencies] Commented out to prevent Tox tests from failing
|
||||||
|
|
@ -149,11 +149,6 @@ exclude = '''
|
||||||
|
|
||||||
[tool.pytest.ini_options]
|
[tool.pytest.ini_options]
|
||||||
pythonpath = ["src"]
|
pythonpath = ["src"]
|
||||||
filterwarnings = [
|
|
||||||
# FastAPI emits this internally when handling 422 responses; the constant
|
|
||||||
# is not yet renamed in the installed starlette version, so we can't migrate.
|
|
||||||
"ignore:'HTTP_422_UNPROCESSABLE_ENTITY' is deprecated:DeprecationWarning",
|
|
||||||
]
|
|
||||||
|
|
||||||
[tool.isort]
|
[tool.isort]
|
||||||
profile = "black"
|
profile = "black"
|
||||||
|
|
|
||||||
|
|
@ -68,7 +68,7 @@ _HEARTBEAT_INTERVAL = 30 # seconds between heartbeats
|
||||||
_POLL_TIMEOUT = 30 # server-side long-poll duration
|
_POLL_TIMEOUT = 30 # server-side long-poll duration
|
||||||
_POLL_CLIENT_TIMEOUT = 40 # client read timeout — slightly longer than server
|
_POLL_CLIENT_TIMEOUT = 40 # client read timeout — slightly longer than server
|
||||||
_RECONNECT_PAUSE = 5 # seconds to wait after a poll error before retrying
|
_RECONNECT_PAUSE = 5 # seconds to wait after a poll error before retrying
|
||||||
_CHUNK_SIZE = 10 * 1024 * 1024 # 10 MB per chunk — fast enough for git-LFS to process within timeout
|
_CHUNK_SIZE = 50 * 1024 * 1024 # 50 MB — well below Cloudflare's 100 MB limit
|
||||||
_DIRECT_THRESHOLD = 90 * 1024 * 1024 # files above this use chunked upload
|
_DIRECT_THRESHOLD = 90 * 1024 * 1024 # files above this use chunked upload
|
||||||
_CAPTURE_SAMPLES = 4096 # IQ samples per inference window
|
_CAPTURE_SAMPLES = 4096 # IQ samples per inference window
|
||||||
_IDLE_LABELS = frozenset({"noise", "idle", "no_signal", "unknown_protocol", "background"})
|
_IDLE_LABELS = frozenset({"noise", "idle", "no_signal", "unknown_protocol", "background"})
|
||||||
|
|
@ -93,24 +93,16 @@ class NodeAgent:
|
||||||
name: str,
|
name: str,
|
||||||
sdr_device: str = "unknown",
|
sdr_device: str = "unknown",
|
||||||
insecure: bool = False,
|
insecure: bool = False,
|
||||||
role: str = "general",
|
|
||||||
session_code: str | None = None,
|
|
||||||
) -> None:
|
) -> None:
|
||||||
self.hub_url = hub_url.rstrip("/")
|
self.hub_url = hub_url.rstrip("/")
|
||||||
self.api_key = api_key
|
self.api_key = api_key
|
||||||
self.name = name
|
self.name = name
|
||||||
self.sdr_device = sdr_device
|
self.sdr_device = sdr_device
|
||||||
self.insecure = insecure
|
self.insecure = insecure
|
||||||
self.role = role
|
|
||||||
self.session_code = session_code
|
|
||||||
|
|
||||||
self.node_id: str | None = None
|
self.node_id: str | None = None
|
||||||
self._stop = threading.Event()
|
self._stop = threading.Event()
|
||||||
|
|
||||||
# ── TX state ────────────────────────────────────────────────────────
|
|
||||||
self._tx_stop = threading.Event()
|
|
||||||
self._tx_thread: threading.Thread | None = None
|
|
||||||
|
|
||||||
# ── Inference state ─────────────────────────────────────────────────
|
# ── Inference state ─────────────────────────────────────────────────
|
||||||
# Protected by _inf_lock for cross-thread model swaps.
|
# Protected by _inf_lock for cross-thread model swaps.
|
||||||
self._inf_lock = threading.Lock()
|
self._inf_lock = threading.Lock()
|
||||||
|
|
@ -180,27 +172,19 @@ class NodeAgent:
|
||||||
capabilities = ["campaign"]
|
capabilities = ["campaign"]
|
||||||
if self._ort_available:
|
if self._ort_available:
|
||||||
capabilities.append("inference")
|
capabilities.append("inference")
|
||||||
if self.role == "tx":
|
resp = self._post(
|
||||||
capabilities.append("transmit")
|
"/composer/nodes/register",
|
||||||
payload: dict = {
|
json={
|
||||||
"name": self.name,
|
"name": self.name,
|
||||||
"sdr_device": self.sdr_device,
|
"sdr_device": self.sdr_device,
|
||||||
"ria_toolkit_version": self._ria_version,
|
"ria_toolkit_version": self._ria_version,
|
||||||
"capabilities": capabilities,
|
"capabilities": capabilities,
|
||||||
"role": self.role,
|
},
|
||||||
}
|
timeout=15,
|
||||||
if self.session_code:
|
)
|
||||||
payload["session_code"] = self.session_code
|
|
||||||
resp = self._post("/composer/nodes/register", json=payload, timeout=15)
|
|
||||||
resp.raise_for_status()
|
resp.raise_for_status()
|
||||||
self.node_id = resp.json()["node_id"]
|
self.node_id = resp.json()["node_id"]
|
||||||
logger.info(
|
logger.info("Registered as %r (node_id=%s)", self.name, self.node_id)
|
||||||
"Registered as %r (node_id=%s, role=%s%s)",
|
|
||||||
self.name,
|
|
||||||
self.node_id,
|
|
||||||
self.role,
|
|
||||||
f", session_code={self.session_code!r}" if self.session_code else "",
|
|
||||||
)
|
|
||||||
|
|
||||||
def _deregister(self) -> None:
|
def _deregister(self) -> None:
|
||||||
if not self.node_id:
|
if not self.node_id:
|
||||||
|
|
@ -261,10 +245,9 @@ class NodeAgent:
|
||||||
if command == "run_campaign":
|
if command == "run_campaign":
|
||||||
campaign_id: str = cmd.get("campaign_id") or str(uuid.uuid4())
|
campaign_id: str = cmd.get("campaign_id") or str(uuid.uuid4())
|
||||||
config_dict: dict = cmd.get("payload") or {}
|
config_dict: dict = cmd.get("payload") or {}
|
||||||
skip_local_tx: bool = bool(cmd.get("skip_local_tx", False))
|
|
||||||
threading.Thread(
|
threading.Thread(
|
||||||
target=self._run_campaign,
|
target=self._run_campaign,
|
||||||
args=(campaign_id, config_dict, skip_local_tx),
|
args=(campaign_id, config_dict),
|
||||||
daemon=True,
|
daemon=True,
|
||||||
name=f"campaign-{campaign_id[:8]}",
|
name=f"campaign-{campaign_id[:8]}",
|
||||||
).start()
|
).start()
|
||||||
|
|
@ -286,17 +269,6 @@ class NodeAgent:
|
||||||
self._stop_inference()
|
self._stop_inference()
|
||||||
elif command == "configure_inference":
|
elif command == "configure_inference":
|
||||||
self._queue_sdr_config(cmd)
|
self._queue_sdr_config(cmd)
|
||||||
elif command == "start_transmit":
|
|
||||||
threading.Thread(
|
|
||||||
target=self._start_transmit,
|
|
||||||
args=(cmd,),
|
|
||||||
daemon=True,
|
|
||||||
name="ria-start-tx",
|
|
||||||
).start()
|
|
||||||
elif command == "stop_transmit":
|
|
||||||
self._stop_transmit()
|
|
||||||
elif command == "configure_transmit":
|
|
||||||
logger.info("configure_transmit received — will apply on next step boundary")
|
|
||||||
else:
|
else:
|
||||||
logger.warning("Unknown command %r — ignored", command)
|
logger.warning("Unknown command %r — ignored", command)
|
||||||
|
|
||||||
|
|
@ -304,7 +276,7 @@ class NodeAgent:
|
||||||
# Campaign execution
|
# Campaign execution
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
def _run_campaign(self, campaign_id: str, config_dict: dict, skip_local_tx: bool = False) -> None:
|
def _run_campaign(self, campaign_id: str, config_dict: dict) -> None:
|
||||||
try:
|
try:
|
||||||
from ria_toolkit_oss.orchestration.campaign import CampaignConfig
|
from ria_toolkit_oss.orchestration.campaign import CampaignConfig
|
||||||
from ria_toolkit_oss.orchestration.executor import CampaignExecutor
|
from ria_toolkit_oss.orchestration.executor import CampaignExecutor
|
||||||
|
|
@ -316,10 +288,10 @@ class NodeAgent:
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
logger.info("Campaign %s starting (skip_local_tx=%s)", campaign_id[:8], skip_local_tx)
|
logger.info("Campaign %s starting", campaign_id[:8])
|
||||||
try:
|
try:
|
||||||
config = CampaignConfig.from_dict(config_dict)
|
config = CampaignConfig.from_dict(config_dict)
|
||||||
executor = CampaignExecutor(config, skip_local_tx=skip_local_tx)
|
executor = CampaignExecutor(config)
|
||||||
result = executor.run()
|
result = executor.run()
|
||||||
logger.info("Campaign %s completed — uploading recordings", campaign_id[:8])
|
logger.info("Campaign %s completed — uploading recordings", campaign_id[:8])
|
||||||
self._upload_recordings(campaign_id, config, result)
|
self._upload_recordings(campaign_id, config, result)
|
||||||
|
|
@ -329,58 +301,6 @@ class NodeAgent:
|
||||||
logger.error("Campaign %s failed: %s", campaign_id[:8], exc)
|
logger.error("Campaign %s failed: %s", campaign_id[:8], exc)
|
||||||
self._report_campaign_status(campaign_id, "failed", error=str(exc))
|
self._report_campaign_status(campaign_id, "failed", error=str(exc))
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
# TX execution
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
|
|
||||||
def _start_transmit(self, cmd: dict) -> None:
|
|
||||||
"""Execute a synthetic transmit campaign using TxExecutor.
|
|
||||||
|
|
||||||
The command payload mirrors a TransmitterConfig dict with an optional
|
|
||||||
``schedule`` of steps. Each step synthesises a signal and transmits it
|
|
||||||
via the local SDR in TX mode.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
from ria_toolkit_oss.orchestration.tx_executor import TxExecutor
|
|
||||||
except ImportError as exc:
|
|
||||||
logger.error("start_transmit: TxExecutor not available: %s", exc)
|
|
||||||
return
|
|
||||||
|
|
||||||
if self._tx_thread and self._tx_thread.is_alive():
|
|
||||||
logger.warning("start_transmit: TX already running — ignoring duplicate command")
|
|
||||||
return
|
|
||||||
|
|
||||||
self._tx_stop.clear()
|
|
||||||
campaign_id: str = cmd.get("campaign_id") or str(uuid.uuid4())
|
|
||||||
executor = TxExecutor(
|
|
||||||
config=cmd,
|
|
||||||
sdr_device=self.sdr_device,
|
|
||||||
stop_event=self._tx_stop,
|
|
||||||
)
|
|
||||||
self._tx_thread = threading.Thread(
|
|
||||||
target=self._run_tx_campaign,
|
|
||||||
args=(executor, campaign_id),
|
|
||||||
daemon=True,
|
|
||||||
name=f"tx-campaign-{campaign_id[:8]}",
|
|
||||||
)
|
|
||||||
self._tx_thread.start()
|
|
||||||
|
|
||||||
def _run_tx_campaign(self, executor: Any, campaign_id: str) -> None:
|
|
||||||
try:
|
|
||||||
executor.run()
|
|
||||||
logger.info("TX campaign %s completed", campaign_id[:8])
|
|
||||||
self._report_campaign_status(campaign_id, "completed")
|
|
||||||
except Exception as exc:
|
|
||||||
logger.error("TX campaign %s failed: %s", campaign_id[:8], exc)
|
|
||||||
self._report_campaign_status(campaign_id, "failed", error=str(exc))
|
|
||||||
|
|
||||||
def _stop_transmit(self) -> None:
|
|
||||||
"""Signal the TX loop to stop gracefully."""
|
|
||||||
self._tx_stop.set()
|
|
||||||
if self._tx_thread and self._tx_thread.is_alive():
|
|
||||||
self._tx_thread.join(timeout=5.0)
|
|
||||||
logger.info("TX stopped")
|
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
# Inference — model loading
|
# Inference — model loading
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
|
|
@ -659,18 +579,13 @@ class NodeAgent:
|
||||||
base_url = f"{self.hub_url}/datasets/upload"
|
base_url = f"{self.hub_url}/datasets/upload"
|
||||||
steps = (result.get("steps") if isinstance(result, dict) else getattr(result, "steps", None)) or []
|
steps = (result.get("steps") if isinstance(result, dict) else getattr(result, "steps", None)) or []
|
||||||
|
|
||||||
output_obj = getattr(config, "output", None)
|
|
||||||
folder = getattr(output_obj, "folder", None)
|
|
||||||
campaign_name: str = folder if folder is not None else (getattr(config, "name", None) or "")
|
|
||||||
for step in steps:
|
for step in steps:
|
||||||
output_path: str | None = getattr(step, "output_path", None)
|
output_path: str | None = getattr(step, "output_path", None)
|
||||||
if not output_path:
|
if not output_path:
|
||||||
continue
|
continue
|
||||||
device_id: str = getattr(step, "transmitter_id", "") or ""
|
device_id: str = getattr(step, "transmitter_id", "") or ""
|
||||||
for fpath in _sigmf_files(output_path):
|
for fpath in _sigmf_files(output_path):
|
||||||
basename = os.path.basename(fpath)
|
filename = os.path.basename(fpath)
|
||||||
path_parts = [p for p in (campaign_name, device_id) if p]
|
|
||||||
filename = "/".join(path_parts + [basename])
|
|
||||||
metadata = {
|
metadata = {
|
||||||
"filename": filename,
|
"filename": filename,
|
||||||
"repo_owner": repo_owner,
|
"repo_owner": repo_owner,
|
||||||
|
|
@ -756,7 +671,7 @@ class NodeAgent:
|
||||||
headers=headers,
|
headers=headers,
|
||||||
files={"file": (filename, chunk, "application/octet-stream")},
|
files={"file": (filename, chunk, "application/octet-stream")},
|
||||||
data={**metadata, "upload_id": upload_id, "chunk_index": i, "total_chunks": total_chunks},
|
data={**metadata, "upload_id": upload_id, "chunk_index": i, "total_chunks": total_chunks},
|
||||||
timeout=(30, None), # 30s connect, no read timeout — server may take minutes on final chunk
|
timeout=120,
|
||||||
verify=verify,
|
verify=verify,
|
||||||
)
|
)
|
||||||
if not resp.ok:
|
if not resp.ok:
|
||||||
|
|
@ -933,21 +848,6 @@ def main() -> None:
|
||||||
choices=["DEBUG", "INFO", "WARNING", "ERROR"],
|
choices=["DEBUG", "INFO", "WARNING", "ERROR"],
|
||||||
help="Logging verbosity (default: INFO)",
|
help="Logging verbosity (default: INFO)",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
|
||||||
"--role",
|
|
||||||
default=None,
|
|
||||||
choices=["general", "rx", "tx"],
|
|
||||||
help=("Node role reported to the hub. " "'tx' enables synthetic transmission commands. " "Default: general"),
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--session-code",
|
|
||||||
default=None,
|
|
||||||
metavar="CODE",
|
|
||||||
help=(
|
|
||||||
"3-word session code to pair this TX agent with a waiting campaign, "
|
|
||||||
"e.g. 'amber-peak-transmit'. Supplied by the campaign UI."
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
|
@ -961,8 +861,6 @@ def main() -> None:
|
||||||
device = args.device or cfg.get("device", "unknown")
|
device = args.device or cfg.get("device", "unknown")
|
||||||
insecure = args.insecure if args.insecure is not None else cfg.get("insecure", False)
|
insecure = args.insecure if args.insecure is not None else cfg.get("insecure", False)
|
||||||
log_level = args.log_level or cfg.get("log_level", "INFO")
|
log_level = args.log_level or cfg.get("log_level", "INFO")
|
||||||
role = args.role or cfg.get("role", "general")
|
|
||||||
session_code = args.session_code or cfg.get("session_code")
|
|
||||||
|
|
||||||
if not hub:
|
if not hub:
|
||||||
parser.error("--hub is required (or set 'hub' in the config file)")
|
parser.error("--hub is required (or set 'hub' in the config file)")
|
||||||
|
|
@ -990,8 +888,6 @@ def main() -> None:
|
||||||
name=name,
|
name=name,
|
||||||
sdr_device=device,
|
sdr_device=device,
|
||||||
insecure=insecure,
|
insecure=insecure,
|
||||||
role=role,
|
|
||||||
session_code=session_code,
|
|
||||||
)
|
)
|
||||||
agent.run()
|
agent.run()
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -233,9 +233,6 @@ class TransmitterConfig:
|
||||||
# For sdr_remote control — keys: host, ssh_user, ssh_key_path, device_type, device_id, zmq_port
|
# For sdr_remote control — keys: host, ssh_user, ssh_key_path, device_type, device_id, zmq_port
|
||||||
sdr_remote: Optional[dict] = None
|
sdr_remote: Optional[dict] = None
|
||||||
|
|
||||||
# For sdr_agent control — keys: modulation, order, symbol_rate, center_frequency, filter, rolloff
|
|
||||||
sdr_agent: Optional[dict] = None
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_dict(cls, d: dict) -> "TransmitterConfig":
|
def from_dict(cls, d: dict) -> "TransmitterConfig":
|
||||||
schedule = [CaptureStep.from_dict(s) for s in d.get("schedule", [])]
|
schedule = [CaptureStep.from_dict(s) for s in d.get("schedule", [])]
|
||||||
|
|
@ -247,7 +244,6 @@ class TransmitterConfig:
|
||||||
script=d.get("script"),
|
script=d.get("script"),
|
||||||
device=d.get("device"),
|
device=d.get("device"),
|
||||||
sdr_remote=d.get("sdr_remote"),
|
sdr_remote=d.get("sdr_remote"),
|
||||||
sdr_agent=d.get("sdr_agent"),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -276,7 +272,6 @@ class OutputConfig:
|
||||||
path: str = "recordings"
|
path: str = "recordings"
|
||||||
device_id: Optional[str] = None # for device-profile campaigns
|
device_id: Optional[str] = None # for device-profile campaigns
|
||||||
repo: Optional[str] = None
|
repo: Optional[str] = None
|
||||||
folder: Optional[str] = None # repo subfolder: None = use campaign name, "" = no subfolder, str = custom
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_dict(cls, d: dict) -> "OutputConfig":
|
def from_dict(cls, d: dict) -> "OutputConfig":
|
||||||
|
|
@ -285,7 +280,6 @@ class OutputConfig:
|
||||||
path=str(d.get("path", "recordings")),
|
path=str(d.get("path", "recordings")),
|
||||||
device_id=d.get("device_id"),
|
device_id=d.get("device_id"),
|
||||||
repo=d.get("repo"),
|
repo=d.get("repo"),
|
||||||
folder=d.get("folder"),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -299,7 +293,6 @@ class CampaignConfig:
|
||||||
qa: QAConfig = field(default_factory=QAConfig)
|
qa: QAConfig = field(default_factory=QAConfig)
|
||||||
output: OutputConfig = field(default_factory=OutputConfig)
|
output: OutputConfig = field(default_factory=OutputConfig)
|
||||||
mode: str = "controlled_testbed"
|
mode: str = "controlled_testbed"
|
||||||
loops: int = 1 # repeat full schedule this many times; labels get _run{N:02d} suffix
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Loaders
|
# Loaders
|
||||||
|
|
@ -327,7 +320,6 @@ class CampaignConfig:
|
||||||
return cls(
|
return cls(
|
||||||
name=safe_name,
|
name=safe_name,
|
||||||
mode=str(campaign_meta.get("mode", "controlled_testbed")),
|
mode=str(campaign_meta.get("mode", "controlled_testbed")),
|
||||||
loops=max(1, int(campaign_meta.get("loops", 1))),
|
|
||||||
recorder=RecorderConfig.from_dict(raw["recorder"]),
|
recorder=RecorderConfig.from_dict(raw["recorder"]),
|
||||||
transmitters=transmitters,
|
transmitters=transmitters,
|
||||||
qa=QAConfig.from_dict(raw.get("qa", {})),
|
qa=QAConfig.from_dict(raw.get("qa", {})),
|
||||||
|
|
@ -392,7 +384,6 @@ class CampaignConfig:
|
||||||
return cls(
|
return cls(
|
||||||
name=safe_name,
|
name=safe_name,
|
||||||
mode=str(campaign_meta.get("mode", "controlled_testbed")),
|
mode=str(campaign_meta.get("mode", "controlled_testbed")),
|
||||||
loops=max(1, int(campaign_meta.get("loops", 1))),
|
|
||||||
recorder=RecorderConfig.from_dict(raw["recorder"]),
|
recorder=RecorderConfig.from_dict(raw["recorder"]),
|
||||||
transmitters=transmitters,
|
transmitters=transmitters,
|
||||||
qa=QAConfig.from_dict(raw.get("qa", {})),
|
qa=QAConfig.from_dict(raw.get("qa", {})),
|
||||||
|
|
@ -495,9 +486,9 @@ class CampaignConfig:
|
||||||
)
|
)
|
||||||
|
|
||||||
def total_capture_time_s(self) -> float:
|
def total_capture_time_s(self) -> float:
|
||||||
"""Sum of all step durations across all transmitters and loops."""
|
"""Sum of all step durations across all transmitters."""
|
||||||
return sum(step.duration for tx in self.transmitters for step in tx.schedule) * self.loops
|
return sum(step.duration for tx in self.transmitters for step in tx.schedule)
|
||||||
|
|
||||||
def total_steps(self) -> int:
|
def total_steps(self) -> int:
|
||||||
"""Total number of capture steps across all transmitters and loops."""
|
"""Total number of capture steps across all transmitters."""
|
||||||
return sum(len(tx.schedule) for tx in self.transmitters) * self.loops
|
return sum(len(tx.schedule) for tx in self.transmitters)
|
||||||
|
|
|
||||||
|
|
@ -5,9 +5,8 @@ from __future__ import annotations
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import subprocess
|
import subprocess
|
||||||
import threading
|
|
||||||
import time
|
import time
|
||||||
from dataclasses import dataclass, field, replace
|
from dataclasses import dataclass, field
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Callable, Optional
|
from typing import Callable, Optional
|
||||||
|
|
||||||
|
|
@ -17,7 +16,6 @@ from ria_toolkit_oss.io.recording import to_sigmf
|
||||||
from .campaign import CampaignConfig, CaptureStep, TransmitterConfig
|
from .campaign import CampaignConfig, CaptureStep, TransmitterConfig
|
||||||
from .labeler import build_output_filename, label_recording
|
from .labeler import build_output_filename, label_recording
|
||||||
from .qa import QAResult, check_recording
|
from .qa import QAResult, check_recording
|
||||||
from .tx_executor import TxExecutor
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
@ -171,21 +169,6 @@ def _run_script(script: str, *args: str, timeout: float = 15.0) -> str:
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
def _extract_tx_params(transmitter: TransmitterConfig) -> dict | None:
|
|
||||||
"""Build a tx_params dict from a transmitter's signal config for SigMF labeling.
|
|
||||||
|
|
||||||
For sdr_agent transmitters, returns the synthetic generation parameters
|
|
||||||
(modulation, order, symbol_rate, etc.) so recordings capture what was
|
|
||||||
transmitted. Returns None for control methods without signal-level params.
|
|
||||||
"""
|
|
||||||
sdr_agent_cfg = getattr(transmitter, "sdr_agent", None)
|
|
||||||
if not sdr_agent_cfg:
|
|
||||||
return None
|
|
||||||
# Extract known signal-level fields; ignore infra fields
|
|
||||||
_INFRA_KEYS = {"node_id", "session_code"}
|
|
||||||
return {k: v for k, v in sdr_agent_cfg.items() if k not in _INFRA_KEYS and v is not None}
|
|
||||||
|
|
||||||
|
|
||||||
class CampaignExecutor:
|
class CampaignExecutor:
|
||||||
"""Executes a :class:`CampaignConfig` end-to-end.
|
"""Executes a :class:`CampaignConfig` end-to-end.
|
||||||
|
|
||||||
|
|
@ -209,14 +192,11 @@ class CampaignExecutor:
|
||||||
config: CampaignConfig,
|
config: CampaignConfig,
|
||||||
progress_cb: Optional[Callable[[int, int, StepResult], None]] = None,
|
progress_cb: Optional[Callable[[int, int, StepResult], None]] = None,
|
||||||
verbose: bool = False,
|
verbose: bool = False,
|
||||||
skip_local_tx: bool = False,
|
|
||||||
):
|
):
|
||||||
self.config = config
|
self.config = config
|
||||||
self.progress_cb = progress_cb
|
self.progress_cb = progress_cb
|
||||||
self.skip_local_tx = skip_local_tx
|
|
||||||
self._sdr = None
|
self._sdr = None
|
||||||
self._remote_tx_controllers: dict = {}
|
self._remote_tx_controllers: dict = {}
|
||||||
self._tx_executors: dict[str, tuple] = {} # tx_id → (TxExecutor, stop_event, thread)
|
|
||||||
|
|
||||||
if verbose:
|
if verbose:
|
||||||
logging.basicConfig(level=logging.DEBUG)
|
logging.basicConfig(level=logging.DEBUG)
|
||||||
|
|
@ -236,12 +216,10 @@ class CampaignExecutor:
|
||||||
"""
|
"""
|
||||||
result = CampaignResult(campaign_name=self.config.name)
|
result = CampaignResult(campaign_name=self.config.name)
|
||||||
|
|
||||||
loops = self.config.loops
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Starting campaign '{self.config.name}': "
|
f"Starting campaign '{self.config.name}': "
|
||||||
f"{self.config.total_steps()} steps"
|
f"{self.config.total_steps()} steps, "
|
||||||
+ (f" ({self.config.total_steps() // loops} × {loops} loops)" if loops > 1 else "")
|
f"~{self.config.total_capture_time_s():.0f}s capture time"
|
||||||
+ f", ~{self.config.total_capture_time_s():.0f}s capture time"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
self._init_sdr()
|
self._init_sdr()
|
||||||
|
|
@ -250,36 +228,29 @@ class CampaignExecutor:
|
||||||
total = self.config.total_steps()
|
total = self.config.total_steps()
|
||||||
step_index = 0
|
step_index = 0
|
||||||
|
|
||||||
for loop_idx in range(loops):
|
for transmitter in self.config.transmitters:
|
||||||
if loops > 1:
|
logger.info(f"Transmitter: {transmitter.id} ({len(transmitter.schedule)} steps)")
|
||||||
logger.info(f"Loop {loop_idx + 1}/{loops}")
|
for step in transmitter.schedule:
|
||||||
for transmitter in self.config.transmitters:
|
step_result = self._execute_step(transmitter, step)
|
||||||
logger.info(f"Transmitter: {transmitter.id} ({len(transmitter.schedule)} steps)")
|
result.steps.append(step_result)
|
||||||
for step in transmitter.schedule:
|
step_index += 1
|
||||||
looped_step = replace(step, label=f"{step.label}_run{loop_idx + 1:02d}") if loops > 1 else step
|
|
||||||
step_result = self._execute_step(transmitter, looped_step)
|
|
||||||
result.steps.append(step_result)
|
|
||||||
step_index += 1
|
|
||||||
|
|
||||||
if self.progress_cb:
|
if self.progress_cb:
|
||||||
self.progress_cb(step_index, total, step_result)
|
self.progress_cb(step_index, total, step_result)
|
||||||
|
|
||||||
if step_result.error:
|
if step_result.error:
|
||||||
logger.warning(f"Step '{looped_step.label}' error: {step_result.error}")
|
logger.warning(f"Step '{step.label}' error: {step_result.error}")
|
||||||
elif step_result.qa.flagged:
|
elif step_result.qa.flagged:
|
||||||
logger.warning(
|
logger.warning(f"Step '{step.label}' flagged for review: " + "; ".join(step_result.qa.issues))
|
||||||
f"Step '{looped_step.label}' flagged for review: " + "; ".join(step_result.qa.issues)
|
else:
|
||||||
)
|
logger.info(
|
||||||
else:
|
f"Step '{step.label}' OK "
|
||||||
logger.info(
|
f"(SNR {step_result.qa.snr_db:.1f} dB, "
|
||||||
f"Step '{looped_step.label}' OK "
|
f"{step_result.qa.duration_s:.1f}s)"
|
||||||
f"(SNR {step_result.qa.snr_db:.1f} dB, "
|
)
|
||||||
f"{step_result.qa.duration_s:.1f}s)"
|
|
||||||
)
|
|
||||||
finally:
|
finally:
|
||||||
self._close_sdr()
|
self._close_sdr()
|
||||||
self._close_remote_tx_controllers()
|
self._close_remote_tx_controllers()
|
||||||
self._close_tx_executors()
|
|
||||||
|
|
||||||
result.end_time = time.time()
|
result.end_time = time.time()
|
||||||
logger.info(
|
logger.info(
|
||||||
|
|
@ -354,12 +325,6 @@ class CampaignExecutor:
|
||||||
logger.warning(f"Error closing remote Tx controller {tx_id}: {exc}")
|
logger.warning(f"Error closing remote Tx controller {tx_id}: {exc}")
|
||||||
self._remote_tx_controllers.clear()
|
self._remote_tx_controllers.clear()
|
||||||
|
|
||||||
def _close_tx_executors(self) -> None:
|
|
||||||
for tx_id, (_, stop_event, t) in list(self._tx_executors.items()):
|
|
||||||
stop_event.set()
|
|
||||||
t.join(timeout=5.0)
|
|
||||||
self._tx_executors.clear()
|
|
||||||
|
|
||||||
def _record(self, duration_s: float) -> Recording:
|
def _record(self, duration_s: float) -> Recording:
|
||||||
"""Capture ``duration_s`` seconds of IQ samples."""
|
"""Capture ``duration_s`` seconds of IQ samples."""
|
||||||
num_samples = int(duration_s * self.config.recorder.sample_rate)
|
num_samples = int(duration_s * self.config.recorder.sample_rate)
|
||||||
|
|
@ -404,7 +369,6 @@ class CampaignExecutor:
|
||||||
step=step,
|
step=step,
|
||||||
capture_timestamp=capture_timestamp,
|
capture_timestamp=capture_timestamp,
|
||||||
campaign_name=self.config.name,
|
campaign_name=self.config.name,
|
||||||
tx_params=_extract_tx_params(transmitter),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# QA
|
# QA
|
||||||
|
|
@ -473,30 +437,6 @@ class CampaignExecutor:
|
||||||
# Start transmission in background; _record() runs concurrently
|
# Start transmission in background; _record() runs concurrently
|
||||||
ctrl.transmit_async(step.duration + 1.0)
|
ctrl.transmit_async(step.duration + 1.0)
|
||||||
|
|
||||||
elif transmitter.control_method == "sdr_agent":
|
|
||||||
if self.skip_local_tx:
|
|
||||||
logger.debug(f"skip_local_tx — TX for '{transmitter.id}' delegated to TX agent node")
|
|
||||||
return
|
|
||||||
if not transmitter.sdr_agent:
|
|
||||||
logger.warning(f"Transmitter '{transmitter.id}' has no sdr_agent config — skipping")
|
|
||||||
return
|
|
||||||
step_dict: dict = {"label": step.label, "duration": step.duration + 1.0}
|
|
||||||
if step.power_dbm is not None:
|
|
||||||
step_dict["power_dbm"] = step.power_dbm
|
|
||||||
tx_config = {
|
|
||||||
"id": transmitter.id,
|
|
||||||
"sdr_agent": transmitter.sdr_agent,
|
|
||||||
"schedule": [step_dict],
|
|
||||||
}
|
|
||||||
rec = self.config.recorder
|
|
||||||
tx_device = transmitter.device or rec.device
|
|
||||||
sdr_device = _DEVICE_ALIASES.get(tx_device.lower(), tx_device.lower())
|
|
||||||
stop_event = threading.Event()
|
|
||||||
executor = TxExecutor(tx_config, sdr_device=sdr_device, stop_event=stop_event)
|
|
||||||
t = threading.Thread(target=executor.run, daemon=True, name=f"tx-{transmitter.id}")
|
|
||||||
self._tx_executors[transmitter.id] = (executor, stop_event, t)
|
|
||||||
t.start()
|
|
||||||
|
|
||||||
else:
|
else:
|
||||||
logger.warning(f"Unknown control method '{transmitter.control_method}' — skipping")
|
logger.warning(f"Unknown control method '{transmitter.control_method}' — skipping")
|
||||||
|
|
||||||
|
|
@ -519,13 +459,6 @@ class CampaignExecutor:
|
||||||
if ctrl is not None:
|
if ctrl is not None:
|
||||||
ctrl.wait_transmit(timeout=step.duration + 10.0)
|
ctrl.wait_transmit(timeout=step.duration + 10.0)
|
||||||
|
|
||||||
elif transmitter.control_method == "sdr_agent":
|
|
||||||
entry = self._tx_executors.pop(transmitter.id, None)
|
|
||||||
if entry is not None:
|
|
||||||
_, stop_event, t = entry
|
|
||||||
stop_event.set()
|
|
||||||
t.join(timeout=step.duration + 10.0)
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _step_params_json(transmitter: TransmitterConfig, step: CaptureStep) -> str:
|
def _step_params_json(transmitter: TransmitterConfig, step: CaptureStep) -> str:
|
||||||
"""Serialise step parameters to a JSON string for the control script."""
|
"""Serialise step parameters to a JSON string for the control script."""
|
||||||
|
|
|
||||||
|
|
@ -15,7 +15,6 @@ def label_recording(
|
||||||
step: CaptureStep,
|
step: CaptureStep,
|
||||||
capture_timestamp: float,
|
capture_timestamp: float,
|
||||||
campaign_name: Optional[str] = None,
|
campaign_name: Optional[str] = None,
|
||||||
tx_params: Optional[dict] = None,
|
|
||||||
) -> Recording:
|
) -> Recording:
|
||||||
"""Apply device identity and capture configuration labels to a recording's metadata.
|
"""Apply device identity and capture configuration labels to a recording's metadata.
|
||||||
|
|
||||||
|
|
@ -28,9 +27,6 @@ def label_recording(
|
||||||
step: The capture step that was active during this recording.
|
step: The capture step that was active during this recording.
|
||||||
capture_timestamp: Unix timestamp (float) of when capture started.
|
capture_timestamp: Unix timestamp (float) of when capture started.
|
||||||
campaign_name: Optional campaign name for cross-recording reference.
|
campaign_name: Optional campaign name for cross-recording reference.
|
||||||
tx_params: Optional dict of transmitter signal parameters (e.g. modulation,
|
|
||||||
order, symbol_rate) written as ``ria:tx_<key>`` fields so downstream
|
|
||||||
training pipelines know what was transmitted into the recording.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The same recording with updated metadata.
|
The same recording with updated metadata.
|
||||||
|
|
@ -61,11 +57,6 @@ def label_recording(
|
||||||
if step.power_dbm is not None:
|
if step.power_dbm is not None:
|
||||||
recording.update_metadata("tx_power_dbm", step.power_dbm)
|
recording.update_metadata("tx_power_dbm", step.power_dbm)
|
||||||
|
|
||||||
# Transmitter signal parameters (e.g. from sdr_agent synthetic generation)
|
|
||||||
if tx_params:
|
|
||||||
for key, value in tx_params.items():
|
|
||||||
recording.update_metadata(f"tx_{key}", value)
|
|
||||||
|
|
||||||
return recording
|
return recording
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,299 +0,0 @@
|
||||||
"""TX campaign executor — synthesises and transmits signals via a local SDR.
|
|
||||||
|
|
||||||
The TxExecutor receives a transmitter config dict (matching the
|
|
||||||
``sdr_agent`` control method's schema) and a step schedule, then for each
|
|
||||||
step builds a signal chain with the block generator and transmits it via
|
|
||||||
the local SDR device.
|
|
||||||
|
|
||||||
Supported modulations (``modulation`` field in config):
|
|
||||||
BPSK, QPSK, 8PSK, 16QAM, 64QAM, 256QAM, FSK, OOK, GMSK, OQPSK
|
|
||||||
|
|
||||||
Example config dict (matches CampaignConfig transmitter with
|
|
||||||
``control_method: sdr_agent``)::
|
|
||||||
|
|
||||||
{
|
|
||||||
"id": "synthetic-tx",
|
|
||||||
"type": "sdr",
|
|
||||||
"control_method": "sdr_agent",
|
|
||||||
"sdr_agent": {
|
|
||||||
"modulation": "QPSK",
|
|
||||||
"order": 4,
|
|
||||||
"symbol_rate": 1000000,
|
|
||||||
"center_frequency": 0.0,
|
|
||||||
"filter": "rrc",
|
|
||||||
"rolloff": 0.35
|
|
||||||
},
|
|
||||||
"schedule": [
|
|
||||||
{"label": "step1", "duration": 10, "power_dbm": -10}
|
|
||||||
]
|
|
||||||
}
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import logging
|
|
||||||
import threading
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
def _parse_hz(val: object) -> float:
|
|
||||||
"""Parse a frequency value that may be a float (Hz) or a string like '2.45GHz'."""
|
|
||||||
if isinstance(val, (int, float)):
|
|
||||||
return float(val)
|
|
||||||
s = str(val).strip()
|
|
||||||
for suffix, mult in (("GHz", 1e9), ("MHz", 1e6), ("kHz", 1e3), ("Hz", 1.0)):
|
|
||||||
if s.endswith(suffix):
|
|
||||||
return float(s[: -len(suffix)]) * mult
|
|
||||||
return float(s)
|
|
||||||
|
|
||||||
|
|
||||||
def _parse_seconds(val: object) -> float:
|
|
||||||
"""Parse a duration value that may be a float (seconds) or a string like '5s'."""
|
|
||||||
if isinstance(val, (int, float)):
|
|
||||||
return float(val)
|
|
||||||
s = str(val).strip()
|
|
||||||
return float(s[:-1]) if s.endswith("s") else float(s)
|
|
||||||
|
|
||||||
|
|
||||||
# Mapping from modulation name → (PSK/QAM order, generator_type)
|
|
||||||
# 'psk' uses PSKGenerator, 'qam' uses QAMGenerator
|
|
||||||
_MOD_TABLE: dict[str, tuple[int, str]] = {
|
|
||||||
"BPSK": (1, "psk"),
|
|
||||||
"QPSK": (2, "psk"),
|
|
||||||
"8PSK": (3, "psk"),
|
|
||||||
"16QAM": (4, "qam"),
|
|
||||||
"64QAM": (6, "qam"),
|
|
||||||
"256QAM": (8, "qam"),
|
|
||||||
}
|
|
||||||
|
|
||||||
_SPECIAL_MODS = {"FSK", "OOK", "GMSK", "OQPSK"}
|
|
||||||
|
|
||||||
# usrp-uhd-client's tx_recording() streams 2 000-sample chunks and loops the
|
|
||||||
# source buffer for the full tx_time, so only this many samples ever need to
|
|
||||||
# be in RAM regardless of step duration or sample rate.
|
|
||||||
# 50 000 complex64 samples ≈ 400 kB — enough spectral diversity for looping.
|
|
||||||
_SYNTH_BLOCK_SAMPLES = 50_000
|
|
||||||
|
|
||||||
|
|
||||||
class TxExecutor:
|
|
||||||
"""Synthesise and transmit a signal campaign via a local SDR.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
config: Transmitter config dict (must have ``sdr_agent`` sub-dict with
|
|
||||||
modulation params, and ``schedule`` list of step dicts).
|
|
||||||
sdr_device: SDR device name to open in TX mode (e.g. "pluto", "usrp").
|
|
||||||
stop_event: External event that aborts the TX loop mid-step.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
config: dict,
|
|
||||||
sdr_device: str = "unknown",
|
|
||||||
stop_event: threading.Event | None = None,
|
|
||||||
) -> None:
|
|
||||||
self.config = config
|
|
||||||
self.sdr_device = sdr_device
|
|
||||||
self.stop_event = stop_event or threading.Event()
|
|
||||||
self._sdr: Any = None
|
|
||||||
|
|
||||||
def run(self) -> None:
|
|
||||||
"""Execute all steps in the schedule, transmitting for each step duration."""
|
|
||||||
agent_cfg: dict = self.config.get("sdr_agent") or {}
|
|
||||||
schedule: list[dict] = self.config.get("schedule") or []
|
|
||||||
|
|
||||||
if not schedule:
|
|
||||||
logger.warning("TxExecutor: no schedule steps — nothing to transmit")
|
|
||||||
return
|
|
||||||
|
|
||||||
modulation: str = agent_cfg.get("modulation", "QPSK").upper()
|
|
||||||
symbol_rate: float = float(agent_cfg.get("symbol_rate", 1e6))
|
|
||||||
center_freq: float = _parse_hz(agent_cfg.get("center_frequency", 0.0))
|
|
||||||
filter_type: str = agent_cfg.get("filter", "rrc").lower()
|
|
||||||
rolloff: float = float(agent_cfg.get("rolloff", 0.35))
|
|
||||||
loops: int = max(1, int(self.config.get("loops", 1)))
|
|
||||||
|
|
||||||
# Upsampling factor: samples_per_symbol, fixed at 8 for SDR compatibility.
|
|
||||||
sps = 8
|
|
||||||
sample_rate = symbol_rate * sps
|
|
||||||
|
|
||||||
self._init_sdr(sample_rate, center_freq)
|
|
||||||
try:
|
|
||||||
for loop_idx in range(loops):
|
|
||||||
if self.stop_event.is_set():
|
|
||||||
break
|
|
||||||
if loops > 1:
|
|
||||||
logger.info("TX loop %d/%d", loop_idx + 1, loops)
|
|
||||||
for step in schedule:
|
|
||||||
if self.stop_event.is_set():
|
|
||||||
break
|
|
||||||
looped_step = (
|
|
||||||
{**step, "label": f"{step.get('label', 'step')}_run{loop_idx + 1:02d}"} if loops > 1 else step
|
|
||||||
)
|
|
||||||
self._execute_step(looped_step, modulation, sps, symbol_rate, filter_type, rolloff)
|
|
||||||
finally:
|
|
||||||
self._close_sdr()
|
|
||||||
|
|
||||||
def _execute_step(
|
|
||||||
self,
|
|
||||||
step: dict,
|
|
||||||
modulation: str,
|
|
||||||
sps: int,
|
|
||||||
symbol_rate: float,
|
|
||||||
filter_type: str,
|
|
||||||
rolloff: float,
|
|
||||||
) -> None:
|
|
||||||
duration: float = _parse_seconds(step.get("duration", 10.0))
|
|
||||||
label: str = step.get("label", "step")
|
|
||||||
gain: float = float(step.get("power_dbm") or 0.0)
|
|
||||||
sample_rate = symbol_rate * sps
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
"TX step '%s': %.0f s, %s @ %.3f MHz (sps=%d, filter=%s)",
|
|
||||||
label,
|
|
||||||
duration,
|
|
||||||
modulation,
|
|
||||||
symbol_rate / 1e6,
|
|
||||||
sps,
|
|
||||||
filter_type,
|
|
||||||
)
|
|
||||||
|
|
||||||
num_samples = int(duration * sample_rate)
|
|
||||||
|
|
||||||
# Synthesise a short representative block. tx_recording() loops this
|
|
||||||
# buffer for the full tx_time using a 2 000-sample streaming callback,
|
|
||||||
# so peak memory is O(_SYNTH_BLOCK_SAMPLES) regardless of duration.
|
|
||||||
block_size = min(num_samples, _SYNTH_BLOCK_SAMPLES)
|
|
||||||
signal = self._synthesise(modulation, sps, block_size, filter_type, rolloff)
|
|
||||||
|
|
||||||
if self._sdr is not None:
|
|
||||||
try:
|
|
||||||
# Apply gain update if SDR supports it
|
|
||||||
if hasattr(self._sdr, "set_tx_gain"):
|
|
||||||
self._sdr.set_tx_gain(gain)
|
|
||||||
self._sdr.tx_recording(signal, tx_time=duration)
|
|
||||||
except Exception as exc:
|
|
||||||
logger.error("TX step '%s' SDR error: %s", label, exc)
|
|
||||||
else:
|
|
||||||
# No SDR available — simulate by sleeping for the step duration.
|
|
||||||
logger.warning("TX step '%s': no SDR — simulating %.0f s delay", label, duration)
|
|
||||||
self.stop_event.wait(timeout=duration)
|
|
||||||
|
|
||||||
def _synthesise(
|
|
||||||
self,
|
|
||||||
modulation: str,
|
|
||||||
sps: int,
|
|
||||||
num_samples: int,
|
|
||||||
filter_type: str,
|
|
||||||
rolloff: float,
|
|
||||||
):
|
|
||||||
"""Build a block-generator chain and return IQ samples as a numpy array."""
|
|
||||||
try:
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
from ria_toolkit_oss.signal.block_generator import (
|
|
||||||
BinarySource,
|
|
||||||
GMSKModulator,
|
|
||||||
Mapper,
|
|
||||||
OOKModulator,
|
|
||||||
OQPSKModulator,
|
|
||||||
RaisedCosineFilter,
|
|
||||||
RootRaisedCosineFilter,
|
|
||||||
Upsampling,
|
|
||||||
)
|
|
||||||
from ria_toolkit_oss.signal.block_generator.continuous_modulation.fsk_modulator import (
|
|
||||||
FSKModulator,
|
|
||||||
)
|
|
||||||
except ImportError as exc:
|
|
||||||
raise RuntimeError(f"ria_toolkit_oss block generator not available: {exc}") from exc
|
|
||||||
|
|
||||||
# ── Special modulations with their own source-connected modulator ──
|
|
||||||
if modulation in ("OOK", "GMSK", "OQPSK"):
|
|
||||||
src = BinarySource()
|
|
||||||
if modulation == "OOK":
|
|
||||||
mod = OOKModulator(src, samples_per_symbol=sps)
|
|
||||||
elif modulation == "GMSK":
|
|
||||||
mod = GMSKModulator(src, samples_per_symbol=sps)
|
|
||||||
else:
|
|
||||||
mod = OQPSKModulator(src, samples_per_symbol=sps)
|
|
||||||
recording = mod.record(num_samples)
|
|
||||||
flat = np.asarray(recording.data).flatten().astype(np.complex64)
|
|
||||||
if len(flat) < num_samples:
|
|
||||||
flat = np.tile(flat, num_samples // len(flat) + 1)
|
|
||||||
return flat[:num_samples]
|
|
||||||
|
|
||||||
if modulation == "FSK":
|
|
||||||
symbol_rate = num_samples / sps
|
|
||||||
bits_per_sym = 1 # 2-FSK
|
|
||||||
num_bits = max(num_samples // sps, 128) * bits_per_sym
|
|
||||||
bits = BinarySource()((1, num_bits))
|
|
||||||
mod = FSKModulator(
|
|
||||||
num_bits_per_symbol=bits_per_sym,
|
|
||||||
frequency_spacing=symbol_rate * 0.5,
|
|
||||||
symbol_duration=1.0 / max(symbol_rate, 1.0),
|
|
||||||
sampling_frequency=symbol_rate * sps,
|
|
||||||
)
|
|
||||||
flat = np.asarray(mod(bits)).flatten().astype(np.complex64)
|
|
||||||
if len(flat) < num_samples:
|
|
||||||
flat = np.tile(flat, num_samples // len(flat) + 1)
|
|
||||||
return flat[:num_samples]
|
|
||||||
|
|
||||||
# ── PSK / QAM via Mapper → Upsampling → pulse filter ──────────────
|
|
||||||
if modulation not in _MOD_TABLE:
|
|
||||||
logger.warning("Unknown modulation %r — defaulting to QPSK", modulation)
|
|
||||||
modulation = "QPSK"
|
|
||||||
|
|
||||||
bits_per_sym, gen_type = _MOD_TABLE[modulation]
|
|
||||||
mod_family = "QAM" if gen_type == "qam" else "PSK"
|
|
||||||
|
|
||||||
source = BinarySource()
|
|
||||||
mapper = Mapper(constellation_type=mod_family, num_bits_per_symbol=bits_per_sym)
|
|
||||||
upsampler = Upsampling(factor=sps)
|
|
||||||
|
|
||||||
mapper.connect_input([source])
|
|
||||||
upsampler.connect_input([mapper])
|
|
||||||
|
|
||||||
if filter_type in ("rrc",):
|
|
||||||
pulse_filter = RootRaisedCosineFilter(span_in_symbols=6, upsampling_factor=sps, beta=rolloff)
|
|
||||||
pulse_filter.connect_input([upsampler])
|
|
||||||
recording = pulse_filter.record(num_samples)
|
|
||||||
elif filter_type in ("rc",):
|
|
||||||
pulse_filter = RaisedCosineFilter(span_in_symbols=6, upsampling_factor=sps, beta=rolloff)
|
|
||||||
pulse_filter.connect_input([upsampler])
|
|
||||||
recording = pulse_filter.record(num_samples)
|
|
||||||
else:
|
|
||||||
# "none", "rect", "gaussian" — use upsampler output directly
|
|
||||||
recording = upsampler.record(num_samples)
|
|
||||||
|
|
||||||
flat = np.asarray(recording.data).flatten().astype(np.complex64)
|
|
||||||
if len(flat) < num_samples:
|
|
||||||
flat = np.tile(flat, num_samples // len(flat) + 1)
|
|
||||||
return flat[:num_samples]
|
|
||||||
|
|
||||||
def _init_sdr(self, sample_rate: float, center_freq: float) -> None:
|
|
||||||
try:
|
|
||||||
from ria_toolkit_oss.sdr import get_sdr_device
|
|
||||||
|
|
||||||
self._sdr = get_sdr_device(self.sdr_device)
|
|
||||||
self._sdr.init_tx(
|
|
||||||
sample_rate=sample_rate,
|
|
||||||
center_frequency=center_freq,
|
|
||||||
gain=0,
|
|
||||||
channel=0,
|
|
||||||
gain_mode="manual",
|
|
||||||
)
|
|
||||||
logger.info(
|
|
||||||
"TX SDR initialised: %s @ %.3f MHz, %.1f Msps", self.sdr_device, center_freq / 1e6, sample_rate / 1e6
|
|
||||||
)
|
|
||||||
except Exception as exc:
|
|
||||||
logger.warning("TX SDR init failed (%s) — will simulate: %s", self.sdr_device, exc)
|
|
||||||
self._sdr = None
|
|
||||||
|
|
||||||
def _close_sdr(self) -> None:
|
|
||||||
if self._sdr is not None:
|
|
||||||
try:
|
|
||||||
self._sdr.close()
|
|
||||||
except Exception as exc:
|
|
||||||
logger.debug("TX SDR close error: %s", exc)
|
|
||||||
self._sdr = None
|
|
||||||
|
|
@ -3,7 +3,7 @@
|
||||||
from fastapi import Depends, FastAPI
|
from fastapi import Depends, FastAPI
|
||||||
|
|
||||||
from .auth import require_api_key
|
from .auth import require_api_key
|
||||||
from .routers import conductor, inference
|
from .routers import inference, orchestrator
|
||||||
|
|
||||||
|
|
||||||
def create_app(api_key: str = "") -> FastAPI:
|
def create_app(api_key: str = "") -> FastAPI:
|
||||||
|
|
@ -28,9 +28,9 @@ def create_app(api_key: str = "") -> FastAPI:
|
||||||
app.state.api_key = api_key
|
app.state.api_key = api_key
|
||||||
|
|
||||||
app.include_router(
|
app.include_router(
|
||||||
conductor.router,
|
orchestrator.router,
|
||||||
prefix="/conductor",
|
prefix="/orchestrator",
|
||||||
tags=["Conductor"],
|
tags=["Orchestrator"],
|
||||||
dependencies=[Depends(require_api_key)],
|
dependencies=[Depends(require_api_key)],
|
||||||
)
|
)
|
||||||
app.include_router(
|
app.include_router(
|
||||||
|
|
|
||||||
|
|
@ -7,7 +7,7 @@ from pathlib import Path
|
||||||
from pydantic import BaseModel, field_validator
|
from pydantic import BaseModel, field_validator
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Conductor
|
# Orchestrator
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
"""Conductor routes: campaign deployment, status, and cancellation."""
|
"""Orchestrator routes: campaign deployment, status, and cancellation."""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
|
@ -23,9 +23,9 @@ def serve(host: str, port: int, api_key: str, log_level: str):
|
||||||
|
|
||||||
\b
|
\b
|
||||||
Endpoints:
|
Endpoints:
|
||||||
POST /conductor/deploy
|
POST /orchestrator/deploy
|
||||||
GET /conductor/status/{campaign_id}
|
GET /orchestrator/status/{campaign_id}
|
||||||
POST /conductor/cancel/{campaign_id}
|
POST /orchestrator/cancel/{campaign_id}
|
||||||
POST /inference/load
|
POST /inference/load
|
||||||
POST /inference/start
|
POST /inference/start
|
||||||
POST /inference/stop
|
POST /inference/stop
|
||||||
|
|
|
||||||
|
|
@ -1,314 +0,0 @@
|
||||||
"""Tests for orchestration executor — StepResult, CampaignResult, _run_script, _extract_tx_params."""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import json
|
|
||||||
import stat
|
|
||||||
from types import SimpleNamespace
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from ria_toolkit_oss.orchestration.executor import (
|
|
||||||
CampaignResult,
|
|
||||||
StepResult,
|
|
||||||
_extract_tx_params,
|
|
||||||
_run_script,
|
|
||||||
)
|
|
||||||
from ria_toolkit_oss.orchestration.qa import QAResult
|
|
||||||
|
|
||||||
|
|
||||||
def _ok_qa() -> QAResult:
|
|
||||||
return QAResult(passed=True, flagged=False, snr_db=20.0, duration_s=1.0)
|
|
||||||
|
|
||||||
|
|
||||||
def _flagged_qa() -> QAResult:
|
|
||||||
return QAResult(passed=True, flagged=True, snr_db=5.0, duration_s=1.0, issues=["low SNR"])
|
|
||||||
|
|
||||||
|
|
||||||
def _failed_qa() -> QAResult:
|
|
||||||
return QAResult(passed=False, flagged=True, snr_db=0.0, duration_s=0.0, issues=["no signal"])
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# StepResult
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
class TestStepResult:
|
|
||||||
def test_ok_true_when_no_error_and_qa_passed(self):
|
|
||||||
r = StepResult(
|
|
||||||
transmitter_id="tx1",
|
|
||||||
step_label="step1",
|
|
||||||
output_path="/out/rec.sigmf-data",
|
|
||||||
qa=_ok_qa(),
|
|
||||||
capture_timestamp=0.0,
|
|
||||||
)
|
|
||||||
assert r.ok is True
|
|
||||||
|
|
||||||
def test_ok_false_when_error_set(self):
|
|
||||||
r = StepResult(
|
|
||||||
transmitter_id="tx1",
|
|
||||||
step_label="step1",
|
|
||||||
output_path=None,
|
|
||||||
qa=_ok_qa(),
|
|
||||||
capture_timestamp=0.0,
|
|
||||||
error="SDR failed",
|
|
||||||
)
|
|
||||||
assert r.ok is False
|
|
||||||
|
|
||||||
def test_ok_false_when_qa_not_passed(self):
|
|
||||||
r = StepResult(
|
|
||||||
transmitter_id="tx1",
|
|
||||||
step_label="step1",
|
|
||||||
output_path="/out",
|
|
||||||
qa=_failed_qa(),
|
|
||||||
capture_timestamp=0.0,
|
|
||||||
)
|
|
||||||
assert r.ok is False
|
|
||||||
|
|
||||||
def test_to_dict_contains_required_keys(self):
|
|
||||||
r = StepResult(
|
|
||||||
transmitter_id="tx1",
|
|
||||||
step_label="step1",
|
|
||||||
output_path="/out/rec.sigmf-data",
|
|
||||||
qa=_ok_qa(),
|
|
||||||
capture_timestamp=1234.5,
|
|
||||||
)
|
|
||||||
d = r.to_dict()
|
|
||||||
assert d["transmitter_id"] == "tx1"
|
|
||||||
assert d["step_label"] == "step1"
|
|
||||||
assert d["output_path"] == "/out/rec.sigmf-data"
|
|
||||||
assert d["capture_timestamp"] == pytest.approx(1234.5)
|
|
||||||
assert d["error"] is None
|
|
||||||
assert d["qa"]["passed"] is True
|
|
||||||
|
|
||||||
def test_to_dict_includes_error_when_set(self):
|
|
||||||
r = StepResult(
|
|
||||||
transmitter_id="tx1",
|
|
||||||
step_label="step1",
|
|
||||||
output_path=None,
|
|
||||||
qa=_failed_qa(),
|
|
||||||
capture_timestamp=0.0,
|
|
||||||
error="disk full",
|
|
||||||
)
|
|
||||||
assert r.to_dict()["error"] == "disk full"
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# CampaignResult
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
class TestCampaignResult:
|
|
||||||
def _make(self, steps: list) -> CampaignResult:
|
|
||||||
r = CampaignResult(campaign_name="test_campaign")
|
|
||||||
r.steps = steps
|
|
||||||
r.end_time = r.start_time + 5.0
|
|
||||||
return r
|
|
||||||
|
|
||||||
def test_total_steps(self):
|
|
||||||
r = self._make(
|
|
||||||
[
|
|
||||||
StepResult("tx1", "s1", "/out", _ok_qa(), 0.0),
|
|
||||||
StepResult("tx1", "s2", "/out", _ok_qa(), 0.0),
|
|
||||||
]
|
|
||||||
)
|
|
||||||
assert r.total_steps == 2
|
|
||||||
|
|
||||||
def test_passed_count(self):
|
|
||||||
r = self._make(
|
|
||||||
[
|
|
||||||
StepResult("tx1", "s1", "/out", _ok_qa(), 0.0),
|
|
||||||
StepResult("tx1", "s2", "/out", _failed_qa(), 0.0),
|
|
||||||
]
|
|
||||||
)
|
|
||||||
assert r.passed == 1
|
|
||||||
|
|
||||||
def test_failed_count(self):
|
|
||||||
r = self._make(
|
|
||||||
[
|
|
||||||
StepResult("tx1", "s1", "/out", _ok_qa(), 0.0),
|
|
||||||
StepResult("tx1", "s2", "/out", _failed_qa(), 0.0),
|
|
||||||
]
|
|
||||||
)
|
|
||||||
assert r.failed == 1
|
|
||||||
|
|
||||||
def test_flagged_count(self):
|
|
||||||
r = self._make(
|
|
||||||
[
|
|
||||||
StepResult("tx1", "s1", "/out", _ok_qa(), 0.0),
|
|
||||||
StepResult("tx1", "s2", "/out", _flagged_qa(), 0.0),
|
|
||||||
]
|
|
||||||
)
|
|
||||||
assert r.flagged == 1
|
|
||||||
|
|
||||||
def test_error_step_counts_as_failed_not_passed(self):
|
|
||||||
r = self._make(
|
|
||||||
[
|
|
||||||
StepResult("tx1", "s1", None, _ok_qa(), 0.0, error="disk full"),
|
|
||||||
]
|
|
||||||
)
|
|
||||||
assert r.failed == 1
|
|
||||||
assert r.passed == 0
|
|
||||||
|
|
||||||
def test_duration_s_from_end_time(self):
|
|
||||||
r = CampaignResult(campaign_name="c")
|
|
||||||
r.start_time = 100.0
|
|
||||||
r.end_time = 115.0
|
|
||||||
assert r.duration_s == pytest.approx(15.0)
|
|
||||||
|
|
||||||
def test_to_dict_structure(self):
|
|
||||||
r = self._make([StepResult("tx1", "s1", "/out", _ok_qa(), 0.0)])
|
|
||||||
d = r.to_dict()
|
|
||||||
assert d["campaign_name"] == "test_campaign"
|
|
||||||
assert d["total_steps"] == 1
|
|
||||||
assert d["passed"] == 1
|
|
||||||
assert len(d["steps"]) == 1
|
|
||||||
|
|
||||||
def test_write_report(self, tmp_path):
|
|
||||||
r = self._make([StepResult("tx1", "s1", "/out", _ok_qa(), 0.0)])
|
|
||||||
out = tmp_path / "report.json"
|
|
||||||
r.write_report(str(out))
|
|
||||||
assert out.exists()
|
|
||||||
data = json.loads(out.read_text())
|
|
||||||
assert data["campaign_name"] == "test_campaign"
|
|
||||||
|
|
||||||
def test_write_report_creates_nested_dirs(self, tmp_path):
|
|
||||||
r = self._make([])
|
|
||||||
out = tmp_path / "nested" / "deep" / "report.json"
|
|
||||||
r.write_report(str(out))
|
|
||||||
assert out.exists()
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# _run_script
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
class TestRunScript:
|
|
||||||
def _script(self, tmp_path, body: str) -> str:
|
|
||||||
s = tmp_path / "script.sh"
|
|
||||||
s.write_text("#!/bin/sh\n" + body)
|
|
||||||
s.chmod(s.stat().st_mode | stat.S_IEXEC)
|
|
||||||
return str(s)
|
|
||||||
|
|
||||||
def test_returns_stdout(self, tmp_path):
|
|
||||||
out = _run_script(self._script(tmp_path, 'echo "hello world"'))
|
|
||||||
assert out == "hello world"
|
|
||||||
|
|
||||||
def test_passes_args_to_script(self, tmp_path):
|
|
||||||
out = _run_script(self._script(tmp_path, 'echo "$1 $2"'), "configure", "arg2")
|
|
||||||
assert "configure" in out
|
|
||||||
|
|
||||||
def test_raises_on_nonzero_exit(self, tmp_path):
|
|
||||||
with pytest.raises(RuntimeError, match="exited 1"):
|
|
||||||
_run_script(self._script(tmp_path, "exit 1"))
|
|
||||||
|
|
||||||
def test_raises_on_relative_path(self):
|
|
||||||
with pytest.raises(RuntimeError, match="absolute"):
|
|
||||||
_run_script("relative/script.sh")
|
|
||||||
|
|
||||||
def test_raises_on_missing_file(self, tmp_path):
|
|
||||||
with pytest.raises(RuntimeError):
|
|
||||||
_run_script(str(tmp_path / "nonexistent.sh"))
|
|
||||||
|
|
||||||
def test_raises_on_timeout(self, tmp_path):
|
|
||||||
with pytest.raises(RuntimeError, match="timed out"):
|
|
||||||
_run_script(self._script(tmp_path, "sleep 60"), timeout=0.1)
|
|
||||||
|
|
||||||
def test_stderr_included_in_error_message(self, tmp_path):
|
|
||||||
with pytest.raises(RuntimeError) as exc_info:
|
|
||||||
_run_script(self._script(tmp_path, "echo 'bad thing' >&2; exit 1"))
|
|
||||||
assert "bad thing" in str(exc_info.value)
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# _extract_tx_params
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
class TestExtractTxParams:
|
|
||||||
def test_returns_none_when_no_sdr_agent_attribute(self):
|
|
||||||
tx = SimpleNamespace()
|
|
||||||
assert _extract_tx_params(tx) is None
|
|
||||||
|
|
||||||
def test_returns_none_when_sdr_agent_is_none(self):
|
|
||||||
tx = SimpleNamespace(sdr_agent=None)
|
|
||||||
assert _extract_tx_params(tx) is None
|
|
||||||
|
|
||||||
def test_returns_none_when_sdr_agent_is_empty_dict(self):
|
|
||||||
tx = SimpleNamespace(sdr_agent={})
|
|
||||||
assert _extract_tx_params(tx) is None
|
|
||||||
|
|
||||||
def test_returns_signal_params(self):
|
|
||||||
tx = SimpleNamespace(
|
|
||||||
sdr_agent={
|
|
||||||
"modulation": "QPSK",
|
|
||||||
"symbol_rate": 1e6,
|
|
||||||
"center_frequency": 2.4e9,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
result = _extract_tx_params(tx)
|
|
||||||
assert result == {"modulation": "QPSK", "symbol_rate": 1e6, "center_frequency": 2.4e9}
|
|
||||||
|
|
||||||
def test_strips_infra_key_node_id(self):
|
|
||||||
tx = SimpleNamespace(
|
|
||||||
sdr_agent={
|
|
||||||
"modulation": "BPSK",
|
|
||||||
"node_id": "node_abc123",
|
|
||||||
}
|
|
||||||
)
|
|
||||||
result = _extract_tx_params(tx)
|
|
||||||
assert "node_id" not in result
|
|
||||||
assert result == {"modulation": "BPSK"}
|
|
||||||
|
|
||||||
def test_strips_infra_key_session_code(self):
|
|
||||||
tx = SimpleNamespace(
|
|
||||||
sdr_agent={
|
|
||||||
"modulation": "FSK",
|
|
||||||
"session_code": "amber-peak-transmit",
|
|
||||||
}
|
|
||||||
)
|
|
||||||
result = _extract_tx_params(tx)
|
|
||||||
assert "session_code" not in result
|
|
||||||
|
|
||||||
def test_strips_none_values(self):
|
|
||||||
tx = SimpleNamespace(
|
|
||||||
sdr_agent={
|
|
||||||
"modulation": "QPSK",
|
|
||||||
"order": None,
|
|
||||||
"rolloff": 0.35,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
result = _extract_tx_params(tx)
|
|
||||||
assert "order" not in result
|
|
||||||
assert result == {"modulation": "QPSK", "rolloff": 0.35}
|
|
||||||
|
|
||||||
def test_does_not_mutate_source_dict(self):
|
|
||||||
cfg = {"modulation": "QPSK", "node_id": "nid", "session_code": "code"}
|
|
||||||
tx = SimpleNamespace(sdr_agent=cfg)
|
|
||||||
_extract_tx_params(tx)
|
|
||||||
assert "node_id" in cfg
|
|
||||||
|
|
||||||
def test_full_sdr_agent_config(self):
|
|
||||||
tx = SimpleNamespace(
|
|
||||||
sdr_agent={
|
|
||||||
"modulation": "16QAM",
|
|
||||||
"order": 4,
|
|
||||||
"symbol_rate": 5e6,
|
|
||||||
"center_frequency": 915e6,
|
|
||||||
"filter": "rrc",
|
|
||||||
"rolloff": 0.35,
|
|
||||||
"node_id": "node_xyz",
|
|
||||||
"session_code": "some-code",
|
|
||||||
}
|
|
||||||
)
|
|
||||||
result = _extract_tx_params(tx)
|
|
||||||
assert result == {
|
|
||||||
"modulation": "16QAM",
|
|
||||||
"order": 4,
|
|
||||||
"symbol_rate": 5e6,
|
|
||||||
"center_frequency": 915e6,
|
|
||||||
"filter": "rrc",
|
|
||||||
"rolloff": 0.35,
|
|
||||||
}
|
|
||||||
|
|
@ -109,38 +109,6 @@ class TestLabelRecording:
|
||||||
result = label_recording(rec, "iphone13_001", _wifi_step(), time.time())
|
result = label_recording(rec, "iphone13_001", _wifi_step(), time.time())
|
||||||
assert result is rec
|
assert result is rec
|
||||||
|
|
||||||
def test_tx_params_none_by_default(self):
|
|
||||||
rec = label_recording(_simple_recording(), "iphone13_001", _wifi_step(), time.time())
|
|
||||||
tx_keys = [k for k in rec.metadata if k.startswith("tx_")]
|
|
||||||
assert tx_keys == []
|
|
||||||
|
|
||||||
def test_tx_params_written_as_tx_prefix_keys(self):
|
|
||||||
params = {"modulation": "QPSK", "symbol_rate": 1e6}
|
|
||||||
rec = label_recording(_simple_recording(), "dev", _wifi_step(), time.time(), tx_params=params)
|
|
||||||
assert rec.metadata["tx_modulation"] == "QPSK"
|
|
||||||
assert rec.metadata["tx_symbol_rate"] == pytest.approx(1e6)
|
|
||||||
|
|
||||||
def test_tx_params_multiple_fields(self):
|
|
||||||
params = {
|
|
||||||
"modulation": "16QAM",
|
|
||||||
"order": 4,
|
|
||||||
"symbol_rate": 5e6,
|
|
||||||
"center_frequency": 915e6,
|
|
||||||
"filter": "rrc",
|
|
||||||
"rolloff": 0.35,
|
|
||||||
}
|
|
||||||
rec = label_recording(_simple_recording(), "dev", _wifi_step(), time.time(), tx_params=params)
|
|
||||||
for k, v in params.items():
|
|
||||||
assert f"tx_{k}" in rec.metadata
|
|
||||||
assert (
|
|
||||||
rec.metadata[f"tx_{k}"] == pytest.approx(v) if isinstance(v, float) else rec.metadata[f"tx_{k}"] == v
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_tx_params_empty_dict_writes_nothing(self):
|
|
||||||
rec = label_recording(_simple_recording(), "dev", _wifi_step(), time.time(), tx_params={})
|
|
||||||
tx_keys = [k for k in rec.metadata if k.startswith("tx_") and k != "tx_power_dbm"]
|
|
||||||
assert tx_keys == []
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# build_output_filename
|
# build_output_filename
|
||||||
|
|
|
||||||
|
|
@ -1,153 +0,0 @@
|
||||||
"""Tests for TxExecutor — signal synthesis and step execution."""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import threading
|
|
||||||
from unittest.mock import patch
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from ria_toolkit_oss.orchestration.tx_executor import TxExecutor
|
|
||||||
|
|
||||||
|
|
||||||
def _cfg(modulation="QPSK", symbol_rate=100_000, steps=None):
|
|
||||||
return {
|
|
||||||
"id": "test-tx",
|
|
||||||
"type": "sdr",
|
|
||||||
"control_method": "sdr_agent",
|
|
||||||
"sdr_agent": {
|
|
||||||
"modulation": modulation,
|
|
||||||
"symbol_rate": symbol_rate,
|
|
||||||
"center_frequency": 0.0,
|
|
||||||
"filter": "rrc",
|
|
||||||
"rolloff": 0.35,
|
|
||||||
},
|
|
||||||
"schedule": steps or [{"label": "step1", "duration": 0.001, "power_dbm": -10}],
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Initialisation
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
class TestTxExecutorInit:
|
|
||||||
def test_stores_sdr_device(self):
|
|
||||||
ex = TxExecutor(_cfg(), sdr_device="pluto")
|
|
||||||
assert ex.sdr_device == "pluto"
|
|
||||||
|
|
||||||
def test_stop_event_created_when_not_supplied(self):
|
|
||||||
ex = TxExecutor(_cfg())
|
|
||||||
assert isinstance(ex.stop_event, threading.Event)
|
|
||||||
assert not ex.stop_event.is_set()
|
|
||||||
|
|
||||||
def test_accepts_external_stop_event(self):
|
|
||||||
ev = threading.Event()
|
|
||||||
ex = TxExecutor(_cfg(), stop_event=ev)
|
|
||||||
assert ex.stop_event is ev
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# run() — schedule iteration
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
class TestTxExecutorRun:
|
|
||||||
def test_empty_schedule_returns_immediately(self):
|
|
||||||
cfg = _cfg(steps=[])
|
|
||||||
ex = TxExecutor(cfg)
|
|
||||||
ex.run() # must not raise or block
|
|
||||||
|
|
||||||
def test_pre_set_stop_event_skips_all_steps(self):
|
|
||||||
ev = threading.Event()
|
|
||||||
ev.set()
|
|
||||||
ex = TxExecutor(_cfg(), stop_event=ev)
|
|
||||||
# If stop was set, _execute_step should never be called.
|
|
||||||
# run() should return cleanly without attempting synthesis.
|
|
||||||
ex.run()
|
|
||||||
|
|
||||||
def test_no_sdr_falls_back_to_simulation(self, monkeypatch):
|
|
||||||
"""Without SDR hardware TxExecutor simulates by calling stop_event.wait."""
|
|
||||||
cfg = _cfg(steps=[{"label": "s", "duration": 0.001, "power_dbm": 0}])
|
|
||||||
waited = []
|
|
||||||
real_ev = threading.Event()
|
|
||||||
|
|
||||||
def _fake_wait(timeout=None):
|
|
||||||
waited.append(timeout)
|
|
||||||
return False
|
|
||||||
|
|
||||||
monkeypatch.setattr(real_ev, "wait", _fake_wait)
|
|
||||||
|
|
||||||
# Patch SDR init to always fail (forces simulation path)
|
|
||||||
with patch.object(TxExecutor, "_init_sdr", lambda self, *a, **kw: setattr(self, "_sdr", None)):
|
|
||||||
ex = TxExecutor(cfg, sdr_device="nonexistent_xyz", stop_event=real_ev)
|
|
||||||
ex.run()
|
|
||||||
|
|
||||||
assert len(waited) >= 1, "expected stop_event.wait to be called for simulation"
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# _synthesise — all modulation types and filter types
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
class TestSynthesise:
|
|
||||||
@pytest.fixture(autouse=True)
|
|
||||||
def _ex(self):
|
|
||||||
self.ex = TxExecutor(_cfg())
|
|
||||||
|
|
||||||
def _synth(self, mod, num_samples=256):
|
|
||||||
return self.ex._synthesise(mod, sps=4, num_samples=num_samples, filter_type="rrc", rolloff=0.35)
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("mod", ["BPSK", "QPSK", "8PSK", "16QAM", "64QAM", "256QAM"])
|
|
||||||
def test_psk_qam_returns_complex64_array(self, mod):
|
|
||||||
sig = self._synth(mod)
|
|
||||||
assert sig.dtype == np.complex64
|
|
||||||
assert len(sig) == 256
|
|
||||||
|
|
||||||
def test_fsk_returns_correct_length(self):
|
|
||||||
sig = self._synth("FSK")
|
|
||||||
assert len(sig) == 256
|
|
||||||
|
|
||||||
def test_ook_returns_correct_length(self):
|
|
||||||
sig = self._synth("OOK")
|
|
||||||
assert len(sig) == 256
|
|
||||||
|
|
||||||
def test_gmsk_returns_correct_length(self):
|
|
||||||
sig = self._synth("GMSK")
|
|
||||||
assert len(sig) == 256
|
|
||||||
|
|
||||||
def test_oqpsk_returns_correct_length(self):
|
|
||||||
sig = self._synth("OQPSK")
|
|
||||||
assert len(sig) == 256
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("mod", ["BPSK", "QPSK", "16QAM", "FSK", "OOK", "GMSK"])
|
|
||||||
def test_samples_are_finite(self, mod):
|
|
||||||
sig = self._synth(mod)
|
|
||||||
assert np.all(np.isfinite(sig.real)), f"{mod}: non-finite real samples"
|
|
||||||
assert np.all(np.isfinite(sig.imag)), f"{mod}: non-finite imag samples"
|
|
||||||
|
|
||||||
def test_unknown_modulation_defaults_to_qpsk(self):
|
|
||||||
sig = self._synth("UNKNOWN_MOD_XYZ")
|
|
||||||
assert len(sig) == 256
|
|
||||||
assert sig.dtype == np.complex64
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("filter_type", ["rrc", "rc", "gaussian", "rect", "none"])
|
|
||||||
def test_all_filter_types(self, filter_type):
|
|
||||||
sig = self.ex._synthesise("QPSK", sps=4, num_samples=128, filter_type=filter_type, rolloff=0.35)
|
|
||||||
assert len(sig) == 128
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("n", [64, 128, 512, 1024])
|
|
||||||
def test_output_length_matches_requested_samples(self, n):
|
|
||||||
sig = self._synth("QPSK", num_samples=n)
|
|
||||||
assert len(sig) == n
|
|
||||||
|
|
||||||
def test_bpsk_output_is_complex_not_real(self):
|
|
||||||
sig = self._synth("BPSK")
|
|
||||||
# complex64 always has imag part; just check dtype
|
|
||||||
assert sig.dtype == np.complex64
|
|
||||||
|
|
||||||
def test_256qam_correct_length(self):
|
|
||||||
sig = self._synth("256QAM")
|
|
||||||
assert len(sig) == 256
|
|
||||||
|
|
@ -189,8 +189,6 @@ class TestNoiseCommand:
|
||||||
"10000",
|
"10000",
|
||||||
"--noise-type",
|
"--noise-type",
|
||||||
"gaussian",
|
"gaussian",
|
||||||
"--power",
|
|
||||||
"0.01",
|
|
||||||
"--output",
|
"--output",
|
||||||
output,
|
output,
|
||||||
"-q",
|
"-q",
|
||||||
|
|
@ -236,7 +234,7 @@ class TestNoiseCommand:
|
||||||
"--num-samples",
|
"--num-samples",
|
||||||
"10000",
|
"10000",
|
||||||
"--power",
|
"--power",
|
||||||
"0.01",
|
"0.5",
|
||||||
"--output",
|
"--output",
|
||||||
output,
|
output,
|
||||||
"-q",
|
"-q",
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,6 @@
|
||||||
"""Tests for the RT-OSS HTTP server.
|
"""Tests for the RT-OSS HTTP server.
|
||||||
|
|
||||||
Covers: auth, inference lifecycle (without SDR/ONNX hardware), conductor
|
Covers: auth, inference lifecycle (without SDR/ONNX hardware), orchestrator
|
||||||
lifecycle (with mocked executor), and state helpers.
|
lifecycle (with mocked executor), and state helpers.
|
||||||
|
|
||||||
``start_inference`` and ``_inference_loop`` require real SDR hardware and an
|
``start_inference`` and ``_inference_loop`` require real SDR hardware and an
|
||||||
|
|
@ -286,17 +286,17 @@ class TestInferenceStop:
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# POST /conductor/deploy
|
# POST /orchestrator/deploy
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
class TestConductorDeploy:
|
class TestOrchestratorDeploy:
|
||||||
def test_deploy_422_on_invalid_config(self, client):
|
def test_deploy_422_on_invalid_config(self, client):
|
||||||
with patch(
|
with patch(
|
||||||
"ria_toolkit_oss.server.routers.conductor.CampaignConfig.from_dict",
|
"ria_toolkit_oss.server.routers.orchestrator.CampaignConfig.from_dict",
|
||||||
side_effect=ValueError("missing required field 'name'"),
|
side_effect=ValueError("missing required field 'name'"),
|
||||||
):
|
):
|
||||||
resp = client.post("/conductor/deploy", json={"config": {}})
|
resp = client.post("/orchestrator/deploy", json={"config": {}})
|
||||||
assert resp.status_code == 422
|
assert resp.status_code == 422
|
||||||
|
|
||||||
def test_deploy_returns_campaign_id(self, client):
|
def test_deploy_returns_campaign_id(self, client):
|
||||||
|
|
@ -307,10 +307,10 @@ class TestConductorDeploy:
|
||||||
mock_executor.return_value.run.return_value = MagicMock(to_dict=lambda: {})
|
mock_executor.return_value.run.return_value = MagicMock(to_dict=lambda: {})
|
||||||
|
|
||||||
with (
|
with (
|
||||||
patch("ria_toolkit_oss.server.routers.conductor.CampaignConfig.from_dict", return_value=mock_cfg),
|
patch("ria_toolkit_oss.server.routers.orchestrator.CampaignConfig.from_dict", return_value=mock_cfg),
|
||||||
patch("ria_toolkit_oss.server.routers.conductor.CampaignExecutor", mock_executor),
|
patch("ria_toolkit_oss.server.routers.orchestrator.CampaignExecutor", mock_executor),
|
||||||
):
|
):
|
||||||
resp = client.post("/conductor/deploy", json={"config": {"name": "test_campaign"}})
|
resp = client.post("/orchestrator/deploy", json={"config": {"name": "test_campaign"}})
|
||||||
|
|
||||||
assert resp.status_code == 200
|
assert resp.status_code == 200
|
||||||
body = resp.json()
|
body = resp.json()
|
||||||
|
|
@ -325,23 +325,23 @@ class TestConductorDeploy:
|
||||||
mock_executor.return_value.run.return_value = MagicMock(to_dict=lambda: {})
|
mock_executor.return_value.run.return_value = MagicMock(to_dict=lambda: {})
|
||||||
|
|
||||||
with (
|
with (
|
||||||
patch("ria_toolkit_oss.server.routers.conductor.CampaignConfig.from_dict", return_value=mock_cfg),
|
patch("ria_toolkit_oss.server.routers.orchestrator.CampaignConfig.from_dict", return_value=mock_cfg),
|
||||||
patch("ria_toolkit_oss.server.routers.conductor.CampaignExecutor", mock_executor),
|
patch("ria_toolkit_oss.server.routers.orchestrator.CampaignExecutor", mock_executor),
|
||||||
):
|
):
|
||||||
resp = client.post("/conductor/deploy", json={"config": {}})
|
resp = client.post("/orchestrator/deploy", json={"config": {}})
|
||||||
|
|
||||||
campaign_id = resp.json()["campaign_id"]
|
campaign_id = resp.json()["campaign_id"]
|
||||||
assert state_module._campaigns.get(campaign_id) is not None
|
assert state_module._campaigns.get(campaign_id) is not None
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# GET /conductor/status/{campaign_id}
|
# GET /orchestrator/status/{campaign_id}
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
class TestConductorStatus:
|
class TestOrchestratorStatus:
|
||||||
def test_status_404_for_unknown_id(self, client):
|
def test_status_404_for_unknown_id(self, client):
|
||||||
resp = client.get("/conductor/status/nonexistent-id")
|
resp = client.get("/orchestrator/status/nonexistent-id")
|
||||||
assert resp.status_code == 404
|
assert resp.status_code == 404
|
||||||
|
|
||||||
def test_status_returns_campaign_state(self, client):
|
def test_status_returns_campaign_state(self, client):
|
||||||
|
|
@ -357,7 +357,7 @@ class TestConductorStatus:
|
||||||
)
|
)
|
||||||
state_module._campaigns["abc-123"] = state
|
state_module._campaigns["abc-123"] = state
|
||||||
|
|
||||||
resp = client.get("/conductor/status/abc-123")
|
resp = client.get("/orchestrator/status/abc-123")
|
||||||
assert resp.status_code == 200
|
assert resp.status_code == 200
|
||||||
body = resp.json()
|
body = resp.json()
|
||||||
assert body["campaign_id"] == "abc-123"
|
assert body["campaign_id"] == "abc-123"
|
||||||
|
|
@ -367,13 +367,13 @@ class TestConductorStatus:
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# POST /conductor/cancel/{campaign_id}
|
# POST /orchestrator/cancel/{campaign_id}
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
class TestConductorCancel:
|
class TestOrchestratorCancel:
|
||||||
def test_cancel_404_for_unknown_id(self, client):
|
def test_cancel_404_for_unknown_id(self, client):
|
||||||
resp = client.post("/conductor/cancel/no-such-id")
|
resp = client.post("/orchestrator/cancel/no-such-id")
|
||||||
assert resp.status_code == 404
|
assert resp.status_code == 404
|
||||||
|
|
||||||
def test_cancel_sets_cancel_event(self, client):
|
def test_cancel_sets_cancel_event(self, client):
|
||||||
|
|
@ -387,7 +387,7 @@ class TestConductorCancel:
|
||||||
)
|
)
|
||||||
state_module._campaigns["camp-to-cancel"] = state
|
state_module._campaigns["camp-to-cancel"] = state
|
||||||
|
|
||||||
resp = client.post("/conductor/cancel/camp-to-cancel")
|
resp = client.post("/orchestrator/cancel/camp-to-cancel")
|
||||||
assert resp.status_code == 200
|
assert resp.status_code == 200
|
||||||
assert resp.json()["cancelled"] is True
|
assert resp.json()["cancelled"] is True
|
||||||
assert cancel_event.is_set()
|
assert cancel_event.is_set()
|
||||||
|
|
@ -403,7 +403,7 @@ class TestConductorCancel:
|
||||||
)
|
)
|
||||||
state_module._campaigns["done"] = state
|
state_module._campaigns["done"] = state
|
||||||
|
|
||||||
resp = client.post("/conductor/cancel/done")
|
resp = client.post("/orchestrator/cancel/done")
|
||||||
assert resp.status_code == 200
|
assert resp.status_code == 200
|
||||||
assert resp.json()["cancelled"] is False
|
assert resp.json()["cancelled"] is False
|
||||||
assert not cancel_event.is_set()
|
assert not cancel_event.is_set()
|
||||||
|
|
|
||||||
|
|
@ -1,247 +0,0 @@
|
||||||
"""Tests for NodeAgent — TX role, session code, and TX command dispatch."""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import threading
|
|
||||||
import time
|
|
||||||
from unittest.mock import MagicMock, patch
|
|
||||||
|
|
||||||
from ria_toolkit_oss.agent import NodeAgent
|
|
||||||
|
|
||||||
|
|
||||||
def _agent(role="general", session_code=None, **kwargs):
|
|
||||||
return NodeAgent(
|
|
||||||
hub_url="http://hub.test",
|
|
||||||
api_key="test-key",
|
|
||||||
name="test-node",
|
|
||||||
sdr_device="mock",
|
|
||||||
role=role,
|
|
||||||
session_code=session_code,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _mock_register(agent, node_id="node_abc123"):
|
|
||||||
"""Patch _post so _register() returns a fake node_id response."""
|
|
||||||
resp = MagicMock()
|
|
||||||
resp.json.return_value = {"node_id": node_id}
|
|
||||||
resp.raise_for_status.return_value = None
|
|
||||||
agent._post = MagicMock(return_value=resp)
|
|
||||||
return agent._post
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Initialisation
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
class TestNodeAgentInit:
|
|
||||||
def test_stores_role_general(self):
|
|
||||||
assert _agent(role="general").role == "general"
|
|
||||||
|
|
||||||
def test_stores_role_tx(self):
|
|
||||||
assert _agent(role="tx").role == "tx"
|
|
||||||
|
|
||||||
def test_stores_role_rx(self):
|
|
||||||
assert _agent(role="rx").role == "rx"
|
|
||||||
|
|
||||||
def test_session_code_stored(self):
|
|
||||||
assert _agent(session_code="amber-peak-transmit").session_code == "amber-peak-transmit"
|
|
||||||
|
|
||||||
def test_session_code_none_by_default(self):
|
|
||||||
assert _agent().session_code is None
|
|
||||||
|
|
||||||
def test_tx_stop_event_created(self):
|
|
||||||
a = _agent()
|
|
||||||
assert isinstance(a._tx_stop, threading.Event)
|
|
||||||
|
|
||||||
def test_tx_thread_none_initially(self):
|
|
||||||
assert _agent()._tx_thread is None
|
|
||||||
|
|
||||||
def test_hub_url_trailing_slash_stripped(self):
|
|
||||||
a = NodeAgent(hub_url="http://hub.test/", api_key="k", name="n")
|
|
||||||
assert a.hub_url == "http://hub.test"
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# _register payload
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
class TestNodeAgentRegisterPayload:
|
|
||||||
def _payload(self, agent):
|
|
||||||
post = _mock_register(agent)
|
|
||||||
agent._register()
|
|
||||||
_, kwargs = post.call_args
|
|
||||||
return kwargs["json"]
|
|
||||||
|
|
||||||
def test_general_role_in_payload(self):
|
|
||||||
payload = self._payload(_agent(role="general"))
|
|
||||||
assert payload["role"] == "general"
|
|
||||||
|
|
||||||
def test_tx_role_in_payload(self):
|
|
||||||
payload = self._payload(_agent(role="tx"))
|
|
||||||
assert payload["role"] == "tx"
|
|
||||||
|
|
||||||
def test_tx_role_adds_transmit_capability(self):
|
|
||||||
payload = self._payload(_agent(role="tx"))
|
|
||||||
assert "transmit" in payload["capabilities"]
|
|
||||||
|
|
||||||
def test_general_role_omits_transmit_capability(self):
|
|
||||||
payload = self._payload(_agent(role="general"))
|
|
||||||
assert "transmit" not in payload.get("capabilities", [])
|
|
||||||
|
|
||||||
def test_session_code_included_when_set(self):
|
|
||||||
payload = self._payload(_agent(role="tx", session_code="amber-peak-transmit"))
|
|
||||||
assert payload["session_code"] == "amber-peak-transmit"
|
|
||||||
|
|
||||||
def test_session_code_omitted_when_none(self):
|
|
||||||
payload = self._payload(_agent())
|
|
||||||
assert "session_code" not in payload
|
|
||||||
|
|
||||||
def test_register_stores_returned_node_id(self):
|
|
||||||
a = _agent()
|
|
||||||
_mock_register(a, node_id="node_xyz999")
|
|
||||||
a._register()
|
|
||||||
assert a.node_id == "node_xyz999"
|
|
||||||
|
|
||||||
def test_name_in_payload(self):
|
|
||||||
a = NodeAgent(hub_url="http://h", api_key="k", name="my-bench")
|
|
||||||
_mock_register(a)
|
|
||||||
a._register()
|
|
||||||
_, kwargs = a._post.call_args
|
|
||||||
assert kwargs["json"]["name"] == "my-bench"
|
|
||||||
|
|
||||||
def test_sdr_device_in_payload(self):
|
|
||||||
a = _agent()
|
|
||||||
post = _mock_register(a)
|
|
||||||
a._register()
|
|
||||||
_, kwargs = post.call_args
|
|
||||||
assert kwargs["json"]["sdr_device"] == "mock"
|
|
||||||
|
|
||||||
def test_campaign_capability_always_present(self):
|
|
||||||
for role in ("general", "rx", "tx"):
|
|
||||||
a = _agent(role=role)
|
|
||||||
payload = self._payload(a)
|
|
||||||
assert "campaign" in payload["capabilities"]
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# _dispatch — TX commands
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
class TestNodeAgentDispatch:
|
|
||||||
def _make_agent(self):
|
|
||||||
a = _agent(role="tx")
|
|
||||||
a.node_id = "node_abc"
|
|
||||||
a._report_campaign_status = MagicMock()
|
|
||||||
return a
|
|
||||||
|
|
||||||
def test_start_transmit_spawns_thread(self):
|
|
||||||
a = self._make_agent()
|
|
||||||
done = threading.Event()
|
|
||||||
|
|
||||||
class _FakeExecutor:
|
|
||||||
def run(self_):
|
|
||||||
done.wait(timeout=2)
|
|
||||||
|
|
||||||
with patch("ria_toolkit_oss.orchestration.tx_executor.TxExecutor", return_value=_FakeExecutor()):
|
|
||||||
a._dispatch({"command": "start_transmit", "sdr_agent": {}, "schedule": []})
|
|
||||||
time.sleep(0.05)
|
|
||||||
assert a._tx_thread is not None
|
|
||||||
done.set()
|
|
||||||
|
|
||||||
def test_start_transmit_clears_stop_event(self):
|
|
||||||
a = self._make_agent()
|
|
||||||
a._tx_stop.set() # pre-set
|
|
||||||
|
|
||||||
done = threading.Event()
|
|
||||||
|
|
||||||
class _FakeExecutor:
|
|
||||||
def run(self_):
|
|
||||||
done.wait(timeout=2)
|
|
||||||
|
|
||||||
with patch("ria_toolkit_oss.orchestration.tx_executor.TxExecutor", return_value=_FakeExecutor()):
|
|
||||||
a._dispatch({"command": "start_transmit", "sdr_agent": {}, "schedule": []})
|
|
||||||
time.sleep(0.05)
|
|
||||||
assert not a._tx_stop.is_set()
|
|
||||||
done.set()
|
|
||||||
|
|
||||||
def test_stop_transmit_sets_stop_event(self):
|
|
||||||
a = self._make_agent()
|
|
||||||
a._dispatch({"command": "stop_transmit"})
|
|
||||||
assert a._tx_stop.is_set()
|
|
||||||
|
|
||||||
def test_configure_transmit_does_not_raise(self):
|
|
||||||
a = self._make_agent()
|
|
||||||
a._dispatch({"command": "configure_transmit", "modulation": "BPSK"})
|
|
||||||
|
|
||||||
def test_unknown_command_is_ignored(self):
|
|
||||||
a = self._make_agent()
|
|
||||||
a._dispatch({"command": "frobnicate_xyz"})
|
|
||||||
|
|
||||||
def test_duplicate_start_transmit_ignored_while_running(self):
|
|
||||||
a = self._make_agent()
|
|
||||||
done = threading.Event()
|
|
||||||
run_calls = []
|
|
||||||
|
|
||||||
class _FakeExecutor:
|
|
||||||
def run(self_):
|
|
||||||
run_calls.append(1)
|
|
||||||
done.wait(timeout=2)
|
|
||||||
|
|
||||||
with patch("ria_toolkit_oss.orchestration.tx_executor.TxExecutor", return_value=_FakeExecutor()):
|
|
||||||
a._dispatch({"command": "start_transmit"})
|
|
||||||
time.sleep(0.05)
|
|
||||||
a._dispatch({"command": "start_transmit"}) # second while first alive
|
|
||||||
done.set()
|
|
||||||
time.sleep(0.05)
|
|
||||||
|
|
||||||
assert len(run_calls) == 1
|
|
||||||
|
|
||||||
def test_run_campaign_dispatched_in_thread(self):
|
|
||||||
a = self._make_agent()
|
|
||||||
done = threading.Event()
|
|
||||||
|
|
||||||
with patch("ria_toolkit_oss.agent.NodeAgent._run_campaign") as mock_run:
|
|
||||||
mock_run.side_effect = lambda *_: done.set()
|
|
||||||
a._dispatch({"command": "run_campaign", "campaign_id": "c1", "payload": {}})
|
|
||||||
done.wait(timeout=2)
|
|
||||||
assert mock_run.called
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# _stop_transmit
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
class TestStopTransmit:
|
|
||||||
def test_no_thread_noop(self):
|
|
||||||
a = _agent()
|
|
||||||
a._stop_transmit() # must not raise
|
|
||||||
|
|
||||||
def test_sets_stop_event(self):
|
|
||||||
a = _agent()
|
|
||||||
a._stop_transmit()
|
|
||||||
assert a._tx_stop.is_set()
|
|
||||||
|
|
||||||
def test_joins_live_thread(self):
|
|
||||||
a = _agent()
|
|
||||||
finished = threading.Event()
|
|
||||||
unblock = threading.Event()
|
|
||||||
|
|
||||||
def _task():
|
|
||||||
unblock.wait(timeout=2)
|
|
||||||
finished.set()
|
|
||||||
|
|
||||||
t = threading.Thread(target=_task, daemon=True)
|
|
||||||
t.start()
|
|
||||||
a._tx_thread = t
|
|
||||||
|
|
||||||
# Signal stop and trigger thread exit
|
|
||||||
a._tx_stop.set()
|
|
||||||
unblock.set()
|
|
||||||
a._stop_transmit()
|
|
||||||
|
|
||||||
assert not t.is_alive()
|
|
||||||
Loading…
Reference in New Issue
Block a user