Merge pull request 'zfp-oss' (#27) from zfp-oss into main
Some checks failed
Build Sphinx Docs Set / Build Docs (push) Successful in 44m43s
Test with tox / Test with tox (3.10) (push) Successful in 1h4m45s
Build Project / Build Project (3.10) (push) Successful in 1h16m56s
Build Project / Build Project (3.12) (push) Successful in 1h16m52s
Test with tox / Test with tox (3.12) (push) Successful in 31m45s
Test with tox / Test with tox (3.11) (push) Successful in 47m45s
Build Project / Build Project (3.11) (push) Failing after 1h9m0s
Some checks failed
Build Sphinx Docs Set / Build Docs (push) Successful in 44m43s
Test with tox / Test with tox (3.10) (push) Successful in 1h4m45s
Build Project / Build Project (3.10) (push) Successful in 1h16m56s
Build Project / Build Project (3.12) (push) Successful in 1h16m52s
Test with tox / Test with tox (3.12) (push) Successful in 31m45s
Test with tox / Test with tox (3.11) (push) Successful in 47m45s
Build Project / Build Project (3.11) (push) Failing after 1h9m0s
Reviewed-on: #27
This commit is contained in:
commit
2881aaf06e
12
poetry.lock
generated
12
poetry.lock
generated
|
|
@ -1,4 +1,4 @@
|
||||||
# This file is automatically @generated by Poetry 2.3.3 and should not be changed by hand.
|
# This file is automatically @generated by Poetry 2.1.2 and should not be changed by hand.
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "alabaster"
|
name = "alabaster"
|
||||||
|
|
@ -230,14 +230,14 @@ uvloop = ["uvloop (>=0.15.2) ; sys_platform != \"win32\"", "winloop (>=0.5.0) ;
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "cachetools"
|
name = "cachetools"
|
||||||
version = "7.0.5"
|
version = "7.0.6"
|
||||||
description = "Extensible memoizing collections and decorators"
|
description = "Extensible memoizing collections and decorators"
|
||||||
optional = false
|
optional = false
|
||||||
python-versions = ">=3.10"
|
python-versions = ">=3.10"
|
||||||
groups = ["test"]
|
groups = ["test"]
|
||||||
files = [
|
files = [
|
||||||
{file = "cachetools-7.0.5-py3-none-any.whl", hash = "sha256:46bc8ebefbe485407621d0a4264b23c080cedd913921bad7ac3ed2f26c183114"},
|
{file = "cachetools-7.0.6-py3-none-any.whl", hash = "sha256:4e94956cfdd3086f12042cdd29318f5ced3893014f7d0d059bf3ead3f85b7f8b"},
|
||||||
{file = "cachetools-7.0.5.tar.gz", hash = "sha256:0cd042c24377200c1dcd225f8b7b12b0ca53cc2c961b43757e774ebe190fd990"},
|
{file = "cachetools-7.0.6.tar.gz", hash = "sha256:e5d524d36d65703a87243a26ff08ad84f73352adbeafb1cde81e207b456aaf24"},
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
|
|
@ -1271,7 +1271,7 @@ files = [
|
||||||
|
|
||||||
[package.dependencies]
|
[package.dependencies]
|
||||||
attrs = ">=22.2.0"
|
attrs = ">=22.2.0"
|
||||||
jsonschema-specifications = ">=2023.3.6"
|
jsonschema-specifications = ">=2023.03.6"
|
||||||
referencing = ">=0.28.4"
|
referencing = ">=0.28.4"
|
||||||
rpds-py = ">=0.25.0"
|
rpds-py = ">=0.25.0"
|
||||||
|
|
||||||
|
|
@ -3749,4 +3749,4 @@ files = [
|
||||||
[metadata]
|
[metadata]
|
||||||
lock-version = "2.1"
|
lock-version = "2.1"
|
||||||
python-versions = ">=3.10"
|
python-versions = ">=3.10"
|
||||||
content-hash = "ffde300b2fc93161d2279a6e2b899bc988d3b5eb3833135821830affc9a5fb62"
|
content-hash = "66c9adf647316db90f963da05e8a83574378bfa4db2c69ce751446b5ee7c408c"
|
||||||
|
|
|
||||||
|
|
@ -50,7 +50,7 @@ dependencies = [
|
||||||
"pyyaml (>=6.0.3,<7.0.0)",
|
"pyyaml (>=6.0.3,<7.0.0)",
|
||||||
"click (>=8.1.0,<9.0.0)",
|
"click (>=8.1.0,<9.0.0)",
|
||||||
"matplotlib (>=3.8.0,<4.0.0)",
|
"matplotlib (>=3.8.0,<4.0.0)",
|
||||||
"paramiko (>=4.0.0)"
|
"paramiko (>=3.5.1)"
|
||||||
]
|
]
|
||||||
|
|
||||||
# [project.optional-dependencies] Commented out to prevent Tox tests from failing
|
# [project.optional-dependencies] Commented out to prevent Tox tests from failing
|
||||||
|
|
@ -149,6 +149,11 @@ exclude = '''
|
||||||
|
|
||||||
[tool.pytest.ini_options]
|
[tool.pytest.ini_options]
|
||||||
pythonpath = ["src"]
|
pythonpath = ["src"]
|
||||||
|
filterwarnings = [
|
||||||
|
# FastAPI emits this internally when handling 422 responses; the constant
|
||||||
|
# is not yet renamed in the installed starlette version, so we can't migrate.
|
||||||
|
"ignore:'HTTP_422_UNPROCESSABLE_ENTITY' is deprecated:DeprecationWarning",
|
||||||
|
]
|
||||||
|
|
||||||
[tool.isort]
|
[tool.isort]
|
||||||
profile = "black"
|
profile = "black"
|
||||||
|
|
|
||||||
|
|
@ -68,7 +68,7 @@ _HEARTBEAT_INTERVAL = 30 # seconds between heartbeats
|
||||||
_POLL_TIMEOUT = 30 # server-side long-poll duration
|
_POLL_TIMEOUT = 30 # server-side long-poll duration
|
||||||
_POLL_CLIENT_TIMEOUT = 40 # client read timeout — slightly longer than server
|
_POLL_CLIENT_TIMEOUT = 40 # client read timeout — slightly longer than server
|
||||||
_RECONNECT_PAUSE = 5 # seconds to wait after a poll error before retrying
|
_RECONNECT_PAUSE = 5 # seconds to wait after a poll error before retrying
|
||||||
_CHUNK_SIZE = 50 * 1024 * 1024 # 50 MB — well below Cloudflare's 100 MB limit
|
_CHUNK_SIZE = 10 * 1024 * 1024 # 10 MB per chunk — fast enough for git-LFS to process within timeout
|
||||||
_DIRECT_THRESHOLD = 90 * 1024 * 1024 # files above this use chunked upload
|
_DIRECT_THRESHOLD = 90 * 1024 * 1024 # files above this use chunked upload
|
||||||
_CAPTURE_SAMPLES = 4096 # IQ samples per inference window
|
_CAPTURE_SAMPLES = 4096 # IQ samples per inference window
|
||||||
_IDLE_LABELS = frozenset({"noise", "idle", "no_signal", "unknown_protocol", "background"})
|
_IDLE_LABELS = frozenset({"noise", "idle", "no_signal", "unknown_protocol", "background"})
|
||||||
|
|
@ -93,16 +93,24 @@ class NodeAgent:
|
||||||
name: str,
|
name: str,
|
||||||
sdr_device: str = "unknown",
|
sdr_device: str = "unknown",
|
||||||
insecure: bool = False,
|
insecure: bool = False,
|
||||||
|
role: str = "general",
|
||||||
|
session_code: str | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.hub_url = hub_url.rstrip("/")
|
self.hub_url = hub_url.rstrip("/")
|
||||||
self.api_key = api_key
|
self.api_key = api_key
|
||||||
self.name = name
|
self.name = name
|
||||||
self.sdr_device = sdr_device
|
self.sdr_device = sdr_device
|
||||||
self.insecure = insecure
|
self.insecure = insecure
|
||||||
|
self.role = role
|
||||||
|
self.session_code = session_code
|
||||||
|
|
||||||
self.node_id: str | None = None
|
self.node_id: str | None = None
|
||||||
self._stop = threading.Event()
|
self._stop = threading.Event()
|
||||||
|
|
||||||
|
# ── TX state ────────────────────────────────────────────────────────
|
||||||
|
self._tx_stop = threading.Event()
|
||||||
|
self._tx_thread: threading.Thread | None = None
|
||||||
|
|
||||||
# ── Inference state ─────────────────────────────────────────────────
|
# ── Inference state ─────────────────────────────────────────────────
|
||||||
# Protected by _inf_lock for cross-thread model swaps.
|
# Protected by _inf_lock for cross-thread model swaps.
|
||||||
self._inf_lock = threading.Lock()
|
self._inf_lock = threading.Lock()
|
||||||
|
|
@ -172,19 +180,27 @@ class NodeAgent:
|
||||||
capabilities = ["campaign"]
|
capabilities = ["campaign"]
|
||||||
if self._ort_available:
|
if self._ort_available:
|
||||||
capabilities.append("inference")
|
capabilities.append("inference")
|
||||||
resp = self._post(
|
if self.role == "tx":
|
||||||
"/composer/nodes/register",
|
capabilities.append("transmit")
|
||||||
json={
|
payload: dict = {
|
||||||
"name": self.name,
|
"name": self.name,
|
||||||
"sdr_device": self.sdr_device,
|
"sdr_device": self.sdr_device,
|
||||||
"ria_toolkit_version": self._ria_version,
|
"ria_toolkit_version": self._ria_version,
|
||||||
"capabilities": capabilities,
|
"capabilities": capabilities,
|
||||||
},
|
"role": self.role,
|
||||||
timeout=15,
|
}
|
||||||
)
|
if self.session_code:
|
||||||
|
payload["session_code"] = self.session_code
|
||||||
|
resp = self._post("/composer/nodes/register", json=payload, timeout=15)
|
||||||
resp.raise_for_status()
|
resp.raise_for_status()
|
||||||
self.node_id = resp.json()["node_id"]
|
self.node_id = resp.json()["node_id"]
|
||||||
logger.info("Registered as %r (node_id=%s)", self.name, self.node_id)
|
logger.info(
|
||||||
|
"Registered as %r (node_id=%s, role=%s%s)",
|
||||||
|
self.name,
|
||||||
|
self.node_id,
|
||||||
|
self.role,
|
||||||
|
f", session_code={self.session_code!r}" if self.session_code else "",
|
||||||
|
)
|
||||||
|
|
||||||
def _deregister(self) -> None:
|
def _deregister(self) -> None:
|
||||||
if not self.node_id:
|
if not self.node_id:
|
||||||
|
|
@ -245,9 +261,10 @@ class NodeAgent:
|
||||||
if command == "run_campaign":
|
if command == "run_campaign":
|
||||||
campaign_id: str = cmd.get("campaign_id") or str(uuid.uuid4())
|
campaign_id: str = cmd.get("campaign_id") or str(uuid.uuid4())
|
||||||
config_dict: dict = cmd.get("payload") or {}
|
config_dict: dict = cmd.get("payload") or {}
|
||||||
|
skip_local_tx: bool = bool(cmd.get("skip_local_tx", False))
|
||||||
threading.Thread(
|
threading.Thread(
|
||||||
target=self._run_campaign,
|
target=self._run_campaign,
|
||||||
args=(campaign_id, config_dict),
|
args=(campaign_id, config_dict, skip_local_tx),
|
||||||
daemon=True,
|
daemon=True,
|
||||||
name=f"campaign-{campaign_id[:8]}",
|
name=f"campaign-{campaign_id[:8]}",
|
||||||
).start()
|
).start()
|
||||||
|
|
@ -269,6 +286,17 @@ class NodeAgent:
|
||||||
self._stop_inference()
|
self._stop_inference()
|
||||||
elif command == "configure_inference":
|
elif command == "configure_inference":
|
||||||
self._queue_sdr_config(cmd)
|
self._queue_sdr_config(cmd)
|
||||||
|
elif command == "start_transmit":
|
||||||
|
threading.Thread(
|
||||||
|
target=self._start_transmit,
|
||||||
|
args=(cmd,),
|
||||||
|
daemon=True,
|
||||||
|
name="ria-start-tx",
|
||||||
|
).start()
|
||||||
|
elif command == "stop_transmit":
|
||||||
|
self._stop_transmit()
|
||||||
|
elif command == "configure_transmit":
|
||||||
|
logger.info("configure_transmit received — will apply on next step boundary")
|
||||||
else:
|
else:
|
||||||
logger.warning("Unknown command %r — ignored", command)
|
logger.warning("Unknown command %r — ignored", command)
|
||||||
|
|
||||||
|
|
@ -276,7 +304,7 @@ class NodeAgent:
|
||||||
# Campaign execution
|
# Campaign execution
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
def _run_campaign(self, campaign_id: str, config_dict: dict) -> None:
|
def _run_campaign(self, campaign_id: str, config_dict: dict, skip_local_tx: bool = False) -> None:
|
||||||
try:
|
try:
|
||||||
from ria_toolkit_oss.orchestration.campaign import CampaignConfig
|
from ria_toolkit_oss.orchestration.campaign import CampaignConfig
|
||||||
from ria_toolkit_oss.orchestration.executor import CampaignExecutor
|
from ria_toolkit_oss.orchestration.executor import CampaignExecutor
|
||||||
|
|
@ -288,10 +316,10 @@ class NodeAgent:
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
logger.info("Campaign %s starting", campaign_id[:8])
|
logger.info("Campaign %s starting (skip_local_tx=%s)", campaign_id[:8], skip_local_tx)
|
||||||
try:
|
try:
|
||||||
config = CampaignConfig.from_dict(config_dict)
|
config = CampaignConfig.from_dict(config_dict)
|
||||||
executor = CampaignExecutor(config)
|
executor = CampaignExecutor(config, skip_local_tx=skip_local_tx)
|
||||||
result = executor.run()
|
result = executor.run()
|
||||||
logger.info("Campaign %s completed — uploading recordings", campaign_id[:8])
|
logger.info("Campaign %s completed — uploading recordings", campaign_id[:8])
|
||||||
self._upload_recordings(campaign_id, config, result)
|
self._upload_recordings(campaign_id, config, result)
|
||||||
|
|
@ -301,6 +329,58 @@ class NodeAgent:
|
||||||
logger.error("Campaign %s failed: %s", campaign_id[:8], exc)
|
logger.error("Campaign %s failed: %s", campaign_id[:8], exc)
|
||||||
self._report_campaign_status(campaign_id, "failed", error=str(exc))
|
self._report_campaign_status(campaign_id, "failed", error=str(exc))
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# TX execution
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def _start_transmit(self, cmd: dict) -> None:
|
||||||
|
"""Execute a synthetic transmit campaign using TxExecutor.
|
||||||
|
|
||||||
|
The command payload mirrors a TransmitterConfig dict with an optional
|
||||||
|
``schedule`` of steps. Each step synthesises a signal and transmits it
|
||||||
|
via the local SDR in TX mode.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
from ria_toolkit_oss.orchestration.tx_executor import TxExecutor
|
||||||
|
except ImportError as exc:
|
||||||
|
logger.error("start_transmit: TxExecutor not available: %s", exc)
|
||||||
|
return
|
||||||
|
|
||||||
|
if self._tx_thread and self._tx_thread.is_alive():
|
||||||
|
logger.warning("start_transmit: TX already running — ignoring duplicate command")
|
||||||
|
return
|
||||||
|
|
||||||
|
self._tx_stop.clear()
|
||||||
|
campaign_id: str = cmd.get("campaign_id") or str(uuid.uuid4())
|
||||||
|
executor = TxExecutor(
|
||||||
|
config=cmd,
|
||||||
|
sdr_device=self.sdr_device,
|
||||||
|
stop_event=self._tx_stop,
|
||||||
|
)
|
||||||
|
self._tx_thread = threading.Thread(
|
||||||
|
target=self._run_tx_campaign,
|
||||||
|
args=(executor, campaign_id),
|
||||||
|
daemon=True,
|
||||||
|
name=f"tx-campaign-{campaign_id[:8]}",
|
||||||
|
)
|
||||||
|
self._tx_thread.start()
|
||||||
|
|
||||||
|
def _run_tx_campaign(self, executor: Any, campaign_id: str) -> None:
|
||||||
|
try:
|
||||||
|
executor.run()
|
||||||
|
logger.info("TX campaign %s completed", campaign_id[:8])
|
||||||
|
self._report_campaign_status(campaign_id, "completed")
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error("TX campaign %s failed: %s", campaign_id[:8], exc)
|
||||||
|
self._report_campaign_status(campaign_id, "failed", error=str(exc))
|
||||||
|
|
||||||
|
def _stop_transmit(self) -> None:
|
||||||
|
"""Signal the TX loop to stop gracefully."""
|
||||||
|
self._tx_stop.set()
|
||||||
|
if self._tx_thread and self._tx_thread.is_alive():
|
||||||
|
self._tx_thread.join(timeout=5.0)
|
||||||
|
logger.info("TX stopped")
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
# Inference — model loading
|
# Inference — model loading
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
|
|
@ -579,13 +659,18 @@ class NodeAgent:
|
||||||
base_url = f"{self.hub_url}/datasets/upload"
|
base_url = f"{self.hub_url}/datasets/upload"
|
||||||
steps = (result.get("steps") if isinstance(result, dict) else getattr(result, "steps", None)) or []
|
steps = (result.get("steps") if isinstance(result, dict) else getattr(result, "steps", None)) or []
|
||||||
|
|
||||||
|
output_obj = getattr(config, "output", None)
|
||||||
|
folder = getattr(output_obj, "folder", None)
|
||||||
|
campaign_name: str = folder if folder is not None else (getattr(config, "name", None) or "")
|
||||||
for step in steps:
|
for step in steps:
|
||||||
output_path: str | None = getattr(step, "output_path", None)
|
output_path: str | None = getattr(step, "output_path", None)
|
||||||
if not output_path:
|
if not output_path:
|
||||||
continue
|
continue
|
||||||
device_id: str = getattr(step, "transmitter_id", "") or ""
|
device_id: str = getattr(step, "transmitter_id", "") or ""
|
||||||
for fpath in _sigmf_files(output_path):
|
for fpath in _sigmf_files(output_path):
|
||||||
filename = os.path.basename(fpath)
|
basename = os.path.basename(fpath)
|
||||||
|
path_parts = [p for p in (campaign_name, device_id) if p]
|
||||||
|
filename = "/".join(path_parts + [basename])
|
||||||
metadata = {
|
metadata = {
|
||||||
"filename": filename,
|
"filename": filename,
|
||||||
"repo_owner": repo_owner,
|
"repo_owner": repo_owner,
|
||||||
|
|
@ -671,7 +756,7 @@ class NodeAgent:
|
||||||
headers=headers,
|
headers=headers,
|
||||||
files={"file": (filename, chunk, "application/octet-stream")},
|
files={"file": (filename, chunk, "application/octet-stream")},
|
||||||
data={**metadata, "upload_id": upload_id, "chunk_index": i, "total_chunks": total_chunks},
|
data={**metadata, "upload_id": upload_id, "chunk_index": i, "total_chunks": total_chunks},
|
||||||
timeout=120,
|
timeout=(30, None), # 30s connect, no read timeout — server may take minutes on final chunk
|
||||||
verify=verify,
|
verify=verify,
|
||||||
)
|
)
|
||||||
if not resp.ok:
|
if not resp.ok:
|
||||||
|
|
@ -848,6 +933,21 @@ def main() -> None:
|
||||||
choices=["DEBUG", "INFO", "WARNING", "ERROR"],
|
choices=["DEBUG", "INFO", "WARNING", "ERROR"],
|
||||||
help="Logging verbosity (default: INFO)",
|
help="Logging verbosity (default: INFO)",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--role",
|
||||||
|
default=None,
|
||||||
|
choices=["general", "rx", "tx"],
|
||||||
|
help=("Node role reported to the hub. " "'tx' enables synthetic transmission commands. " "Default: general"),
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--session-code",
|
||||||
|
default=None,
|
||||||
|
metavar="CODE",
|
||||||
|
help=(
|
||||||
|
"3-word session code to pair this TX agent with a waiting campaign, "
|
||||||
|
"e.g. 'amber-peak-transmit'. Supplied by the campaign UI."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
|
@ -861,6 +961,8 @@ def main() -> None:
|
||||||
device = args.device or cfg.get("device", "unknown")
|
device = args.device or cfg.get("device", "unknown")
|
||||||
insecure = args.insecure if args.insecure is not None else cfg.get("insecure", False)
|
insecure = args.insecure if args.insecure is not None else cfg.get("insecure", False)
|
||||||
log_level = args.log_level or cfg.get("log_level", "INFO")
|
log_level = args.log_level or cfg.get("log_level", "INFO")
|
||||||
|
role = args.role or cfg.get("role", "general")
|
||||||
|
session_code = args.session_code or cfg.get("session_code")
|
||||||
|
|
||||||
if not hub:
|
if not hub:
|
||||||
parser.error("--hub is required (or set 'hub' in the config file)")
|
parser.error("--hub is required (or set 'hub' in the config file)")
|
||||||
|
|
@ -888,6 +990,8 @@ def main() -> None:
|
||||||
name=name,
|
name=name,
|
||||||
sdr_device=device,
|
sdr_device=device,
|
||||||
insecure=insecure,
|
insecure=insecure,
|
||||||
|
role=role,
|
||||||
|
session_code=session_code,
|
||||||
)
|
)
|
||||||
agent.run()
|
agent.run()
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -233,6 +233,9 @@ class TransmitterConfig:
|
||||||
# For sdr_remote control — keys: host, ssh_user, ssh_key_path, device_type, device_id, zmq_port
|
# For sdr_remote control — keys: host, ssh_user, ssh_key_path, device_type, device_id, zmq_port
|
||||||
sdr_remote: Optional[dict] = None
|
sdr_remote: Optional[dict] = None
|
||||||
|
|
||||||
|
# For sdr_agent control — keys: modulation, order, symbol_rate, center_frequency, filter, rolloff
|
||||||
|
sdr_agent: Optional[dict] = None
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_dict(cls, d: dict) -> "TransmitterConfig":
|
def from_dict(cls, d: dict) -> "TransmitterConfig":
|
||||||
schedule = [CaptureStep.from_dict(s) for s in d.get("schedule", [])]
|
schedule = [CaptureStep.from_dict(s) for s in d.get("schedule", [])]
|
||||||
|
|
@ -244,6 +247,7 @@ class TransmitterConfig:
|
||||||
script=d.get("script"),
|
script=d.get("script"),
|
||||||
device=d.get("device"),
|
device=d.get("device"),
|
||||||
sdr_remote=d.get("sdr_remote"),
|
sdr_remote=d.get("sdr_remote"),
|
||||||
|
sdr_agent=d.get("sdr_agent"),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -272,6 +276,7 @@ class OutputConfig:
|
||||||
path: str = "recordings"
|
path: str = "recordings"
|
||||||
device_id: Optional[str] = None # for device-profile campaigns
|
device_id: Optional[str] = None # for device-profile campaigns
|
||||||
repo: Optional[str] = None
|
repo: Optional[str] = None
|
||||||
|
folder: Optional[str] = None # repo subfolder: None = use campaign name, "" = no subfolder, str = custom
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_dict(cls, d: dict) -> "OutputConfig":
|
def from_dict(cls, d: dict) -> "OutputConfig":
|
||||||
|
|
@ -280,6 +285,7 @@ class OutputConfig:
|
||||||
path=str(d.get("path", "recordings")),
|
path=str(d.get("path", "recordings")),
|
||||||
device_id=d.get("device_id"),
|
device_id=d.get("device_id"),
|
||||||
repo=d.get("repo"),
|
repo=d.get("repo"),
|
||||||
|
folder=d.get("folder"),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -293,6 +299,7 @@ class CampaignConfig:
|
||||||
qa: QAConfig = field(default_factory=QAConfig)
|
qa: QAConfig = field(default_factory=QAConfig)
|
||||||
output: OutputConfig = field(default_factory=OutputConfig)
|
output: OutputConfig = field(default_factory=OutputConfig)
|
||||||
mode: str = "controlled_testbed"
|
mode: str = "controlled_testbed"
|
||||||
|
loops: int = 1 # repeat full schedule this many times; labels get _run{N:02d} suffix
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Loaders
|
# Loaders
|
||||||
|
|
@ -320,6 +327,7 @@ class CampaignConfig:
|
||||||
return cls(
|
return cls(
|
||||||
name=safe_name,
|
name=safe_name,
|
||||||
mode=str(campaign_meta.get("mode", "controlled_testbed")),
|
mode=str(campaign_meta.get("mode", "controlled_testbed")),
|
||||||
|
loops=max(1, int(campaign_meta.get("loops", 1))),
|
||||||
recorder=RecorderConfig.from_dict(raw["recorder"]),
|
recorder=RecorderConfig.from_dict(raw["recorder"]),
|
||||||
transmitters=transmitters,
|
transmitters=transmitters,
|
||||||
qa=QAConfig.from_dict(raw.get("qa", {})),
|
qa=QAConfig.from_dict(raw.get("qa", {})),
|
||||||
|
|
@ -384,6 +392,7 @@ class CampaignConfig:
|
||||||
return cls(
|
return cls(
|
||||||
name=safe_name,
|
name=safe_name,
|
||||||
mode=str(campaign_meta.get("mode", "controlled_testbed")),
|
mode=str(campaign_meta.get("mode", "controlled_testbed")),
|
||||||
|
loops=max(1, int(campaign_meta.get("loops", 1))),
|
||||||
recorder=RecorderConfig.from_dict(raw["recorder"]),
|
recorder=RecorderConfig.from_dict(raw["recorder"]),
|
||||||
transmitters=transmitters,
|
transmitters=transmitters,
|
||||||
qa=QAConfig.from_dict(raw.get("qa", {})),
|
qa=QAConfig.from_dict(raw.get("qa", {})),
|
||||||
|
|
@ -486,9 +495,9 @@ class CampaignConfig:
|
||||||
)
|
)
|
||||||
|
|
||||||
def total_capture_time_s(self) -> float:
|
def total_capture_time_s(self) -> float:
|
||||||
"""Sum of all step durations across all transmitters."""
|
"""Sum of all step durations across all transmitters and loops."""
|
||||||
return sum(step.duration for tx in self.transmitters for step in tx.schedule)
|
return sum(step.duration for tx in self.transmitters for step in tx.schedule) * self.loops
|
||||||
|
|
||||||
def total_steps(self) -> int:
|
def total_steps(self) -> int:
|
||||||
"""Total number of capture steps across all transmitters."""
|
"""Total number of capture steps across all transmitters and loops."""
|
||||||
return sum(len(tx.schedule) for tx in self.transmitters)
|
return sum(len(tx.schedule) for tx in self.transmitters) * self.loops
|
||||||
|
|
|
||||||
|
|
@ -5,8 +5,9 @@ from __future__ import annotations
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import subprocess
|
import subprocess
|
||||||
|
import threading
|
||||||
import time
|
import time
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field, replace
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Callable, Optional
|
from typing import Callable, Optional
|
||||||
|
|
||||||
|
|
@ -16,6 +17,7 @@ from ria_toolkit_oss.io.recording import to_sigmf
|
||||||
from .campaign import CampaignConfig, CaptureStep, TransmitterConfig
|
from .campaign import CampaignConfig, CaptureStep, TransmitterConfig
|
||||||
from .labeler import build_output_filename, label_recording
|
from .labeler import build_output_filename, label_recording
|
||||||
from .qa import QAResult, check_recording
|
from .qa import QAResult, check_recording
|
||||||
|
from .tx_executor import TxExecutor
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
@ -169,6 +171,21 @@ def _run_script(script: str, *args: str, timeout: float = 15.0) -> str:
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_tx_params(transmitter: TransmitterConfig) -> dict | None:
|
||||||
|
"""Build a tx_params dict from a transmitter's signal config for SigMF labeling.
|
||||||
|
|
||||||
|
For sdr_agent transmitters, returns the synthetic generation parameters
|
||||||
|
(modulation, order, symbol_rate, etc.) so recordings capture what was
|
||||||
|
transmitted. Returns None for control methods without signal-level params.
|
||||||
|
"""
|
||||||
|
sdr_agent_cfg = getattr(transmitter, "sdr_agent", None)
|
||||||
|
if not sdr_agent_cfg:
|
||||||
|
return None
|
||||||
|
# Extract known signal-level fields; ignore infra fields
|
||||||
|
_INFRA_KEYS = {"node_id", "session_code"}
|
||||||
|
return {k: v for k, v in sdr_agent_cfg.items() if k not in _INFRA_KEYS and v is not None}
|
||||||
|
|
||||||
|
|
||||||
class CampaignExecutor:
|
class CampaignExecutor:
|
||||||
"""Executes a :class:`CampaignConfig` end-to-end.
|
"""Executes a :class:`CampaignConfig` end-to-end.
|
||||||
|
|
||||||
|
|
@ -192,11 +209,14 @@ class CampaignExecutor:
|
||||||
config: CampaignConfig,
|
config: CampaignConfig,
|
||||||
progress_cb: Optional[Callable[[int, int, StepResult], None]] = None,
|
progress_cb: Optional[Callable[[int, int, StepResult], None]] = None,
|
||||||
verbose: bool = False,
|
verbose: bool = False,
|
||||||
|
skip_local_tx: bool = False,
|
||||||
):
|
):
|
||||||
self.config = config
|
self.config = config
|
||||||
self.progress_cb = progress_cb
|
self.progress_cb = progress_cb
|
||||||
|
self.skip_local_tx = skip_local_tx
|
||||||
self._sdr = None
|
self._sdr = None
|
||||||
self._remote_tx_controllers: dict = {}
|
self._remote_tx_controllers: dict = {}
|
||||||
|
self._tx_executors: dict[str, tuple] = {} # tx_id → (TxExecutor, stop_event, thread)
|
||||||
|
|
||||||
if verbose:
|
if verbose:
|
||||||
logging.basicConfig(level=logging.DEBUG)
|
logging.basicConfig(level=logging.DEBUG)
|
||||||
|
|
@ -216,10 +236,12 @@ class CampaignExecutor:
|
||||||
"""
|
"""
|
||||||
result = CampaignResult(campaign_name=self.config.name)
|
result = CampaignResult(campaign_name=self.config.name)
|
||||||
|
|
||||||
|
loops = self.config.loops
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Starting campaign '{self.config.name}': "
|
f"Starting campaign '{self.config.name}': "
|
||||||
f"{self.config.total_steps()} steps, "
|
f"{self.config.total_steps()} steps"
|
||||||
f"~{self.config.total_capture_time_s():.0f}s capture time"
|
+ (f" ({self.config.total_steps() // loops} × {loops} loops)" if loops > 1 else "")
|
||||||
|
+ f", ~{self.config.total_capture_time_s():.0f}s capture time"
|
||||||
)
|
)
|
||||||
|
|
||||||
self._init_sdr()
|
self._init_sdr()
|
||||||
|
|
@ -228,29 +250,36 @@ class CampaignExecutor:
|
||||||
total = self.config.total_steps()
|
total = self.config.total_steps()
|
||||||
step_index = 0
|
step_index = 0
|
||||||
|
|
||||||
for transmitter in self.config.transmitters:
|
for loop_idx in range(loops):
|
||||||
logger.info(f"Transmitter: {transmitter.id} ({len(transmitter.schedule)} steps)")
|
if loops > 1:
|
||||||
for step in transmitter.schedule:
|
logger.info(f"Loop {loop_idx + 1}/{loops}")
|
||||||
step_result = self._execute_step(transmitter, step)
|
for transmitter in self.config.transmitters:
|
||||||
result.steps.append(step_result)
|
logger.info(f"Transmitter: {transmitter.id} ({len(transmitter.schedule)} steps)")
|
||||||
step_index += 1
|
for step in transmitter.schedule:
|
||||||
|
looped_step = replace(step, label=f"{step.label}_run{loop_idx + 1:02d}") if loops > 1 else step
|
||||||
|
step_result = self._execute_step(transmitter, looped_step)
|
||||||
|
result.steps.append(step_result)
|
||||||
|
step_index += 1
|
||||||
|
|
||||||
if self.progress_cb:
|
if self.progress_cb:
|
||||||
self.progress_cb(step_index, total, step_result)
|
self.progress_cb(step_index, total, step_result)
|
||||||
|
|
||||||
if step_result.error:
|
if step_result.error:
|
||||||
logger.warning(f"Step '{step.label}' error: {step_result.error}")
|
logger.warning(f"Step '{looped_step.label}' error: {step_result.error}")
|
||||||
elif step_result.qa.flagged:
|
elif step_result.qa.flagged:
|
||||||
logger.warning(f"Step '{step.label}' flagged for review: " + "; ".join(step_result.qa.issues))
|
logger.warning(
|
||||||
else:
|
f"Step '{looped_step.label}' flagged for review: " + "; ".join(step_result.qa.issues)
|
||||||
logger.info(
|
)
|
||||||
f"Step '{step.label}' OK "
|
else:
|
||||||
f"(SNR {step_result.qa.snr_db:.1f} dB, "
|
logger.info(
|
||||||
f"{step_result.qa.duration_s:.1f}s)"
|
f"Step '{looped_step.label}' OK "
|
||||||
)
|
f"(SNR {step_result.qa.snr_db:.1f} dB, "
|
||||||
|
f"{step_result.qa.duration_s:.1f}s)"
|
||||||
|
)
|
||||||
finally:
|
finally:
|
||||||
self._close_sdr()
|
self._close_sdr()
|
||||||
self._close_remote_tx_controllers()
|
self._close_remote_tx_controllers()
|
||||||
|
self._close_tx_executors()
|
||||||
|
|
||||||
result.end_time = time.time()
|
result.end_time = time.time()
|
||||||
logger.info(
|
logger.info(
|
||||||
|
|
@ -325,6 +354,12 @@ class CampaignExecutor:
|
||||||
logger.warning(f"Error closing remote Tx controller {tx_id}: {exc}")
|
logger.warning(f"Error closing remote Tx controller {tx_id}: {exc}")
|
||||||
self._remote_tx_controllers.clear()
|
self._remote_tx_controllers.clear()
|
||||||
|
|
||||||
|
def _close_tx_executors(self) -> None:
|
||||||
|
for tx_id, (_, stop_event, t) in list(self._tx_executors.items()):
|
||||||
|
stop_event.set()
|
||||||
|
t.join(timeout=5.0)
|
||||||
|
self._tx_executors.clear()
|
||||||
|
|
||||||
def _record(self, duration_s: float) -> Recording:
|
def _record(self, duration_s: float) -> Recording:
|
||||||
"""Capture ``duration_s`` seconds of IQ samples."""
|
"""Capture ``duration_s`` seconds of IQ samples."""
|
||||||
num_samples = int(duration_s * self.config.recorder.sample_rate)
|
num_samples = int(duration_s * self.config.recorder.sample_rate)
|
||||||
|
|
@ -369,6 +404,7 @@ class CampaignExecutor:
|
||||||
step=step,
|
step=step,
|
||||||
capture_timestamp=capture_timestamp,
|
capture_timestamp=capture_timestamp,
|
||||||
campaign_name=self.config.name,
|
campaign_name=self.config.name,
|
||||||
|
tx_params=_extract_tx_params(transmitter),
|
||||||
)
|
)
|
||||||
|
|
||||||
# QA
|
# QA
|
||||||
|
|
@ -437,6 +473,30 @@ class CampaignExecutor:
|
||||||
# Start transmission in background; _record() runs concurrently
|
# Start transmission in background; _record() runs concurrently
|
||||||
ctrl.transmit_async(step.duration + 1.0)
|
ctrl.transmit_async(step.duration + 1.0)
|
||||||
|
|
||||||
|
elif transmitter.control_method == "sdr_agent":
|
||||||
|
if self.skip_local_tx:
|
||||||
|
logger.debug(f"skip_local_tx — TX for '{transmitter.id}' delegated to TX agent node")
|
||||||
|
return
|
||||||
|
if not transmitter.sdr_agent:
|
||||||
|
logger.warning(f"Transmitter '{transmitter.id}' has no sdr_agent config — skipping")
|
||||||
|
return
|
||||||
|
step_dict: dict = {"label": step.label, "duration": step.duration + 1.0}
|
||||||
|
if step.power_dbm is not None:
|
||||||
|
step_dict["power_dbm"] = step.power_dbm
|
||||||
|
tx_config = {
|
||||||
|
"id": transmitter.id,
|
||||||
|
"sdr_agent": transmitter.sdr_agent,
|
||||||
|
"schedule": [step_dict],
|
||||||
|
}
|
||||||
|
rec = self.config.recorder
|
||||||
|
tx_device = transmitter.device or rec.device
|
||||||
|
sdr_device = _DEVICE_ALIASES.get(tx_device.lower(), tx_device.lower())
|
||||||
|
stop_event = threading.Event()
|
||||||
|
executor = TxExecutor(tx_config, sdr_device=sdr_device, stop_event=stop_event)
|
||||||
|
t = threading.Thread(target=executor.run, daemon=True, name=f"tx-{transmitter.id}")
|
||||||
|
self._tx_executors[transmitter.id] = (executor, stop_event, t)
|
||||||
|
t.start()
|
||||||
|
|
||||||
else:
|
else:
|
||||||
logger.warning(f"Unknown control method '{transmitter.control_method}' — skipping")
|
logger.warning(f"Unknown control method '{transmitter.control_method}' — skipping")
|
||||||
|
|
||||||
|
|
@ -459,6 +519,13 @@ class CampaignExecutor:
|
||||||
if ctrl is not None:
|
if ctrl is not None:
|
||||||
ctrl.wait_transmit(timeout=step.duration + 10.0)
|
ctrl.wait_transmit(timeout=step.duration + 10.0)
|
||||||
|
|
||||||
|
elif transmitter.control_method == "sdr_agent":
|
||||||
|
entry = self._tx_executors.pop(transmitter.id, None)
|
||||||
|
if entry is not None:
|
||||||
|
_, stop_event, t = entry
|
||||||
|
stop_event.set()
|
||||||
|
t.join(timeout=step.duration + 10.0)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _step_params_json(transmitter: TransmitterConfig, step: CaptureStep) -> str:
|
def _step_params_json(transmitter: TransmitterConfig, step: CaptureStep) -> str:
|
||||||
"""Serialise step parameters to a JSON string for the control script."""
|
"""Serialise step parameters to a JSON string for the control script."""
|
||||||
|
|
|
||||||
|
|
@ -15,6 +15,7 @@ def label_recording(
|
||||||
step: CaptureStep,
|
step: CaptureStep,
|
||||||
capture_timestamp: float,
|
capture_timestamp: float,
|
||||||
campaign_name: Optional[str] = None,
|
campaign_name: Optional[str] = None,
|
||||||
|
tx_params: Optional[dict] = None,
|
||||||
) -> Recording:
|
) -> Recording:
|
||||||
"""Apply device identity and capture configuration labels to a recording's metadata.
|
"""Apply device identity and capture configuration labels to a recording's metadata.
|
||||||
|
|
||||||
|
|
@ -27,6 +28,9 @@ def label_recording(
|
||||||
step: The capture step that was active during this recording.
|
step: The capture step that was active during this recording.
|
||||||
capture_timestamp: Unix timestamp (float) of when capture started.
|
capture_timestamp: Unix timestamp (float) of when capture started.
|
||||||
campaign_name: Optional campaign name for cross-recording reference.
|
campaign_name: Optional campaign name for cross-recording reference.
|
||||||
|
tx_params: Optional dict of transmitter signal parameters (e.g. modulation,
|
||||||
|
order, symbol_rate) written as ``ria:tx_<key>`` fields so downstream
|
||||||
|
training pipelines know what was transmitted into the recording.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The same recording with updated metadata.
|
The same recording with updated metadata.
|
||||||
|
|
@ -57,6 +61,11 @@ def label_recording(
|
||||||
if step.power_dbm is not None:
|
if step.power_dbm is not None:
|
||||||
recording.update_metadata("tx_power_dbm", step.power_dbm)
|
recording.update_metadata("tx_power_dbm", step.power_dbm)
|
||||||
|
|
||||||
|
# Transmitter signal parameters (e.g. from sdr_agent synthetic generation)
|
||||||
|
if tx_params:
|
||||||
|
for key, value in tx_params.items():
|
||||||
|
recording.update_metadata(f"tx_{key}", value)
|
||||||
|
|
||||||
return recording
|
return recording
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
299
src/ria_toolkit_oss/orchestration/tx_executor.py
Normal file
299
src/ria_toolkit_oss/orchestration/tx_executor.py
Normal file
|
|
@ -0,0 +1,299 @@
|
||||||
|
"""TX campaign executor — synthesises and transmits signals via a local SDR.
|
||||||
|
|
||||||
|
The TxExecutor receives a transmitter config dict (matching the
|
||||||
|
``sdr_agent`` control method's schema) and a step schedule, then for each
|
||||||
|
step builds a signal chain with the block generator and transmits it via
|
||||||
|
the local SDR device.
|
||||||
|
|
||||||
|
Supported modulations (``modulation`` field in config):
|
||||||
|
BPSK, QPSK, 8PSK, 16QAM, 64QAM, 256QAM, FSK, OOK, GMSK, OQPSK
|
||||||
|
|
||||||
|
Example config dict (matches CampaignConfig transmitter with
|
||||||
|
``control_method: sdr_agent``)::
|
||||||
|
|
||||||
|
{
|
||||||
|
"id": "synthetic-tx",
|
||||||
|
"type": "sdr",
|
||||||
|
"control_method": "sdr_agent",
|
||||||
|
"sdr_agent": {
|
||||||
|
"modulation": "QPSK",
|
||||||
|
"order": 4,
|
||||||
|
"symbol_rate": 1000000,
|
||||||
|
"center_frequency": 0.0,
|
||||||
|
"filter": "rrc",
|
||||||
|
"rolloff": 0.35
|
||||||
|
},
|
||||||
|
"schedule": [
|
||||||
|
{"label": "step1", "duration": 10, "power_dbm": -10}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import threading
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_hz(val: object) -> float:
|
||||||
|
"""Parse a frequency value that may be a float (Hz) or a string like '2.45GHz'."""
|
||||||
|
if isinstance(val, (int, float)):
|
||||||
|
return float(val)
|
||||||
|
s = str(val).strip()
|
||||||
|
for suffix, mult in (("GHz", 1e9), ("MHz", 1e6), ("kHz", 1e3), ("Hz", 1.0)):
|
||||||
|
if s.endswith(suffix):
|
||||||
|
return float(s[: -len(suffix)]) * mult
|
||||||
|
return float(s)
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_seconds(val: object) -> float:
|
||||||
|
"""Parse a duration value that may be a float (seconds) or a string like '5s'."""
|
||||||
|
if isinstance(val, (int, float)):
|
||||||
|
return float(val)
|
||||||
|
s = str(val).strip()
|
||||||
|
return float(s[:-1]) if s.endswith("s") else float(s)
|
||||||
|
|
||||||
|
|
||||||
|
# Mapping from modulation name → (PSK/QAM order, generator_type)
|
||||||
|
# 'psk' uses PSKGenerator, 'qam' uses QAMGenerator
|
||||||
|
_MOD_TABLE: dict[str, tuple[int, str]] = {
|
||||||
|
"BPSK": (1, "psk"),
|
||||||
|
"QPSK": (2, "psk"),
|
||||||
|
"8PSK": (3, "psk"),
|
||||||
|
"16QAM": (4, "qam"),
|
||||||
|
"64QAM": (6, "qam"),
|
||||||
|
"256QAM": (8, "qam"),
|
||||||
|
}
|
||||||
|
|
||||||
|
_SPECIAL_MODS = {"FSK", "OOK", "GMSK", "OQPSK"}
|
||||||
|
|
||||||
|
# usrp-uhd-client's tx_recording() streams 2 000-sample chunks and loops the
|
||||||
|
# source buffer for the full tx_time, so only this many samples ever need to
|
||||||
|
# be in RAM regardless of step duration or sample rate.
|
||||||
|
# 50 000 complex64 samples ≈ 400 kB — enough spectral diversity for looping.
|
||||||
|
_SYNTH_BLOCK_SAMPLES = 50_000
|
||||||
|
|
||||||
|
|
||||||
|
class TxExecutor:
|
||||||
|
"""Synthesise and transmit a signal campaign via a local SDR.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config: Transmitter config dict (must have ``sdr_agent`` sub-dict with
|
||||||
|
modulation params, and ``schedule`` list of step dicts).
|
||||||
|
sdr_device: SDR device name to open in TX mode (e.g. "pluto", "usrp").
|
||||||
|
stop_event: External event that aborts the TX loop mid-step.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: dict,
|
||||||
|
sdr_device: str = "unknown",
|
||||||
|
stop_event: threading.Event | None = None,
|
||||||
|
) -> None:
|
||||||
|
self.config = config
|
||||||
|
self.sdr_device = sdr_device
|
||||||
|
self.stop_event = stop_event or threading.Event()
|
||||||
|
self._sdr: Any = None
|
||||||
|
|
||||||
|
def run(self) -> None:
|
||||||
|
"""Execute all steps in the schedule, transmitting for each step duration."""
|
||||||
|
agent_cfg: dict = self.config.get("sdr_agent") or {}
|
||||||
|
schedule: list[dict] = self.config.get("schedule") or []
|
||||||
|
|
||||||
|
if not schedule:
|
||||||
|
logger.warning("TxExecutor: no schedule steps — nothing to transmit")
|
||||||
|
return
|
||||||
|
|
||||||
|
modulation: str = agent_cfg.get("modulation", "QPSK").upper()
|
||||||
|
symbol_rate: float = float(agent_cfg.get("symbol_rate", 1e6))
|
||||||
|
center_freq: float = _parse_hz(agent_cfg.get("center_frequency", 0.0))
|
||||||
|
filter_type: str = agent_cfg.get("filter", "rrc").lower()
|
||||||
|
rolloff: float = float(agent_cfg.get("rolloff", 0.35))
|
||||||
|
loops: int = max(1, int(self.config.get("loops", 1)))
|
||||||
|
|
||||||
|
# Upsampling factor: samples_per_symbol, fixed at 8 for SDR compatibility.
|
||||||
|
sps = 8
|
||||||
|
sample_rate = symbol_rate * sps
|
||||||
|
|
||||||
|
self._init_sdr(sample_rate, center_freq)
|
||||||
|
try:
|
||||||
|
for loop_idx in range(loops):
|
||||||
|
if self.stop_event.is_set():
|
||||||
|
break
|
||||||
|
if loops > 1:
|
||||||
|
logger.info("TX loop %d/%d", loop_idx + 1, loops)
|
||||||
|
for step in schedule:
|
||||||
|
if self.stop_event.is_set():
|
||||||
|
break
|
||||||
|
looped_step = (
|
||||||
|
{**step, "label": f"{step.get('label', 'step')}_run{loop_idx + 1:02d}"} if loops > 1 else step
|
||||||
|
)
|
||||||
|
self._execute_step(looped_step, modulation, sps, symbol_rate, filter_type, rolloff)
|
||||||
|
finally:
|
||||||
|
self._close_sdr()
|
||||||
|
|
||||||
|
def _execute_step(
|
||||||
|
self,
|
||||||
|
step: dict,
|
||||||
|
modulation: str,
|
||||||
|
sps: int,
|
||||||
|
symbol_rate: float,
|
||||||
|
filter_type: str,
|
||||||
|
rolloff: float,
|
||||||
|
) -> None:
|
||||||
|
duration: float = _parse_seconds(step.get("duration", 10.0))
|
||||||
|
label: str = step.get("label", "step")
|
||||||
|
gain: float = float(step.get("power_dbm") or 0.0)
|
||||||
|
sample_rate = symbol_rate * sps
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"TX step '%s': %.0f s, %s @ %.3f MHz (sps=%d, filter=%s)",
|
||||||
|
label,
|
||||||
|
duration,
|
||||||
|
modulation,
|
||||||
|
symbol_rate / 1e6,
|
||||||
|
sps,
|
||||||
|
filter_type,
|
||||||
|
)
|
||||||
|
|
||||||
|
num_samples = int(duration * sample_rate)
|
||||||
|
|
||||||
|
# Synthesise a short representative block. tx_recording() loops this
|
||||||
|
# buffer for the full tx_time using a 2 000-sample streaming callback,
|
||||||
|
# so peak memory is O(_SYNTH_BLOCK_SAMPLES) regardless of duration.
|
||||||
|
block_size = min(num_samples, _SYNTH_BLOCK_SAMPLES)
|
||||||
|
signal = self._synthesise(modulation, sps, block_size, filter_type, rolloff)
|
||||||
|
|
||||||
|
if self._sdr is not None:
|
||||||
|
try:
|
||||||
|
# Apply gain update if SDR supports it
|
||||||
|
if hasattr(self._sdr, "set_tx_gain"):
|
||||||
|
self._sdr.set_tx_gain(gain)
|
||||||
|
self._sdr.tx_recording(signal, tx_time=duration)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error("TX step '%s' SDR error: %s", label, exc)
|
||||||
|
else:
|
||||||
|
# No SDR available — simulate by sleeping for the step duration.
|
||||||
|
logger.warning("TX step '%s': no SDR — simulating %.0f s delay", label, duration)
|
||||||
|
self.stop_event.wait(timeout=duration)
|
||||||
|
|
||||||
|
def _synthesise(
|
||||||
|
self,
|
||||||
|
modulation: str,
|
||||||
|
sps: int,
|
||||||
|
num_samples: int,
|
||||||
|
filter_type: str,
|
||||||
|
rolloff: float,
|
||||||
|
):
|
||||||
|
"""Build a block-generator chain and return IQ samples as a numpy array."""
|
||||||
|
try:
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from ria_toolkit_oss.signal.block_generator import (
|
||||||
|
BinarySource,
|
||||||
|
GMSKModulator,
|
||||||
|
Mapper,
|
||||||
|
OOKModulator,
|
||||||
|
OQPSKModulator,
|
||||||
|
RaisedCosineFilter,
|
||||||
|
RootRaisedCosineFilter,
|
||||||
|
Upsampling,
|
||||||
|
)
|
||||||
|
from ria_toolkit_oss.signal.block_generator.continuous_modulation.fsk_modulator import (
|
||||||
|
FSKModulator,
|
||||||
|
)
|
||||||
|
except ImportError as exc:
|
||||||
|
raise RuntimeError(f"ria_toolkit_oss block generator not available: {exc}") from exc
|
||||||
|
|
||||||
|
# ── Special modulations with their own source-connected modulator ──
|
||||||
|
if modulation in ("OOK", "GMSK", "OQPSK"):
|
||||||
|
src = BinarySource()
|
||||||
|
if modulation == "OOK":
|
||||||
|
mod = OOKModulator(src, samples_per_symbol=sps)
|
||||||
|
elif modulation == "GMSK":
|
||||||
|
mod = GMSKModulator(src, samples_per_symbol=sps)
|
||||||
|
else:
|
||||||
|
mod = OQPSKModulator(src, samples_per_symbol=sps)
|
||||||
|
recording = mod.record(num_samples)
|
||||||
|
flat = np.asarray(recording.data).flatten().astype(np.complex64)
|
||||||
|
if len(flat) < num_samples:
|
||||||
|
flat = np.tile(flat, num_samples // len(flat) + 1)
|
||||||
|
return flat[:num_samples]
|
||||||
|
|
||||||
|
if modulation == "FSK":
|
||||||
|
symbol_rate = num_samples / sps
|
||||||
|
bits_per_sym = 1 # 2-FSK
|
||||||
|
num_bits = max(num_samples // sps, 128) * bits_per_sym
|
||||||
|
bits = BinarySource()((1, num_bits))
|
||||||
|
mod = FSKModulator(
|
||||||
|
num_bits_per_symbol=bits_per_sym,
|
||||||
|
frequency_spacing=symbol_rate * 0.5,
|
||||||
|
symbol_duration=1.0 / max(symbol_rate, 1.0),
|
||||||
|
sampling_frequency=symbol_rate * sps,
|
||||||
|
)
|
||||||
|
flat = np.asarray(mod(bits)).flatten().astype(np.complex64)
|
||||||
|
if len(flat) < num_samples:
|
||||||
|
flat = np.tile(flat, num_samples // len(flat) + 1)
|
||||||
|
return flat[:num_samples]
|
||||||
|
|
||||||
|
# ── PSK / QAM via Mapper → Upsampling → pulse filter ──────────────
|
||||||
|
if modulation not in _MOD_TABLE:
|
||||||
|
logger.warning("Unknown modulation %r — defaulting to QPSK", modulation)
|
||||||
|
modulation = "QPSK"
|
||||||
|
|
||||||
|
bits_per_sym, gen_type = _MOD_TABLE[modulation]
|
||||||
|
mod_family = "QAM" if gen_type == "qam" else "PSK"
|
||||||
|
|
||||||
|
source = BinarySource()
|
||||||
|
mapper = Mapper(constellation_type=mod_family, num_bits_per_symbol=bits_per_sym)
|
||||||
|
upsampler = Upsampling(factor=sps)
|
||||||
|
|
||||||
|
mapper.connect_input([source])
|
||||||
|
upsampler.connect_input([mapper])
|
||||||
|
|
||||||
|
if filter_type in ("rrc",):
|
||||||
|
pulse_filter = RootRaisedCosineFilter(span_in_symbols=6, upsampling_factor=sps, beta=rolloff)
|
||||||
|
pulse_filter.connect_input([upsampler])
|
||||||
|
recording = pulse_filter.record(num_samples)
|
||||||
|
elif filter_type in ("rc",):
|
||||||
|
pulse_filter = RaisedCosineFilter(span_in_symbols=6, upsampling_factor=sps, beta=rolloff)
|
||||||
|
pulse_filter.connect_input([upsampler])
|
||||||
|
recording = pulse_filter.record(num_samples)
|
||||||
|
else:
|
||||||
|
# "none", "rect", "gaussian" — use upsampler output directly
|
||||||
|
recording = upsampler.record(num_samples)
|
||||||
|
|
||||||
|
flat = np.asarray(recording.data).flatten().astype(np.complex64)
|
||||||
|
if len(flat) < num_samples:
|
||||||
|
flat = np.tile(flat, num_samples // len(flat) + 1)
|
||||||
|
return flat[:num_samples]
|
||||||
|
|
||||||
|
def _init_sdr(self, sample_rate: float, center_freq: float) -> None:
|
||||||
|
try:
|
||||||
|
from ria_toolkit_oss.sdr import get_sdr_device
|
||||||
|
|
||||||
|
self._sdr = get_sdr_device(self.sdr_device)
|
||||||
|
self._sdr.init_tx(
|
||||||
|
sample_rate=sample_rate,
|
||||||
|
center_frequency=center_freq,
|
||||||
|
gain=0,
|
||||||
|
channel=0,
|
||||||
|
gain_mode="manual",
|
||||||
|
)
|
||||||
|
logger.info(
|
||||||
|
"TX SDR initialised: %s @ %.3f MHz, %.1f Msps", self.sdr_device, center_freq / 1e6, sample_rate / 1e6
|
||||||
|
)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("TX SDR init failed (%s) — will simulate: %s", self.sdr_device, exc)
|
||||||
|
self._sdr = None
|
||||||
|
|
||||||
|
def _close_sdr(self) -> None:
|
||||||
|
if self._sdr is not None:
|
||||||
|
try:
|
||||||
|
self._sdr.close()
|
||||||
|
except Exception as exc:
|
||||||
|
logger.debug("TX SDR close error: %s", exc)
|
||||||
|
self._sdr = None
|
||||||
|
|
@ -3,7 +3,7 @@
|
||||||
from fastapi import Depends, FastAPI
|
from fastapi import Depends, FastAPI
|
||||||
|
|
||||||
from .auth import require_api_key
|
from .auth import require_api_key
|
||||||
from .routers import inference, orchestrator
|
from .routers import conductor, inference
|
||||||
|
|
||||||
|
|
||||||
def create_app(api_key: str = "") -> FastAPI:
|
def create_app(api_key: str = "") -> FastAPI:
|
||||||
|
|
@ -28,9 +28,9 @@ def create_app(api_key: str = "") -> FastAPI:
|
||||||
app.state.api_key = api_key
|
app.state.api_key = api_key
|
||||||
|
|
||||||
app.include_router(
|
app.include_router(
|
||||||
orchestrator.router,
|
conductor.router,
|
||||||
prefix="/orchestrator",
|
prefix="/conductor",
|
||||||
tags=["Orchestrator"],
|
tags=["Conductor"],
|
||||||
dependencies=[Depends(require_api_key)],
|
dependencies=[Depends(require_api_key)],
|
||||||
)
|
)
|
||||||
app.include_router(
|
app.include_router(
|
||||||
|
|
|
||||||
|
|
@ -7,7 +7,7 @@ from pathlib import Path
|
||||||
from pydantic import BaseModel, field_validator
|
from pydantic import BaseModel, field_validator
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Orchestrator
|
# Conductor
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
"""Orchestrator routes: campaign deployment, status, and cancellation."""
|
"""Conductor routes: campaign deployment, status, and cancellation."""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
|
@ -23,9 +23,9 @@ def serve(host: str, port: int, api_key: str, log_level: str):
|
||||||
|
|
||||||
\b
|
\b
|
||||||
Endpoints:
|
Endpoints:
|
||||||
POST /orchestrator/deploy
|
POST /conductor/deploy
|
||||||
GET /orchestrator/status/{campaign_id}
|
GET /conductor/status/{campaign_id}
|
||||||
POST /orchestrator/cancel/{campaign_id}
|
POST /conductor/cancel/{campaign_id}
|
||||||
POST /inference/load
|
POST /inference/load
|
||||||
POST /inference/start
|
POST /inference/start
|
||||||
POST /inference/stop
|
POST /inference/stop
|
||||||
|
|
|
||||||
314
tests/orchestration/test_executor.py
Normal file
314
tests/orchestration/test_executor.py
Normal file
|
|
@ -0,0 +1,314 @@
|
||||||
|
"""Tests for orchestration executor — StepResult, CampaignResult, _run_script, _extract_tx_params."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import stat
|
||||||
|
from types import SimpleNamespace
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from ria_toolkit_oss.orchestration.executor import (
|
||||||
|
CampaignResult,
|
||||||
|
StepResult,
|
||||||
|
_extract_tx_params,
|
||||||
|
_run_script,
|
||||||
|
)
|
||||||
|
from ria_toolkit_oss.orchestration.qa import QAResult
|
||||||
|
|
||||||
|
|
||||||
|
def _ok_qa() -> QAResult:
|
||||||
|
return QAResult(passed=True, flagged=False, snr_db=20.0, duration_s=1.0)
|
||||||
|
|
||||||
|
|
||||||
|
def _flagged_qa() -> QAResult:
|
||||||
|
return QAResult(passed=True, flagged=True, snr_db=5.0, duration_s=1.0, issues=["low SNR"])
|
||||||
|
|
||||||
|
|
||||||
|
def _failed_qa() -> QAResult:
|
||||||
|
return QAResult(passed=False, flagged=True, snr_db=0.0, duration_s=0.0, issues=["no signal"])
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# StepResult
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestStepResult:
|
||||||
|
def test_ok_true_when_no_error_and_qa_passed(self):
|
||||||
|
r = StepResult(
|
||||||
|
transmitter_id="tx1",
|
||||||
|
step_label="step1",
|
||||||
|
output_path="/out/rec.sigmf-data",
|
||||||
|
qa=_ok_qa(),
|
||||||
|
capture_timestamp=0.0,
|
||||||
|
)
|
||||||
|
assert r.ok is True
|
||||||
|
|
||||||
|
def test_ok_false_when_error_set(self):
|
||||||
|
r = StepResult(
|
||||||
|
transmitter_id="tx1",
|
||||||
|
step_label="step1",
|
||||||
|
output_path=None,
|
||||||
|
qa=_ok_qa(),
|
||||||
|
capture_timestamp=0.0,
|
||||||
|
error="SDR failed",
|
||||||
|
)
|
||||||
|
assert r.ok is False
|
||||||
|
|
||||||
|
def test_ok_false_when_qa_not_passed(self):
|
||||||
|
r = StepResult(
|
||||||
|
transmitter_id="tx1",
|
||||||
|
step_label="step1",
|
||||||
|
output_path="/out",
|
||||||
|
qa=_failed_qa(),
|
||||||
|
capture_timestamp=0.0,
|
||||||
|
)
|
||||||
|
assert r.ok is False
|
||||||
|
|
||||||
|
def test_to_dict_contains_required_keys(self):
|
||||||
|
r = StepResult(
|
||||||
|
transmitter_id="tx1",
|
||||||
|
step_label="step1",
|
||||||
|
output_path="/out/rec.sigmf-data",
|
||||||
|
qa=_ok_qa(),
|
||||||
|
capture_timestamp=1234.5,
|
||||||
|
)
|
||||||
|
d = r.to_dict()
|
||||||
|
assert d["transmitter_id"] == "tx1"
|
||||||
|
assert d["step_label"] == "step1"
|
||||||
|
assert d["output_path"] == "/out/rec.sigmf-data"
|
||||||
|
assert d["capture_timestamp"] == pytest.approx(1234.5)
|
||||||
|
assert d["error"] is None
|
||||||
|
assert d["qa"]["passed"] is True
|
||||||
|
|
||||||
|
def test_to_dict_includes_error_when_set(self):
|
||||||
|
r = StepResult(
|
||||||
|
transmitter_id="tx1",
|
||||||
|
step_label="step1",
|
||||||
|
output_path=None,
|
||||||
|
qa=_failed_qa(),
|
||||||
|
capture_timestamp=0.0,
|
||||||
|
error="disk full",
|
||||||
|
)
|
||||||
|
assert r.to_dict()["error"] == "disk full"
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# CampaignResult
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestCampaignResult:
|
||||||
|
def _make(self, steps: list) -> CampaignResult:
|
||||||
|
r = CampaignResult(campaign_name="test_campaign")
|
||||||
|
r.steps = steps
|
||||||
|
r.end_time = r.start_time + 5.0
|
||||||
|
return r
|
||||||
|
|
||||||
|
def test_total_steps(self):
|
||||||
|
r = self._make(
|
||||||
|
[
|
||||||
|
StepResult("tx1", "s1", "/out", _ok_qa(), 0.0),
|
||||||
|
StepResult("tx1", "s2", "/out", _ok_qa(), 0.0),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
assert r.total_steps == 2
|
||||||
|
|
||||||
|
def test_passed_count(self):
|
||||||
|
r = self._make(
|
||||||
|
[
|
||||||
|
StepResult("tx1", "s1", "/out", _ok_qa(), 0.0),
|
||||||
|
StepResult("tx1", "s2", "/out", _failed_qa(), 0.0),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
assert r.passed == 1
|
||||||
|
|
||||||
|
def test_failed_count(self):
|
||||||
|
r = self._make(
|
||||||
|
[
|
||||||
|
StepResult("tx1", "s1", "/out", _ok_qa(), 0.0),
|
||||||
|
StepResult("tx1", "s2", "/out", _failed_qa(), 0.0),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
assert r.failed == 1
|
||||||
|
|
||||||
|
def test_flagged_count(self):
|
||||||
|
r = self._make(
|
||||||
|
[
|
||||||
|
StepResult("tx1", "s1", "/out", _ok_qa(), 0.0),
|
||||||
|
StepResult("tx1", "s2", "/out", _flagged_qa(), 0.0),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
assert r.flagged == 1
|
||||||
|
|
||||||
|
def test_error_step_counts_as_failed_not_passed(self):
|
||||||
|
r = self._make(
|
||||||
|
[
|
||||||
|
StepResult("tx1", "s1", None, _ok_qa(), 0.0, error="disk full"),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
assert r.failed == 1
|
||||||
|
assert r.passed == 0
|
||||||
|
|
||||||
|
def test_duration_s_from_end_time(self):
|
||||||
|
r = CampaignResult(campaign_name="c")
|
||||||
|
r.start_time = 100.0
|
||||||
|
r.end_time = 115.0
|
||||||
|
assert r.duration_s == pytest.approx(15.0)
|
||||||
|
|
||||||
|
def test_to_dict_structure(self):
|
||||||
|
r = self._make([StepResult("tx1", "s1", "/out", _ok_qa(), 0.0)])
|
||||||
|
d = r.to_dict()
|
||||||
|
assert d["campaign_name"] == "test_campaign"
|
||||||
|
assert d["total_steps"] == 1
|
||||||
|
assert d["passed"] == 1
|
||||||
|
assert len(d["steps"]) == 1
|
||||||
|
|
||||||
|
def test_write_report(self, tmp_path):
|
||||||
|
r = self._make([StepResult("tx1", "s1", "/out", _ok_qa(), 0.0)])
|
||||||
|
out = tmp_path / "report.json"
|
||||||
|
r.write_report(str(out))
|
||||||
|
assert out.exists()
|
||||||
|
data = json.loads(out.read_text())
|
||||||
|
assert data["campaign_name"] == "test_campaign"
|
||||||
|
|
||||||
|
def test_write_report_creates_nested_dirs(self, tmp_path):
|
||||||
|
r = self._make([])
|
||||||
|
out = tmp_path / "nested" / "deep" / "report.json"
|
||||||
|
r.write_report(str(out))
|
||||||
|
assert out.exists()
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# _run_script
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestRunScript:
|
||||||
|
def _script(self, tmp_path, body: str) -> str:
|
||||||
|
s = tmp_path / "script.sh"
|
||||||
|
s.write_text("#!/bin/sh\n" + body)
|
||||||
|
s.chmod(s.stat().st_mode | stat.S_IEXEC)
|
||||||
|
return str(s)
|
||||||
|
|
||||||
|
def test_returns_stdout(self, tmp_path):
|
||||||
|
out = _run_script(self._script(tmp_path, 'echo "hello world"'))
|
||||||
|
assert out == "hello world"
|
||||||
|
|
||||||
|
def test_passes_args_to_script(self, tmp_path):
|
||||||
|
out = _run_script(self._script(tmp_path, 'echo "$1 $2"'), "configure", "arg2")
|
||||||
|
assert "configure" in out
|
||||||
|
|
||||||
|
def test_raises_on_nonzero_exit(self, tmp_path):
|
||||||
|
with pytest.raises(RuntimeError, match="exited 1"):
|
||||||
|
_run_script(self._script(tmp_path, "exit 1"))
|
||||||
|
|
||||||
|
def test_raises_on_relative_path(self):
|
||||||
|
with pytest.raises(RuntimeError, match="absolute"):
|
||||||
|
_run_script("relative/script.sh")
|
||||||
|
|
||||||
|
def test_raises_on_missing_file(self, tmp_path):
|
||||||
|
with pytest.raises(RuntimeError):
|
||||||
|
_run_script(str(tmp_path / "nonexistent.sh"))
|
||||||
|
|
||||||
|
def test_raises_on_timeout(self, tmp_path):
|
||||||
|
with pytest.raises(RuntimeError, match="timed out"):
|
||||||
|
_run_script(self._script(tmp_path, "sleep 60"), timeout=0.1)
|
||||||
|
|
||||||
|
def test_stderr_included_in_error_message(self, tmp_path):
|
||||||
|
with pytest.raises(RuntimeError) as exc_info:
|
||||||
|
_run_script(self._script(tmp_path, "echo 'bad thing' >&2; exit 1"))
|
||||||
|
assert "bad thing" in str(exc_info.value)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# _extract_tx_params
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestExtractTxParams:
|
||||||
|
def test_returns_none_when_no_sdr_agent_attribute(self):
|
||||||
|
tx = SimpleNamespace()
|
||||||
|
assert _extract_tx_params(tx) is None
|
||||||
|
|
||||||
|
def test_returns_none_when_sdr_agent_is_none(self):
|
||||||
|
tx = SimpleNamespace(sdr_agent=None)
|
||||||
|
assert _extract_tx_params(tx) is None
|
||||||
|
|
||||||
|
def test_returns_none_when_sdr_agent_is_empty_dict(self):
|
||||||
|
tx = SimpleNamespace(sdr_agent={})
|
||||||
|
assert _extract_tx_params(tx) is None
|
||||||
|
|
||||||
|
def test_returns_signal_params(self):
|
||||||
|
tx = SimpleNamespace(
|
||||||
|
sdr_agent={
|
||||||
|
"modulation": "QPSK",
|
||||||
|
"symbol_rate": 1e6,
|
||||||
|
"center_frequency": 2.4e9,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
result = _extract_tx_params(tx)
|
||||||
|
assert result == {"modulation": "QPSK", "symbol_rate": 1e6, "center_frequency": 2.4e9}
|
||||||
|
|
||||||
|
def test_strips_infra_key_node_id(self):
|
||||||
|
tx = SimpleNamespace(
|
||||||
|
sdr_agent={
|
||||||
|
"modulation": "BPSK",
|
||||||
|
"node_id": "node_abc123",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
result = _extract_tx_params(tx)
|
||||||
|
assert "node_id" not in result
|
||||||
|
assert result == {"modulation": "BPSK"}
|
||||||
|
|
||||||
|
def test_strips_infra_key_session_code(self):
|
||||||
|
tx = SimpleNamespace(
|
||||||
|
sdr_agent={
|
||||||
|
"modulation": "FSK",
|
||||||
|
"session_code": "amber-peak-transmit",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
result = _extract_tx_params(tx)
|
||||||
|
assert "session_code" not in result
|
||||||
|
|
||||||
|
def test_strips_none_values(self):
|
||||||
|
tx = SimpleNamespace(
|
||||||
|
sdr_agent={
|
||||||
|
"modulation": "QPSK",
|
||||||
|
"order": None,
|
||||||
|
"rolloff": 0.35,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
result = _extract_tx_params(tx)
|
||||||
|
assert "order" not in result
|
||||||
|
assert result == {"modulation": "QPSK", "rolloff": 0.35}
|
||||||
|
|
||||||
|
def test_does_not_mutate_source_dict(self):
|
||||||
|
cfg = {"modulation": "QPSK", "node_id": "nid", "session_code": "code"}
|
||||||
|
tx = SimpleNamespace(sdr_agent=cfg)
|
||||||
|
_extract_tx_params(tx)
|
||||||
|
assert "node_id" in cfg
|
||||||
|
|
||||||
|
def test_full_sdr_agent_config(self):
|
||||||
|
tx = SimpleNamespace(
|
||||||
|
sdr_agent={
|
||||||
|
"modulation": "16QAM",
|
||||||
|
"order": 4,
|
||||||
|
"symbol_rate": 5e6,
|
||||||
|
"center_frequency": 915e6,
|
||||||
|
"filter": "rrc",
|
||||||
|
"rolloff": 0.35,
|
||||||
|
"node_id": "node_xyz",
|
||||||
|
"session_code": "some-code",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
result = _extract_tx_params(tx)
|
||||||
|
assert result == {
|
||||||
|
"modulation": "16QAM",
|
||||||
|
"order": 4,
|
||||||
|
"symbol_rate": 5e6,
|
||||||
|
"center_frequency": 915e6,
|
||||||
|
"filter": "rrc",
|
||||||
|
"rolloff": 0.35,
|
||||||
|
}
|
||||||
|
|
@ -109,6 +109,38 @@ class TestLabelRecording:
|
||||||
result = label_recording(rec, "iphone13_001", _wifi_step(), time.time())
|
result = label_recording(rec, "iphone13_001", _wifi_step(), time.time())
|
||||||
assert result is rec
|
assert result is rec
|
||||||
|
|
||||||
|
def test_tx_params_none_by_default(self):
|
||||||
|
rec = label_recording(_simple_recording(), "iphone13_001", _wifi_step(), time.time())
|
||||||
|
tx_keys = [k for k in rec.metadata if k.startswith("tx_")]
|
||||||
|
assert tx_keys == []
|
||||||
|
|
||||||
|
def test_tx_params_written_as_tx_prefix_keys(self):
|
||||||
|
params = {"modulation": "QPSK", "symbol_rate": 1e6}
|
||||||
|
rec = label_recording(_simple_recording(), "dev", _wifi_step(), time.time(), tx_params=params)
|
||||||
|
assert rec.metadata["tx_modulation"] == "QPSK"
|
||||||
|
assert rec.metadata["tx_symbol_rate"] == pytest.approx(1e6)
|
||||||
|
|
||||||
|
def test_tx_params_multiple_fields(self):
|
||||||
|
params = {
|
||||||
|
"modulation": "16QAM",
|
||||||
|
"order": 4,
|
||||||
|
"symbol_rate": 5e6,
|
||||||
|
"center_frequency": 915e6,
|
||||||
|
"filter": "rrc",
|
||||||
|
"rolloff": 0.35,
|
||||||
|
}
|
||||||
|
rec = label_recording(_simple_recording(), "dev", _wifi_step(), time.time(), tx_params=params)
|
||||||
|
for k, v in params.items():
|
||||||
|
assert f"tx_{k}" in rec.metadata
|
||||||
|
assert (
|
||||||
|
rec.metadata[f"tx_{k}"] == pytest.approx(v) if isinstance(v, float) else rec.metadata[f"tx_{k}"] == v
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_tx_params_empty_dict_writes_nothing(self):
|
||||||
|
rec = label_recording(_simple_recording(), "dev", _wifi_step(), time.time(), tx_params={})
|
||||||
|
tx_keys = [k for k in rec.metadata if k.startswith("tx_") and k != "tx_power_dbm"]
|
||||||
|
assert tx_keys == []
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# build_output_filename
|
# build_output_filename
|
||||||
|
|
|
||||||
153
tests/orchestration/test_tx_executor.py
Normal file
153
tests/orchestration/test_tx_executor.py
Normal file
|
|
@ -0,0 +1,153 @@
|
||||||
|
"""Tests for TxExecutor — signal synthesis and step execution."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import threading
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from ria_toolkit_oss.orchestration.tx_executor import TxExecutor
|
||||||
|
|
||||||
|
|
||||||
|
def _cfg(modulation="QPSK", symbol_rate=100_000, steps=None):
|
||||||
|
return {
|
||||||
|
"id": "test-tx",
|
||||||
|
"type": "sdr",
|
||||||
|
"control_method": "sdr_agent",
|
||||||
|
"sdr_agent": {
|
||||||
|
"modulation": modulation,
|
||||||
|
"symbol_rate": symbol_rate,
|
||||||
|
"center_frequency": 0.0,
|
||||||
|
"filter": "rrc",
|
||||||
|
"rolloff": 0.35,
|
||||||
|
},
|
||||||
|
"schedule": steps or [{"label": "step1", "duration": 0.001, "power_dbm": -10}],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Initialisation
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestTxExecutorInit:
|
||||||
|
def test_stores_sdr_device(self):
|
||||||
|
ex = TxExecutor(_cfg(), sdr_device="pluto")
|
||||||
|
assert ex.sdr_device == "pluto"
|
||||||
|
|
||||||
|
def test_stop_event_created_when_not_supplied(self):
|
||||||
|
ex = TxExecutor(_cfg())
|
||||||
|
assert isinstance(ex.stop_event, threading.Event)
|
||||||
|
assert not ex.stop_event.is_set()
|
||||||
|
|
||||||
|
def test_accepts_external_stop_event(self):
|
||||||
|
ev = threading.Event()
|
||||||
|
ex = TxExecutor(_cfg(), stop_event=ev)
|
||||||
|
assert ex.stop_event is ev
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# run() — schedule iteration
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestTxExecutorRun:
|
||||||
|
def test_empty_schedule_returns_immediately(self):
|
||||||
|
cfg = _cfg(steps=[])
|
||||||
|
ex = TxExecutor(cfg)
|
||||||
|
ex.run() # must not raise or block
|
||||||
|
|
||||||
|
def test_pre_set_stop_event_skips_all_steps(self):
|
||||||
|
ev = threading.Event()
|
||||||
|
ev.set()
|
||||||
|
ex = TxExecutor(_cfg(), stop_event=ev)
|
||||||
|
# If stop was set, _execute_step should never be called.
|
||||||
|
# run() should return cleanly without attempting synthesis.
|
||||||
|
ex.run()
|
||||||
|
|
||||||
|
def test_no_sdr_falls_back_to_simulation(self, monkeypatch):
|
||||||
|
"""Without SDR hardware TxExecutor simulates by calling stop_event.wait."""
|
||||||
|
cfg = _cfg(steps=[{"label": "s", "duration": 0.001, "power_dbm": 0}])
|
||||||
|
waited = []
|
||||||
|
real_ev = threading.Event()
|
||||||
|
|
||||||
|
def _fake_wait(timeout=None):
|
||||||
|
waited.append(timeout)
|
||||||
|
return False
|
||||||
|
|
||||||
|
monkeypatch.setattr(real_ev, "wait", _fake_wait)
|
||||||
|
|
||||||
|
# Patch SDR init to always fail (forces simulation path)
|
||||||
|
with patch.object(TxExecutor, "_init_sdr", lambda self, *a, **kw: setattr(self, "_sdr", None)):
|
||||||
|
ex = TxExecutor(cfg, sdr_device="nonexistent_xyz", stop_event=real_ev)
|
||||||
|
ex.run()
|
||||||
|
|
||||||
|
assert len(waited) >= 1, "expected stop_event.wait to be called for simulation"
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# _synthesise — all modulation types and filter types
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestSynthesise:
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def _ex(self):
|
||||||
|
self.ex = TxExecutor(_cfg())
|
||||||
|
|
||||||
|
def _synth(self, mod, num_samples=256):
|
||||||
|
return self.ex._synthesise(mod, sps=4, num_samples=num_samples, filter_type="rrc", rolloff=0.35)
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("mod", ["BPSK", "QPSK", "8PSK", "16QAM", "64QAM", "256QAM"])
|
||||||
|
def test_psk_qam_returns_complex64_array(self, mod):
|
||||||
|
sig = self._synth(mod)
|
||||||
|
assert sig.dtype == np.complex64
|
||||||
|
assert len(sig) == 256
|
||||||
|
|
||||||
|
def test_fsk_returns_correct_length(self):
|
||||||
|
sig = self._synth("FSK")
|
||||||
|
assert len(sig) == 256
|
||||||
|
|
||||||
|
def test_ook_returns_correct_length(self):
|
||||||
|
sig = self._synth("OOK")
|
||||||
|
assert len(sig) == 256
|
||||||
|
|
||||||
|
def test_gmsk_returns_correct_length(self):
|
||||||
|
sig = self._synth("GMSK")
|
||||||
|
assert len(sig) == 256
|
||||||
|
|
||||||
|
def test_oqpsk_returns_correct_length(self):
|
||||||
|
sig = self._synth("OQPSK")
|
||||||
|
assert len(sig) == 256
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("mod", ["BPSK", "QPSK", "16QAM", "FSK", "OOK", "GMSK"])
|
||||||
|
def test_samples_are_finite(self, mod):
|
||||||
|
sig = self._synth(mod)
|
||||||
|
assert np.all(np.isfinite(sig.real)), f"{mod}: non-finite real samples"
|
||||||
|
assert np.all(np.isfinite(sig.imag)), f"{mod}: non-finite imag samples"
|
||||||
|
|
||||||
|
def test_unknown_modulation_defaults_to_qpsk(self):
|
||||||
|
sig = self._synth("UNKNOWN_MOD_XYZ")
|
||||||
|
assert len(sig) == 256
|
||||||
|
assert sig.dtype == np.complex64
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("filter_type", ["rrc", "rc", "gaussian", "rect", "none"])
|
||||||
|
def test_all_filter_types(self, filter_type):
|
||||||
|
sig = self.ex._synthesise("QPSK", sps=4, num_samples=128, filter_type=filter_type, rolloff=0.35)
|
||||||
|
assert len(sig) == 128
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("n", [64, 128, 512, 1024])
|
||||||
|
def test_output_length_matches_requested_samples(self, n):
|
||||||
|
sig = self._synth("QPSK", num_samples=n)
|
||||||
|
assert len(sig) == n
|
||||||
|
|
||||||
|
def test_bpsk_output_is_complex_not_real(self):
|
||||||
|
sig = self._synth("BPSK")
|
||||||
|
# complex64 always has imag part; just check dtype
|
||||||
|
assert sig.dtype == np.complex64
|
||||||
|
|
||||||
|
def test_256qam_correct_length(self):
|
||||||
|
sig = self._synth("256QAM")
|
||||||
|
assert len(sig) == 256
|
||||||
|
|
@ -189,6 +189,8 @@ class TestNoiseCommand:
|
||||||
"10000",
|
"10000",
|
||||||
"--noise-type",
|
"--noise-type",
|
||||||
"gaussian",
|
"gaussian",
|
||||||
|
"--power",
|
||||||
|
"0.01",
|
||||||
"--output",
|
"--output",
|
||||||
output,
|
output,
|
||||||
"-q",
|
"-q",
|
||||||
|
|
@ -234,7 +236,7 @@ class TestNoiseCommand:
|
||||||
"--num-samples",
|
"--num-samples",
|
||||||
"10000",
|
"10000",
|
||||||
"--power",
|
"--power",
|
||||||
"0.5",
|
"0.01",
|
||||||
"--output",
|
"--output",
|
||||||
output,
|
output,
|
||||||
"-q",
|
"-q",
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,6 @@
|
||||||
"""Tests for the RT-OSS HTTP server.
|
"""Tests for the RT-OSS HTTP server.
|
||||||
|
|
||||||
Covers: auth, inference lifecycle (without SDR/ONNX hardware), orchestrator
|
Covers: auth, inference lifecycle (without SDR/ONNX hardware), conductor
|
||||||
lifecycle (with mocked executor), and state helpers.
|
lifecycle (with mocked executor), and state helpers.
|
||||||
|
|
||||||
``start_inference`` and ``_inference_loop`` require real SDR hardware and an
|
``start_inference`` and ``_inference_loop`` require real SDR hardware and an
|
||||||
|
|
@ -286,17 +286,17 @@ class TestInferenceStop:
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# POST /orchestrator/deploy
|
# POST /conductor/deploy
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
class TestOrchestratorDeploy:
|
class TestConductorDeploy:
|
||||||
def test_deploy_422_on_invalid_config(self, client):
|
def test_deploy_422_on_invalid_config(self, client):
|
||||||
with patch(
|
with patch(
|
||||||
"ria_toolkit_oss.server.routers.orchestrator.CampaignConfig.from_dict",
|
"ria_toolkit_oss.server.routers.conductor.CampaignConfig.from_dict",
|
||||||
side_effect=ValueError("missing required field 'name'"),
|
side_effect=ValueError("missing required field 'name'"),
|
||||||
):
|
):
|
||||||
resp = client.post("/orchestrator/deploy", json={"config": {}})
|
resp = client.post("/conductor/deploy", json={"config": {}})
|
||||||
assert resp.status_code == 422
|
assert resp.status_code == 422
|
||||||
|
|
||||||
def test_deploy_returns_campaign_id(self, client):
|
def test_deploy_returns_campaign_id(self, client):
|
||||||
|
|
@ -307,10 +307,10 @@ class TestOrchestratorDeploy:
|
||||||
mock_executor.return_value.run.return_value = MagicMock(to_dict=lambda: {})
|
mock_executor.return_value.run.return_value = MagicMock(to_dict=lambda: {})
|
||||||
|
|
||||||
with (
|
with (
|
||||||
patch("ria_toolkit_oss.server.routers.orchestrator.CampaignConfig.from_dict", return_value=mock_cfg),
|
patch("ria_toolkit_oss.server.routers.conductor.CampaignConfig.from_dict", return_value=mock_cfg),
|
||||||
patch("ria_toolkit_oss.server.routers.orchestrator.CampaignExecutor", mock_executor),
|
patch("ria_toolkit_oss.server.routers.conductor.CampaignExecutor", mock_executor),
|
||||||
):
|
):
|
||||||
resp = client.post("/orchestrator/deploy", json={"config": {"name": "test_campaign"}})
|
resp = client.post("/conductor/deploy", json={"config": {"name": "test_campaign"}})
|
||||||
|
|
||||||
assert resp.status_code == 200
|
assert resp.status_code == 200
|
||||||
body = resp.json()
|
body = resp.json()
|
||||||
|
|
@ -325,23 +325,23 @@ class TestOrchestratorDeploy:
|
||||||
mock_executor.return_value.run.return_value = MagicMock(to_dict=lambda: {})
|
mock_executor.return_value.run.return_value = MagicMock(to_dict=lambda: {})
|
||||||
|
|
||||||
with (
|
with (
|
||||||
patch("ria_toolkit_oss.server.routers.orchestrator.CampaignConfig.from_dict", return_value=mock_cfg),
|
patch("ria_toolkit_oss.server.routers.conductor.CampaignConfig.from_dict", return_value=mock_cfg),
|
||||||
patch("ria_toolkit_oss.server.routers.orchestrator.CampaignExecutor", mock_executor),
|
patch("ria_toolkit_oss.server.routers.conductor.CampaignExecutor", mock_executor),
|
||||||
):
|
):
|
||||||
resp = client.post("/orchestrator/deploy", json={"config": {}})
|
resp = client.post("/conductor/deploy", json={"config": {}})
|
||||||
|
|
||||||
campaign_id = resp.json()["campaign_id"]
|
campaign_id = resp.json()["campaign_id"]
|
||||||
assert state_module._campaigns.get(campaign_id) is not None
|
assert state_module._campaigns.get(campaign_id) is not None
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# GET /orchestrator/status/{campaign_id}
|
# GET /conductor/status/{campaign_id}
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
class TestOrchestratorStatus:
|
class TestConductorStatus:
|
||||||
def test_status_404_for_unknown_id(self, client):
|
def test_status_404_for_unknown_id(self, client):
|
||||||
resp = client.get("/orchestrator/status/nonexistent-id")
|
resp = client.get("/conductor/status/nonexistent-id")
|
||||||
assert resp.status_code == 404
|
assert resp.status_code == 404
|
||||||
|
|
||||||
def test_status_returns_campaign_state(self, client):
|
def test_status_returns_campaign_state(self, client):
|
||||||
|
|
@ -357,7 +357,7 @@ class TestOrchestratorStatus:
|
||||||
)
|
)
|
||||||
state_module._campaigns["abc-123"] = state
|
state_module._campaigns["abc-123"] = state
|
||||||
|
|
||||||
resp = client.get("/orchestrator/status/abc-123")
|
resp = client.get("/conductor/status/abc-123")
|
||||||
assert resp.status_code == 200
|
assert resp.status_code == 200
|
||||||
body = resp.json()
|
body = resp.json()
|
||||||
assert body["campaign_id"] == "abc-123"
|
assert body["campaign_id"] == "abc-123"
|
||||||
|
|
@ -367,13 +367,13 @@ class TestOrchestratorStatus:
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# POST /orchestrator/cancel/{campaign_id}
|
# POST /conductor/cancel/{campaign_id}
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
class TestOrchestratorCancel:
|
class TestConductorCancel:
|
||||||
def test_cancel_404_for_unknown_id(self, client):
|
def test_cancel_404_for_unknown_id(self, client):
|
||||||
resp = client.post("/orchestrator/cancel/no-such-id")
|
resp = client.post("/conductor/cancel/no-such-id")
|
||||||
assert resp.status_code == 404
|
assert resp.status_code == 404
|
||||||
|
|
||||||
def test_cancel_sets_cancel_event(self, client):
|
def test_cancel_sets_cancel_event(self, client):
|
||||||
|
|
@ -387,7 +387,7 @@ class TestOrchestratorCancel:
|
||||||
)
|
)
|
||||||
state_module._campaigns["camp-to-cancel"] = state
|
state_module._campaigns["camp-to-cancel"] = state
|
||||||
|
|
||||||
resp = client.post("/orchestrator/cancel/camp-to-cancel")
|
resp = client.post("/conductor/cancel/camp-to-cancel")
|
||||||
assert resp.status_code == 200
|
assert resp.status_code == 200
|
||||||
assert resp.json()["cancelled"] is True
|
assert resp.json()["cancelled"] is True
|
||||||
assert cancel_event.is_set()
|
assert cancel_event.is_set()
|
||||||
|
|
@ -403,7 +403,7 @@ class TestOrchestratorCancel:
|
||||||
)
|
)
|
||||||
state_module._campaigns["done"] = state
|
state_module._campaigns["done"] = state
|
||||||
|
|
||||||
resp = client.post("/orchestrator/cancel/done")
|
resp = client.post("/conductor/cancel/done")
|
||||||
assert resp.status_code == 200
|
assert resp.status_code == 200
|
||||||
assert resp.json()["cancelled"] is False
|
assert resp.json()["cancelled"] is False
|
||||||
assert not cancel_event.is_set()
|
assert not cancel_event.is_set()
|
||||||
|
|
|
||||||
247
tests/test_agent.py
Normal file
247
tests/test_agent.py
Normal file
|
|
@ -0,0 +1,247 @@
|
||||||
|
"""Tests for NodeAgent — TX role, session code, and TX command dispatch."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
from ria_toolkit_oss.agent import NodeAgent
|
||||||
|
|
||||||
|
|
||||||
|
def _agent(role="general", session_code=None, **kwargs):
|
||||||
|
return NodeAgent(
|
||||||
|
hub_url="http://hub.test",
|
||||||
|
api_key="test-key",
|
||||||
|
name="test-node",
|
||||||
|
sdr_device="mock",
|
||||||
|
role=role,
|
||||||
|
session_code=session_code,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _mock_register(agent, node_id="node_abc123"):
|
||||||
|
"""Patch _post so _register() returns a fake node_id response."""
|
||||||
|
resp = MagicMock()
|
||||||
|
resp.json.return_value = {"node_id": node_id}
|
||||||
|
resp.raise_for_status.return_value = None
|
||||||
|
agent._post = MagicMock(return_value=resp)
|
||||||
|
return agent._post
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Initialisation
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestNodeAgentInit:
|
||||||
|
def test_stores_role_general(self):
|
||||||
|
assert _agent(role="general").role == "general"
|
||||||
|
|
||||||
|
def test_stores_role_tx(self):
|
||||||
|
assert _agent(role="tx").role == "tx"
|
||||||
|
|
||||||
|
def test_stores_role_rx(self):
|
||||||
|
assert _agent(role="rx").role == "rx"
|
||||||
|
|
||||||
|
def test_session_code_stored(self):
|
||||||
|
assert _agent(session_code="amber-peak-transmit").session_code == "amber-peak-transmit"
|
||||||
|
|
||||||
|
def test_session_code_none_by_default(self):
|
||||||
|
assert _agent().session_code is None
|
||||||
|
|
||||||
|
def test_tx_stop_event_created(self):
|
||||||
|
a = _agent()
|
||||||
|
assert isinstance(a._tx_stop, threading.Event)
|
||||||
|
|
||||||
|
def test_tx_thread_none_initially(self):
|
||||||
|
assert _agent()._tx_thread is None
|
||||||
|
|
||||||
|
def test_hub_url_trailing_slash_stripped(self):
|
||||||
|
a = NodeAgent(hub_url="http://hub.test/", api_key="k", name="n")
|
||||||
|
assert a.hub_url == "http://hub.test"
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# _register payload
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestNodeAgentRegisterPayload:
|
||||||
|
def _payload(self, agent):
|
||||||
|
post = _mock_register(agent)
|
||||||
|
agent._register()
|
||||||
|
_, kwargs = post.call_args
|
||||||
|
return kwargs["json"]
|
||||||
|
|
||||||
|
def test_general_role_in_payload(self):
|
||||||
|
payload = self._payload(_agent(role="general"))
|
||||||
|
assert payload["role"] == "general"
|
||||||
|
|
||||||
|
def test_tx_role_in_payload(self):
|
||||||
|
payload = self._payload(_agent(role="tx"))
|
||||||
|
assert payload["role"] == "tx"
|
||||||
|
|
||||||
|
def test_tx_role_adds_transmit_capability(self):
|
||||||
|
payload = self._payload(_agent(role="tx"))
|
||||||
|
assert "transmit" in payload["capabilities"]
|
||||||
|
|
||||||
|
def test_general_role_omits_transmit_capability(self):
|
||||||
|
payload = self._payload(_agent(role="general"))
|
||||||
|
assert "transmit" not in payload.get("capabilities", [])
|
||||||
|
|
||||||
|
def test_session_code_included_when_set(self):
|
||||||
|
payload = self._payload(_agent(role="tx", session_code="amber-peak-transmit"))
|
||||||
|
assert payload["session_code"] == "amber-peak-transmit"
|
||||||
|
|
||||||
|
def test_session_code_omitted_when_none(self):
|
||||||
|
payload = self._payload(_agent())
|
||||||
|
assert "session_code" not in payload
|
||||||
|
|
||||||
|
def test_register_stores_returned_node_id(self):
|
||||||
|
a = _agent()
|
||||||
|
_mock_register(a, node_id="node_xyz999")
|
||||||
|
a._register()
|
||||||
|
assert a.node_id == "node_xyz999"
|
||||||
|
|
||||||
|
def test_name_in_payload(self):
|
||||||
|
a = NodeAgent(hub_url="http://h", api_key="k", name="my-bench")
|
||||||
|
_mock_register(a)
|
||||||
|
a._register()
|
||||||
|
_, kwargs = a._post.call_args
|
||||||
|
assert kwargs["json"]["name"] == "my-bench"
|
||||||
|
|
||||||
|
def test_sdr_device_in_payload(self):
|
||||||
|
a = _agent()
|
||||||
|
post = _mock_register(a)
|
||||||
|
a._register()
|
||||||
|
_, kwargs = post.call_args
|
||||||
|
assert kwargs["json"]["sdr_device"] == "mock"
|
||||||
|
|
||||||
|
def test_campaign_capability_always_present(self):
|
||||||
|
for role in ("general", "rx", "tx"):
|
||||||
|
a = _agent(role=role)
|
||||||
|
payload = self._payload(a)
|
||||||
|
assert "campaign" in payload["capabilities"]
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# _dispatch — TX commands
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestNodeAgentDispatch:
|
||||||
|
def _make_agent(self):
|
||||||
|
a = _agent(role="tx")
|
||||||
|
a.node_id = "node_abc"
|
||||||
|
a._report_campaign_status = MagicMock()
|
||||||
|
return a
|
||||||
|
|
||||||
|
def test_start_transmit_spawns_thread(self):
|
||||||
|
a = self._make_agent()
|
||||||
|
done = threading.Event()
|
||||||
|
|
||||||
|
class _FakeExecutor:
|
||||||
|
def run(self_):
|
||||||
|
done.wait(timeout=2)
|
||||||
|
|
||||||
|
with patch("ria_toolkit_oss.orchestration.tx_executor.TxExecutor", return_value=_FakeExecutor()):
|
||||||
|
a._dispatch({"command": "start_transmit", "sdr_agent": {}, "schedule": []})
|
||||||
|
time.sleep(0.05)
|
||||||
|
assert a._tx_thread is not None
|
||||||
|
done.set()
|
||||||
|
|
||||||
|
def test_start_transmit_clears_stop_event(self):
|
||||||
|
a = self._make_agent()
|
||||||
|
a._tx_stop.set() # pre-set
|
||||||
|
|
||||||
|
done = threading.Event()
|
||||||
|
|
||||||
|
class _FakeExecutor:
|
||||||
|
def run(self_):
|
||||||
|
done.wait(timeout=2)
|
||||||
|
|
||||||
|
with patch("ria_toolkit_oss.orchestration.tx_executor.TxExecutor", return_value=_FakeExecutor()):
|
||||||
|
a._dispatch({"command": "start_transmit", "sdr_agent": {}, "schedule": []})
|
||||||
|
time.sleep(0.05)
|
||||||
|
assert not a._tx_stop.is_set()
|
||||||
|
done.set()
|
||||||
|
|
||||||
|
def test_stop_transmit_sets_stop_event(self):
|
||||||
|
a = self._make_agent()
|
||||||
|
a._dispatch({"command": "stop_transmit"})
|
||||||
|
assert a._tx_stop.is_set()
|
||||||
|
|
||||||
|
def test_configure_transmit_does_not_raise(self):
|
||||||
|
a = self._make_agent()
|
||||||
|
a._dispatch({"command": "configure_transmit", "modulation": "BPSK"})
|
||||||
|
|
||||||
|
def test_unknown_command_is_ignored(self):
|
||||||
|
a = self._make_agent()
|
||||||
|
a._dispatch({"command": "frobnicate_xyz"})
|
||||||
|
|
||||||
|
def test_duplicate_start_transmit_ignored_while_running(self):
|
||||||
|
a = self._make_agent()
|
||||||
|
done = threading.Event()
|
||||||
|
run_calls = []
|
||||||
|
|
||||||
|
class _FakeExecutor:
|
||||||
|
def run(self_):
|
||||||
|
run_calls.append(1)
|
||||||
|
done.wait(timeout=2)
|
||||||
|
|
||||||
|
with patch("ria_toolkit_oss.orchestration.tx_executor.TxExecutor", return_value=_FakeExecutor()):
|
||||||
|
a._dispatch({"command": "start_transmit"})
|
||||||
|
time.sleep(0.05)
|
||||||
|
a._dispatch({"command": "start_transmit"}) # second while first alive
|
||||||
|
done.set()
|
||||||
|
time.sleep(0.05)
|
||||||
|
|
||||||
|
assert len(run_calls) == 1
|
||||||
|
|
||||||
|
def test_run_campaign_dispatched_in_thread(self):
|
||||||
|
a = self._make_agent()
|
||||||
|
done = threading.Event()
|
||||||
|
|
||||||
|
with patch("ria_toolkit_oss.agent.NodeAgent._run_campaign") as mock_run:
|
||||||
|
mock_run.side_effect = lambda *_: done.set()
|
||||||
|
a._dispatch({"command": "run_campaign", "campaign_id": "c1", "payload": {}})
|
||||||
|
done.wait(timeout=2)
|
||||||
|
assert mock_run.called
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# _stop_transmit
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestStopTransmit:
|
||||||
|
def test_no_thread_noop(self):
|
||||||
|
a = _agent()
|
||||||
|
a._stop_transmit() # must not raise
|
||||||
|
|
||||||
|
def test_sets_stop_event(self):
|
||||||
|
a = _agent()
|
||||||
|
a._stop_transmit()
|
||||||
|
assert a._tx_stop.is_set()
|
||||||
|
|
||||||
|
def test_joins_live_thread(self):
|
||||||
|
a = _agent()
|
||||||
|
finished = threading.Event()
|
||||||
|
unblock = threading.Event()
|
||||||
|
|
||||||
|
def _task():
|
||||||
|
unblock.wait(timeout=2)
|
||||||
|
finished.set()
|
||||||
|
|
||||||
|
t = threading.Thread(target=_task, daemon=True)
|
||||||
|
t.start()
|
||||||
|
a._tx_thread = t
|
||||||
|
|
||||||
|
# Signal stop and trigger thread exit
|
||||||
|
a._tx_stop.set()
|
||||||
|
unblock.set()
|
||||||
|
a._stop_transmit()
|
||||||
|
|
||||||
|
assert not t.is_alive()
|
||||||
Loading…
Reference in New Issue
Block a user