490 lines
16 KiB
Python
490 lines
16 KiB
Python
"""Tests for orchestration campaign schema and YAML parsing."""
|
||
|
||
import os
|
||
import tempfile
|
||
|
||
import pytest
|
||
import yaml
|
||
|
||
from ria_toolkit_oss.orchestration.campaign import (
|
||
CampaignConfig,
|
||
CaptureStep,
|
||
QAConfig,
|
||
RecorderConfig,
|
||
parse_bandwidth_mhz,
|
||
parse_duration,
|
||
parse_frequency,
|
||
parse_gain,
|
||
)
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# parse_duration
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
class TestParseDuration:
|
||
def test_seconds_suffix(self):
|
||
assert parse_duration("30s") == 30.0
|
||
|
||
def test_seconds_suffix_long(self):
|
||
assert parse_duration("30sec") == 30.0
|
||
|
||
def test_minutes_suffix(self):
|
||
assert parse_duration("1.5m") == 90.0
|
||
|
||
def test_minutes_suffix_long(self):
|
||
assert parse_duration("2min") == 120.0
|
||
|
||
def test_hours_suffix(self):
|
||
assert parse_duration("2h") == 7200.0
|
||
|
||
def test_hours_suffix_long(self):
|
||
assert parse_duration("1hr") == 3600.0
|
||
|
||
def test_numeric_int(self):
|
||
assert parse_duration(45) == 45.0
|
||
|
||
def test_numeric_float(self):
|
||
assert parse_duration(1.5) == 1.5
|
||
|
||
def test_bare_number_string(self):
|
||
# No unit → treated as seconds
|
||
assert parse_duration("60") == 60.0
|
||
|
||
def test_invalid_raises(self):
|
||
with pytest.raises(ValueError):
|
||
parse_duration("two minutes")
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# parse_frequency
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
class TestParseFrequency:
|
||
def test_ghz(self):
|
||
assert parse_frequency("2.45GHz") == pytest.approx(2.45e9)
|
||
|
||
def test_mhz(self):
|
||
assert parse_frequency("40MHz") == pytest.approx(40e6)
|
||
|
||
def test_khz(self):
|
||
assert parse_frequency("433k") == pytest.approx(433e3)
|
||
|
||
def test_scientific_notation_string(self):
|
||
assert parse_frequency("915e6") == pytest.approx(915e6)
|
||
|
||
def test_numeric_float(self):
|
||
assert parse_frequency(2.45e9) == pytest.approx(2.45e9)
|
||
|
||
def test_numeric_int(self):
|
||
assert parse_frequency(1000000) == pytest.approx(1e6)
|
||
|
||
def test_hz_suffix_optional(self):
|
||
# "40M" and "40MHz" should both work
|
||
assert parse_frequency("40M") == pytest.approx(40e6)
|
||
assert parse_frequency("40MHz") == pytest.approx(40e6)
|
||
|
||
def test_invalid_raises(self):
|
||
with pytest.raises(ValueError):
|
||
parse_frequency("two point four gigs")
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# parse_gain
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
class TestParseGain:
|
||
def test_db_suffix(self):
|
||
assert parse_gain("40dB") == pytest.approx(40.0)
|
||
|
||
def test_db_suffix_lowercase(self):
|
||
assert parse_gain("32db") == pytest.approx(32.0)
|
||
|
||
def test_auto(self):
|
||
assert parse_gain("auto") == "auto"
|
||
|
||
def test_auto_case_insensitive(self):
|
||
assert parse_gain("AUTO") == "auto"
|
||
|
||
def test_numeric_int(self):
|
||
assert parse_gain(32) == pytest.approx(32.0)
|
||
|
||
def test_numeric_float(self):
|
||
assert parse_gain(32.5) == pytest.approx(32.5)
|
||
|
||
def test_invalid_raises(self):
|
||
with pytest.raises(ValueError):
|
||
parse_gain("high")
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# parse_bandwidth_mhz
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
class TestParseBandwidthMhz:
|
||
def test_mhz_suffix(self):
|
||
assert parse_bandwidth_mhz("20MHz") == pytest.approx(20.0)
|
||
|
||
def test_numeric(self):
|
||
assert parse_bandwidth_mhz(40) == pytest.approx(40.0)
|
||
|
||
def test_none(self):
|
||
assert parse_bandwidth_mhz(None) is None
|
||
|
||
def test_invalid_raises(self):
|
||
with pytest.raises(ValueError):
|
||
parse_bandwidth_mhz("wide")
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# CaptureStep.from_dict
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
class TestCaptureStep:
|
||
def test_wifi_step_auto_label(self):
|
||
d = {"channel": 6, "bandwidth": "20MHz", "traffic": "iperf_udp", "duration": "30s"}
|
||
step = CaptureStep.from_dict(d)
|
||
assert step.duration == 30.0
|
||
assert step.channel == 6
|
||
assert step.bandwidth_mhz == 20.0
|
||
assert step.traffic == "iperf_udp"
|
||
assert step.label == "ch06_20mhz_iperf_udp"
|
||
|
||
def test_explicit_label(self):
|
||
d = {"channel": 1, "bandwidth": "20MHz", "traffic": "idle", "duration": "30s", "label": "my_label"}
|
||
step = CaptureStep.from_dict(d)
|
||
assert step.label == "my_label"
|
||
|
||
def test_fallback_label(self):
|
||
# No channel/bandwidth/traffic → label falls back to "capture"
|
||
d = {"duration": "10s"}
|
||
step = CaptureStep.from_dict(d)
|
||
assert step.label == "capture"
|
||
|
||
def test_power_parsed(self):
|
||
d = {"channel": 6, "bandwidth": "20MHz", "traffic": "idle", "duration": "30s", "power": "15dBm"}
|
||
step = CaptureStep.from_dict(d)
|
||
assert step.power_dbm == pytest.approx(15.0)
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# RecorderConfig.from_dict
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
class TestRecorderConfig:
|
||
def test_basic(self):
|
||
d = {"device": "usrp_b210", "center_freq": "2.45GHz", "sample_rate": "40MHz", "gain": "40dB"}
|
||
rec = RecorderConfig.from_dict(d)
|
||
assert rec.device == "usrp_b210"
|
||
assert rec.center_freq == pytest.approx(2.45e9)
|
||
assert rec.sample_rate == pytest.approx(40e6)
|
||
assert rec.gain == pytest.approx(40.0)
|
||
assert rec.bandwidth is None
|
||
|
||
def test_auto_gain(self):
|
||
d = {"device": "pluto", "center_freq": "2.45GHz", "sample_rate": "20MHz", "gain": "auto"}
|
||
rec = RecorderConfig.from_dict(d)
|
||
assert rec.gain == "auto"
|
||
|
||
def test_bandwidth_set(self):
|
||
d = {"device": "pluto", "center_freq": "2.45GHz", "sample_rate": "20MHz", "gain": 32, "bandwidth": "20MHz"}
|
||
rec = RecorderConfig.from_dict(d)
|
||
assert rec.bandwidth == pytest.approx(20e6)
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# QAConfig.from_dict
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
class TestQAConfig:
|
||
def test_defaults(self):
|
||
qa = QAConfig.from_dict({})
|
||
assert qa.snr_threshold_db == pytest.approx(10.0)
|
||
assert qa.min_duration_s == pytest.approx(25.0)
|
||
assert qa.flag_for_review is True
|
||
|
||
def test_custom_values(self):
|
||
d = {"snr_threshold": "15dB", "min_duration": "28s", "flag_for_review": False}
|
||
qa = QAConfig.from_dict(d)
|
||
assert qa.snr_threshold_db == pytest.approx(15.0)
|
||
assert qa.min_duration_s == pytest.approx(28.0)
|
||
assert qa.flag_for_review is False
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# CampaignConfig.from_device_profile
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
def _write_device_profile(d: dict) -> str:
|
||
"""Write a dict as YAML to a temp file and return the path."""
|
||
f = tempfile.NamedTemporaryFile(mode="w", suffix=".yml", delete=False)
|
||
yaml.dump(d, f)
|
||
f.close()
|
||
return f.name
|
||
|
||
|
||
WIFI_PROFILE = {
|
||
"device": {"name": "iPhone_13_WiFi", "type": "wifi"},
|
||
"capture": {
|
||
"channels": [1, 6, 11],
|
||
"bandwidth": "20MHz",
|
||
"traffic_patterns": ["idle", "ping", "iperf_udp"],
|
||
"duration_per_config": "30s",
|
||
"script": "./scripts/wifi_control.sh",
|
||
},
|
||
"recorder": {
|
||
"device": "usrp_b210",
|
||
"center_freq": "2.45GHz",
|
||
"sample_rate": "40MHz",
|
||
"gain": "auto",
|
||
},
|
||
"output": {"path": "/tmp/test_recordings", "device_id": "iphone13_wifi_001"},
|
||
}
|
||
|
||
BT_PROFILE = {
|
||
"device": {"name": "AirPods_Pro", "type": "bluetooth"},
|
||
"capture": {
|
||
"traffic_patterns": ["idle", "audio_stream", "data_transfer"],
|
||
"duration_per_config": "30s",
|
||
},
|
||
"recorder": {
|
||
"device": "usrp_b210",
|
||
"center_freq": "2.45GHz",
|
||
"sample_rate": "40MHz",
|
||
"gain": "auto",
|
||
},
|
||
"output": {"path": "/tmp/test_recordings", "device_id": "airpods_pro_bt_001"},
|
||
}
|
||
|
||
|
||
class TestDeviceProfileParsing:
|
||
def test_wifi_schedule_count(self):
|
||
"""WiFi: 3 channels × 3 traffic = 9 steps."""
|
||
path = _write_device_profile(WIFI_PROFILE)
|
||
try:
|
||
cfg = CampaignConfig.from_device_profile(path)
|
||
finally:
|
||
os.unlink(path)
|
||
assert len(cfg.transmitters) == 1
|
||
assert len(cfg.transmitters[0].schedule) == 9
|
||
|
||
def test_wifi_campaign_name(self):
|
||
path = _write_device_profile(WIFI_PROFILE)
|
||
try:
|
||
cfg = CampaignConfig.from_device_profile(path)
|
||
finally:
|
||
os.unlink(path)
|
||
assert cfg.name == "enroll_iphone13_wifi_001"
|
||
|
||
def test_wifi_step_labels(self):
|
||
path = _write_device_profile(WIFI_PROFILE)
|
||
try:
|
||
cfg = CampaignConfig.from_device_profile(path)
|
||
finally:
|
||
os.unlink(path)
|
||
labels = [s.label for s in cfg.transmitters[0].schedule]
|
||
assert "ch01_20mhz_idle" in labels
|
||
assert "ch06_20mhz_ping" in labels
|
||
assert "ch11_20mhz_iperf_udp" in labels
|
||
|
||
def test_wifi_step_ordering(self):
|
||
"""Steps iterate channels first, then traffic."""
|
||
path = _write_device_profile(WIFI_PROFILE)
|
||
try:
|
||
cfg = CampaignConfig.from_device_profile(path)
|
||
finally:
|
||
os.unlink(path)
|
||
steps = cfg.transmitters[0].schedule
|
||
assert steps[0].channel == 1 and steps[0].traffic == "idle"
|
||
assert steps[1].channel == 1 and steps[1].traffic == "ping"
|
||
assert steps[3].channel == 6 and steps[3].traffic == "idle"
|
||
assert steps[8].channel == 11 and steps[8].traffic == "iperf_udp"
|
||
|
||
def test_wifi_step_duration(self):
|
||
path = _write_device_profile(WIFI_PROFILE)
|
||
try:
|
||
cfg = CampaignConfig.from_device_profile(path)
|
||
finally:
|
||
os.unlink(path)
|
||
for step in cfg.transmitters[0].schedule:
|
||
assert step.duration == pytest.approx(30.0)
|
||
|
||
def test_wifi_bandwidth(self):
|
||
path = _write_device_profile(WIFI_PROFILE)
|
||
try:
|
||
cfg = CampaignConfig.from_device_profile(path)
|
||
finally:
|
||
os.unlink(path)
|
||
for step in cfg.transmitters[0].schedule:
|
||
assert step.bandwidth_mhz == pytest.approx(20.0)
|
||
|
||
def test_wifi_recorder(self):
|
||
path = _write_device_profile(WIFI_PROFILE)
|
||
try:
|
||
cfg = CampaignConfig.from_device_profile(path)
|
||
finally:
|
||
os.unlink(path)
|
||
assert cfg.recorder.device == "usrp_b210"
|
||
assert cfg.recorder.center_freq == pytest.approx(2.45e9)
|
||
assert cfg.recorder.sample_rate == pytest.approx(40e6)
|
||
assert cfg.recorder.gain == "auto"
|
||
|
||
def test_wifi_total_capture_time(self):
|
||
path = _write_device_profile(WIFI_PROFILE)
|
||
try:
|
||
cfg = CampaignConfig.from_device_profile(path)
|
||
finally:
|
||
os.unlink(path)
|
||
assert cfg.total_capture_time_s() == pytest.approx(270.0) # 9 × 30s
|
||
|
||
def test_bt_schedule_count(self):
|
||
"""BT: no channels, 3 traffic patterns = 3 steps."""
|
||
path = _write_device_profile(BT_PROFILE)
|
||
try:
|
||
cfg = CampaignConfig.from_device_profile(path)
|
||
finally:
|
||
os.unlink(path)
|
||
assert len(cfg.transmitters[0].schedule) == 3
|
||
|
||
def test_bt_no_channel(self):
|
||
path = _write_device_profile(BT_PROFILE)
|
||
try:
|
||
cfg = CampaignConfig.from_device_profile(path)
|
||
finally:
|
||
os.unlink(path)
|
||
for step in cfg.transmitters[0].schedule:
|
||
assert step.channel is None
|
||
|
||
def test_bt_step_labels(self):
|
||
path = _write_device_profile(BT_PROFILE)
|
||
try:
|
||
cfg = CampaignConfig.from_device_profile(path)
|
||
finally:
|
||
os.unlink(path)
|
||
labels = [s.label for s in cfg.transmitters[0].schedule]
|
||
assert labels == ["idle", "audio_stream", "data_transfer"]
|
||
|
||
def test_missing_file_raises(self):
|
||
with pytest.raises(FileNotFoundError):
|
||
CampaignConfig.from_device_profile("/nonexistent/path/profile.yml")
|
||
|
||
def test_invalid_yaml_raises(self):
|
||
with tempfile.NamedTemporaryFile(mode="w", suffix=".yml", delete=False) as f:
|
||
f.write(": bad: yaml: [\n")
|
||
path = f.name
|
||
try:
|
||
with pytest.raises(ValueError, match="Invalid YAML"):
|
||
CampaignConfig.from_device_profile(path)
|
||
finally:
|
||
os.unlink(path)
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# CampaignConfig.from_yaml (full campaign format)
|
||
# ---------------------------------------------------------------------------
|
||
|
||
FULL_CAMPAIGN = {
|
||
"campaign": {"name": "wifi_capture_001", "mode": "controlled_testbed"},
|
||
"transmitters": [
|
||
{
|
||
"id": "laptop_wifi",
|
||
"type": "wifi",
|
||
"control_method": "external_script",
|
||
"script": "./scripts/wifi_control.sh",
|
||
"device": "/dev/wlan0",
|
||
"schedule": [
|
||
{"channel": 6, "bandwidth": "20MHz", "traffic": "iperf_tcp", "duration": "30s"},
|
||
{"channel": 36, "bandwidth": "40MHz", "traffic": "ping_flood", "duration": "30s"},
|
||
],
|
||
}
|
||
],
|
||
"recorder": {
|
||
"device": "usrp_b210",
|
||
"center_freq": "2.45GHz",
|
||
"sample_rate": "20MHz",
|
||
"gain": "40dB",
|
||
},
|
||
"qa": {"snr_threshold": "10dB", "min_duration": "25s", "flag_for_review": True},
|
||
"output": {"format": "sigmf", "path": "./recordings"},
|
||
}
|
||
|
||
|
||
class TestFullCampaignParsing:
|
||
def test_name(self):
|
||
path = _write_device_profile(FULL_CAMPAIGN)
|
||
try:
|
||
cfg = CampaignConfig.from_yaml(path)
|
||
finally:
|
||
os.unlink(path)
|
||
assert cfg.name == "wifi_capture_001"
|
||
|
||
def test_mode(self):
|
||
path = _write_device_profile(FULL_CAMPAIGN)
|
||
try:
|
||
cfg = CampaignConfig.from_yaml(path)
|
||
finally:
|
||
os.unlink(path)
|
||
assert cfg.mode == "controlled_testbed"
|
||
|
||
def test_transmitter_id(self):
|
||
path = _write_device_profile(FULL_CAMPAIGN)
|
||
try:
|
||
cfg = CampaignConfig.from_yaml(path)
|
||
finally:
|
||
os.unlink(path)
|
||
assert cfg.transmitters[0].id == "laptop_wifi"
|
||
assert cfg.transmitters[0].control_method == "external_script"
|
||
assert cfg.transmitters[0].script == "./scripts/wifi_control.sh"
|
||
|
||
def test_schedule_count(self):
|
||
path = _write_device_profile(FULL_CAMPAIGN)
|
||
try:
|
||
cfg = CampaignConfig.from_yaml(path)
|
||
finally:
|
||
os.unlink(path)
|
||
assert len(cfg.transmitters[0].schedule) == 2
|
||
|
||
def test_qa_config(self):
|
||
path = _write_device_profile(FULL_CAMPAIGN)
|
||
try:
|
||
cfg = CampaignConfig.from_yaml(path)
|
||
finally:
|
||
os.unlink(path)
|
||
assert cfg.qa.snr_threshold_db == pytest.approx(10.0)
|
||
assert cfg.qa.min_duration_s == pytest.approx(25.0)
|
||
assert cfg.qa.flag_for_review is True
|
||
|
||
def test_total_steps(self):
|
||
path = _write_device_profile(FULL_CAMPAIGN)
|
||
try:
|
||
cfg = CampaignConfig.from_yaml(path)
|
||
finally:
|
||
os.unlink(path)
|
||
assert cfg.total_steps() == 2
|
||
|
||
def test_no_transmitters_raises(self):
|
||
bad = dict(FULL_CAMPAIGN)
|
||
bad["transmitters"] = []
|
||
path = _write_device_profile(bad)
|
||
try:
|
||
with pytest.raises(ValueError, match="at least one transmitter"):
|
||
CampaignConfig.from_yaml(path)
|
||
finally:
|
||
os.unlink(path)
|
||
|
||
def test_missing_recorder_raises(self):
|
||
bad = {k: v for k, v in FULL_CAMPAIGN.items() if k != "recorder"}
|
||
path = _write_device_profile(bad)
|
||
try:
|
||
with pytest.raises((KeyError, ValueError)):
|
||
CampaignConfig.from_yaml(path)
|
||
finally:
|
||
os.unlink(path)
|