zfp-oss #27

Merged
benchinnery merged 15 commits from zfp-oss into main 2026-04-23 11:10:43 -04:00
6 changed files with 92 additions and 79 deletions
Showing only changes of commit c27a5944c7 - Show all commits

View File

@ -931,11 +931,7 @@ def main() -> None:
"--role", "--role",
default=None, default=None,
choices=["general", "rx", "tx"], choices=["general", "rx", "tx"],
help=( help=("Node role reported to the hub. " "'tx' enables synthetic transmission commands. " "Default: general"),
"Node role reported to the hub. "
"'tx' enables synthetic transmission commands. "
"Default: general"
),
) )
parser.add_argument( parser.add_argument(
"--session-code", "--session-code",

View File

@ -33,7 +33,6 @@ from __future__ import annotations
import logging import logging
import threading import threading
import time
from typing import Any from typing import Any
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -41,11 +40,11 @@ logger = logging.getLogger(__name__)
# Mapping from modulation name → (PSK/QAM order, generator_type) # Mapping from modulation name → (PSK/QAM order, generator_type)
# 'psk' uses PSKGenerator, 'qam' uses QAMGenerator # 'psk' uses PSKGenerator, 'qam' uses QAMGenerator
_MOD_TABLE: dict[str, tuple[int, str]] = { _MOD_TABLE: dict[str, tuple[int, str]] = {
"BPSK": (1, "psk"), "BPSK": (1, "psk"),
"QPSK": (2, "psk"), "QPSK": (2, "psk"),
"8PSK": (3, "psk"), "8PSK": (3, "psk"),
"16QAM": (4, "qam"), "16QAM": (4, "qam"),
"64QAM": (6, "qam"), "64QAM": (6, "qam"),
"256QAM": (8, "qam"), "256QAM": (8, "qam"),
} }
@ -117,7 +116,12 @@ class TxExecutor:
logger.info( logger.info(
"TX step '%s': %.0f s, %s @ %.3f MHz (sps=%d, filter=%s)", "TX step '%s': %.0f s, %s @ %.3f MHz (sps=%d, filter=%s)",
label, duration, modulation, symbol_rate / 1e6, sps, filter_type, label,
duration,
modulation,
symbol_rate / 1e6,
sps,
filter_type,
) )
num_samples = int(duration * sample_rate) num_samples = int(duration * sample_rate)
@ -133,9 +137,7 @@ class TxExecutor:
logger.error("TX step '%s' SDR error: %s", label, exc) logger.error("TX step '%s' SDR error: %s", label, exc)
else: else:
# No SDR available — simulate by sleeping for the step duration. # No SDR available — simulate by sleeping for the step duration.
logger.warning( logger.warning("TX step '%s': no SDR — simulating %.0f s delay", label, duration)
"TX step '%s': no SDR — simulating %.0f s delay", label, duration
)
self.stop_event.wait(timeout=duration) self.stop_event.wait(timeout=duration)
def _synthesise( def _synthesise(
@ -149,6 +151,7 @@ class TxExecutor:
"""Build a block-generator chain and return IQ samples as a numpy array.""" """Build a block-generator chain and return IQ samples as a numpy array."""
try: try:
import numpy as np import numpy as np
from ria_toolkit_oss.signal.block_generator import ( from ria_toolkit_oss.signal.block_generator import (
BinarySource, BinarySource,
GMSKModulator, GMSKModulator,
@ -231,6 +234,7 @@ class TxExecutor:
def _init_sdr(self, sample_rate: float, center_freq: float) -> None: def _init_sdr(self, sample_rate: float, center_freq: float) -> None:
try: try:
from ria_toolkit_oss.sdr import get_sdr_device from ria_toolkit_oss.sdr import get_sdr_device
self._sdr = get_sdr_device(self.sdr_device) self._sdr = get_sdr_device(self.sdr_device)
self._sdr.init_tx( self._sdr.init_tx(
sample_rate=sample_rate, sample_rate=sample_rate,
@ -239,7 +243,9 @@ class TxExecutor:
channel=0, channel=0,
gain_mode="manual", gain_mode="manual",
) )
logger.info("TX SDR initialised: %s @ %.3f MHz, %.1f Msps", self.sdr_device, center_freq / 1e6, sample_rate / 1e6) logger.info(
"TX SDR initialised: %s @ %.3f MHz, %.1f Msps", self.sdr_device, center_freq / 1e6, sample_rate / 1e6
)
except Exception as exc: except Exception as exc:
logger.warning("TX SDR init failed (%s) — will simulate: %s", self.sdr_device, exc) logger.warning("TX SDR init failed (%s) — will simulate: %s", self.sdr_device, exc)
self._sdr = None self._sdr = None

View File

@ -4,7 +4,6 @@ from __future__ import annotations
import json import json
import stat import stat
import threading
from types import SimpleNamespace from types import SimpleNamespace
import pytest import pytest
@ -108,37 +107,47 @@ class TestCampaignResult:
return r return r
def test_total_steps(self): def test_total_steps(self):
r = self._make([ r = self._make(
StepResult("tx1", "s1", "/out", _ok_qa(), 0.0), [
StepResult("tx1", "s2", "/out", _ok_qa(), 0.0), StepResult("tx1", "s1", "/out", _ok_qa(), 0.0),
]) StepResult("tx1", "s2", "/out", _ok_qa(), 0.0),
]
)
assert r.total_steps == 2 assert r.total_steps == 2
def test_passed_count(self): def test_passed_count(self):
r = self._make([ r = self._make(
StepResult("tx1", "s1", "/out", _ok_qa(), 0.0), [
StepResult("tx1", "s2", "/out", _failed_qa(), 0.0), StepResult("tx1", "s1", "/out", _ok_qa(), 0.0),
]) StepResult("tx1", "s2", "/out", _failed_qa(), 0.0),
]
)
assert r.passed == 1 assert r.passed == 1
def test_failed_count(self): def test_failed_count(self):
r = self._make([ r = self._make(
StepResult("tx1", "s1", "/out", _ok_qa(), 0.0), [
StepResult("tx1", "s2", "/out", _failed_qa(), 0.0), StepResult("tx1", "s1", "/out", _ok_qa(), 0.0),
]) StepResult("tx1", "s2", "/out", _failed_qa(), 0.0),
]
)
assert r.failed == 1 assert r.failed == 1
def test_flagged_count(self): def test_flagged_count(self):
r = self._make([ r = self._make(
StepResult("tx1", "s1", "/out", _ok_qa(), 0.0), [
StepResult("tx1", "s2", "/out", _flagged_qa(), 0.0), StepResult("tx1", "s1", "/out", _ok_qa(), 0.0),
]) StepResult("tx1", "s2", "/out", _flagged_qa(), 0.0),
]
)
assert r.flagged == 1 assert r.flagged == 1
def test_error_step_counts_as_failed_not_passed(self): def test_error_step_counts_as_failed_not_passed(self):
r = self._make([ r = self._make(
StepResult("tx1", "s1", None, _ok_qa(), 0.0, error="disk full"), [
]) StepResult("tx1", "s1", None, _ok_qa(), 0.0, error="disk full"),
]
)
assert r.failed == 1 assert r.failed == 1
assert r.passed == 0 assert r.passed == 0
@ -232,37 +241,45 @@ class TestExtractTxParams:
assert _extract_tx_params(tx) is None assert _extract_tx_params(tx) is None
def test_returns_signal_params(self): def test_returns_signal_params(self):
tx = SimpleNamespace(sdr_agent={ tx = SimpleNamespace(
"modulation": "QPSK", sdr_agent={
"symbol_rate": 1e6, "modulation": "QPSK",
"center_frequency": 2.4e9, "symbol_rate": 1e6,
}) "center_frequency": 2.4e9,
}
)
result = _extract_tx_params(tx) result = _extract_tx_params(tx)
assert result == {"modulation": "QPSK", "symbol_rate": 1e6, "center_frequency": 2.4e9} assert result == {"modulation": "QPSK", "symbol_rate": 1e6, "center_frequency": 2.4e9}
def test_strips_infra_key_node_id(self): def test_strips_infra_key_node_id(self):
tx = SimpleNamespace(sdr_agent={ tx = SimpleNamespace(
"modulation": "BPSK", sdr_agent={
"node_id": "node_abc123", "modulation": "BPSK",
}) "node_id": "node_abc123",
}
)
result = _extract_tx_params(tx) result = _extract_tx_params(tx)
assert "node_id" not in result assert "node_id" not in result
assert result == {"modulation": "BPSK"} assert result == {"modulation": "BPSK"}
def test_strips_infra_key_session_code(self): def test_strips_infra_key_session_code(self):
tx = SimpleNamespace(sdr_agent={ tx = SimpleNamespace(
"modulation": "FSK", sdr_agent={
"session_code": "amber-peak-transmit", "modulation": "FSK",
}) "session_code": "amber-peak-transmit",
}
)
result = _extract_tx_params(tx) result = _extract_tx_params(tx)
assert "session_code" not in result assert "session_code" not in result
def test_strips_none_values(self): def test_strips_none_values(self):
tx = SimpleNamespace(sdr_agent={ tx = SimpleNamespace(
"modulation": "QPSK", sdr_agent={
"order": None, "modulation": "QPSK",
"rolloff": 0.35, "order": None,
}) "rolloff": 0.35,
}
)
result = _extract_tx_params(tx) result = _extract_tx_params(tx)
assert "order" not in result assert "order" not in result
assert result == {"modulation": "QPSK", "rolloff": 0.35} assert result == {"modulation": "QPSK", "rolloff": 0.35}
@ -274,16 +291,18 @@ class TestExtractTxParams:
assert "node_id" in cfg assert "node_id" in cfg
def test_full_sdr_agent_config(self): def test_full_sdr_agent_config(self):
tx = SimpleNamespace(sdr_agent={ tx = SimpleNamespace(
"modulation": "16QAM", sdr_agent={
"order": 4, "modulation": "16QAM",
"symbol_rate": 5e6, "order": 4,
"center_frequency": 915e6, "symbol_rate": 5e6,
"filter": "rrc", "center_frequency": 915e6,
"rolloff": 0.35, "filter": "rrc",
"node_id": "node_xyz", "rolloff": 0.35,
"session_code": "some-code", "node_id": "node_xyz",
}) "session_code": "some-code",
}
)
result = _extract_tx_params(tx) result = _extract_tx_params(tx)
assert result == { assert result == {
"modulation": "16QAM", "modulation": "16QAM",

View File

@ -116,9 +116,7 @@ class TestLabelRecording:
def test_tx_params_written_as_tx_prefix_keys(self): def test_tx_params_written_as_tx_prefix_keys(self):
params = {"modulation": "QPSK", "symbol_rate": 1e6} params = {"modulation": "QPSK", "symbol_rate": 1e6}
rec = label_recording( rec = label_recording(_simple_recording(), "dev", _wifi_step(), time.time(), tx_params=params)
_simple_recording(), "dev", _wifi_step(), time.time(), tx_params=params
)
assert rec.metadata["tx_modulation"] == "QPSK" assert rec.metadata["tx_modulation"] == "QPSK"
assert rec.metadata["tx_symbol_rate"] == pytest.approx(1e6) assert rec.metadata["tx_symbol_rate"] == pytest.approx(1e6)
@ -131,17 +129,15 @@ class TestLabelRecording:
"filter": "rrc", "filter": "rrc",
"rolloff": 0.35, "rolloff": 0.35,
} }
rec = label_recording( rec = label_recording(_simple_recording(), "dev", _wifi_step(), time.time(), tx_params=params)
_simple_recording(), "dev", _wifi_step(), time.time(), tx_params=params
)
for k, v in params.items(): for k, v in params.items():
assert f"tx_{k}" in rec.metadata assert f"tx_{k}" in rec.metadata
assert rec.metadata[f"tx_{k}"] == pytest.approx(v) if isinstance(v, float) else rec.metadata[f"tx_{k}"] == v assert (
rec.metadata[f"tx_{k}"] == pytest.approx(v) if isinstance(v, float) else rec.metadata[f"tx_{k}"] == v
)
def test_tx_params_empty_dict_writes_nothing(self): def test_tx_params_empty_dict_writes_nothing(self):
rec = label_recording( rec = label_recording(_simple_recording(), "dev", _wifi_step(), time.time(), tx_params={})
_simple_recording(), "dev", _wifi_step(), time.time(), tx_params={}
)
tx_keys = [k for k in rec.metadata if k.startswith("tx_") and k != "tx_power_dbm"] tx_keys = [k for k in rec.metadata if k.startswith("tx_") and k != "tx_power_dbm"]
assert tx_keys == [] assert tx_keys == []

View File

@ -3,7 +3,7 @@
from __future__ import annotations from __future__ import annotations
import threading import threading
from unittest.mock import MagicMock, patch from unittest.mock import patch
import numpy as np import numpy as np
import pytest import pytest
@ -73,8 +73,6 @@ class TestTxExecutorRun:
waited = [] waited = []
real_ev = threading.Event() real_ev = threading.Event()
orig_wait = real_ev.wait
def _fake_wait(timeout=None): def _fake_wait(timeout=None):
waited.append(timeout) waited.append(timeout)
return False return False

View File

@ -6,8 +6,6 @@ import threading
import time import time
from unittest.mock import MagicMock, patch from unittest.mock import MagicMock, patch
import pytest
from ria_toolkit_oss.agent import NodeAgent from ria_toolkit_oss.agent import NodeAgent