ria-toolkit-oss/tests/orchestration/test_labeler.py
2026-04-20 12:33:14 -04:00

182 lines
7.0 KiB
Python

"""Tests for orchestration labeler."""
import time
import numpy as np
import pytest
from ria_toolkit_oss.datatypes.recording import Recording
from ria_toolkit_oss.orchestration.campaign import CaptureStep
from ria_toolkit_oss.orchestration.labeler import build_output_filename, label_recording
def _simple_recording() -> Recording:
sr = 1e6
n = 1000
data = np.ones(n, dtype=np.complex64)
return Recording(data, metadata={"sample_rate": sr, "center_frequency": 2.45e9})
def _wifi_step() -> CaptureStep:
return CaptureStep(
duration=30.0,
label="ch06_20mhz_idle",
channel=6,
bandwidth_mhz=20.0,
traffic="idle",
)
def _bt_step() -> CaptureStep:
return CaptureStep(
duration=30.0,
label="audio_stream",
traffic="audio_stream",
connection_interval_ms=7.5,
)
# ---------------------------------------------------------------------------
# label_recording
# ---------------------------------------------------------------------------
class TestLabelRecording:
def test_device_id_set(self):
rec = label_recording(_simple_recording(), "iphone13_001", _wifi_step(), time.time())
assert rec.metadata["device_id"] == "iphone13_001"
def test_capture_timestamp_set(self):
ts = time.time()
rec = label_recording(_simple_recording(), "iphone13_001", _wifi_step(), ts)
assert rec.metadata["capture_timestamp"] == pytest.approx(ts, abs=1.0)
def test_step_label_set(self):
rec = label_recording(_simple_recording(), "iphone13_001", _wifi_step(), time.time())
assert rec.metadata["step_label"] == "ch06_20mhz_idle"
def test_step_duration_set(self):
rec = label_recording(_simple_recording(), "iphone13_001", _wifi_step(), time.time())
assert rec.metadata["step_duration_s"] == pytest.approx(30.0)
def test_campaign_name_optional(self):
rec = label_recording(_simple_recording(), "iphone13_001", _wifi_step(), time.time())
assert "campaign" not in rec.metadata
def test_campaign_name_when_provided(self):
rec = label_recording(
_simple_recording(), "iphone13_001", _wifi_step(), time.time(), campaign_name="test_campaign"
)
assert rec.metadata["campaign"] == "test_campaign"
def test_wifi_channel_set(self):
rec = label_recording(_simple_recording(), "iphone13_001", _wifi_step(), time.time())
assert rec.metadata["wifi_channel"] == 6
def test_wifi_bandwidth_set(self):
rec = label_recording(_simple_recording(), "iphone13_001", _wifi_step(), time.time())
assert rec.metadata["wifi_bandwidth_mhz"] == pytest.approx(20.0)
def test_traffic_pattern_set(self):
rec = label_recording(_simple_recording(), "iphone13_001", _wifi_step(), time.time())
assert rec.metadata["traffic_pattern"] == "idle"
def test_bt_connection_interval_set(self):
rec = label_recording(_simple_recording(), "airpods_001", _bt_step(), time.time())
assert rec.metadata["bt_connection_interval_ms"] == pytest.approx(7.5)
def test_no_channel_key_for_bt(self):
"""BT steps with no channel should not add wifi_channel to metadata."""
rec = label_recording(_simple_recording(), "airpods_001", _bt_step(), time.time())
assert "wifi_channel" not in rec.metadata
def test_no_bandwidth_key_for_bt(self):
rec = label_recording(_simple_recording(), "airpods_001", _bt_step(), time.time())
assert "wifi_bandwidth_mhz" not in rec.metadata
def test_power_dbm_set(self):
step = CaptureStep(duration=30.0, label="test", traffic="idle", power_dbm=15.0)
rec = label_recording(_simple_recording(), "dev_001", step, time.time())
assert rec.metadata["tx_power_dbm"] == pytest.approx(15.0)
def test_no_power_key_when_unset(self):
rec = label_recording(_simple_recording(), "iphone13_001", _wifi_step(), time.time())
assert "tx_power_dbm" not in rec.metadata
def test_returns_same_recording(self):
"""label_recording should mutate and return the same Recording object."""
rec = _simple_recording()
result = label_recording(rec, "iphone13_001", _wifi_step(), time.time())
assert result is rec
def test_tx_params_none_by_default(self):
rec = label_recording(_simple_recording(), "iphone13_001", _wifi_step(), time.time())
tx_keys = [k for k in rec.metadata if k.startswith("tx_")]
assert tx_keys == []
def test_tx_params_written_as_tx_prefix_keys(self):
params = {"modulation": "QPSK", "symbol_rate": 1e6}
rec = label_recording(
_simple_recording(), "dev", _wifi_step(), time.time(), tx_params=params
)
assert rec.metadata["tx_modulation"] == "QPSK"
assert rec.metadata["tx_symbol_rate"] == pytest.approx(1e6)
def test_tx_params_multiple_fields(self):
params = {
"modulation": "16QAM",
"order": 4,
"symbol_rate": 5e6,
"center_frequency": 915e6,
"filter": "rrc",
"rolloff": 0.35,
}
rec = label_recording(
_simple_recording(), "dev", _wifi_step(), time.time(), tx_params=params
)
for k, v in params.items():
assert f"tx_{k}" in rec.metadata
assert rec.metadata[f"tx_{k}"] == pytest.approx(v) if isinstance(v, float) else rec.metadata[f"tx_{k}"] == v
def test_tx_params_empty_dict_writes_nothing(self):
rec = label_recording(
_simple_recording(), "dev", _wifi_step(), time.time(), tx_params={}
)
tx_keys = [k for k in rec.metadata if k.startswith("tx_") and k != "tx_power_dbm"]
assert tx_keys == []
# ---------------------------------------------------------------------------
# build_output_filename
# ---------------------------------------------------------------------------
class TestBuildOutputFilename:
def test_basic_wifi(self):
step = CaptureStep(duration=30.0, label="ch06_20mhz_idle")
fn = build_output_filename("iphone13_wifi_001", step)
assert fn == "iphone13_wifi_001/ch06_20mhz_idle"
def test_bt_step(self):
step = CaptureStep(duration=30.0, label="audio_stream")
fn = build_output_filename("airpods_pro_bt_001", step)
assert fn == "airpods_pro_bt_001/audio_stream"
def test_spaces_in_device_id_replaced(self):
step = CaptureStep(duration=30.0, label="idle")
fn = build_output_filename("my device", step)
assert " " not in fn
assert fn == "my_device/idle"
def test_slashes_in_label_replaced(self):
step = CaptureStep(duration=30.0, label="ch/6/idle")
fn = build_output_filename("dev_001", step)
assert "/" not in fn.split("/", 1)[1] # only the separator slash should remain
def test_path_structure(self):
"""Filename should be exactly '<device_id>/<label>' (one level of nesting)."""
step = CaptureStep(duration=30.0, label="idle")
fn = build_output_filename("dev_001", step)
parts = fn.split("/")
assert len(parts) == 2