"""Tests for orchestration executor — StepResult, CampaignResult, _run_script, _extract_tx_params.""" from __future__ import annotations import json import stat import threading from types import SimpleNamespace import pytest from ria_toolkit_oss.orchestration.executor import ( CampaignResult, StepResult, _extract_tx_params, _run_script, ) from ria_toolkit_oss.orchestration.qa import QAResult def _ok_qa() -> QAResult: return QAResult(passed=True, flagged=False, snr_db=20.0, duration_s=1.0) def _flagged_qa() -> QAResult: return QAResult(passed=True, flagged=True, snr_db=5.0, duration_s=1.0, issues=["low SNR"]) def _failed_qa() -> QAResult: return QAResult(passed=False, flagged=True, snr_db=0.0, duration_s=0.0, issues=["no signal"]) # --------------------------------------------------------------------------- # StepResult # --------------------------------------------------------------------------- class TestStepResult: def test_ok_true_when_no_error_and_qa_passed(self): r = StepResult( transmitter_id="tx1", step_label="step1", output_path="/out/rec.sigmf-data", qa=_ok_qa(), capture_timestamp=0.0, ) assert r.ok is True def test_ok_false_when_error_set(self): r = StepResult( transmitter_id="tx1", step_label="step1", output_path=None, qa=_ok_qa(), capture_timestamp=0.0, error="SDR failed", ) assert r.ok is False def test_ok_false_when_qa_not_passed(self): r = StepResult( transmitter_id="tx1", step_label="step1", output_path="/out", qa=_failed_qa(), capture_timestamp=0.0, ) assert r.ok is False def test_to_dict_contains_required_keys(self): r = StepResult( transmitter_id="tx1", step_label="step1", output_path="/out/rec.sigmf-data", qa=_ok_qa(), capture_timestamp=1234.5, ) d = r.to_dict() assert d["transmitter_id"] == "tx1" assert d["step_label"] == "step1" assert d["output_path"] == "/out/rec.sigmf-data" assert d["capture_timestamp"] == pytest.approx(1234.5) assert d["error"] is None assert d["qa"]["passed"] is True def test_to_dict_includes_error_when_set(self): r = StepResult( transmitter_id="tx1", step_label="step1", output_path=None, qa=_failed_qa(), capture_timestamp=0.0, error="disk full", ) assert r.to_dict()["error"] == "disk full" # --------------------------------------------------------------------------- # CampaignResult # --------------------------------------------------------------------------- class TestCampaignResult: def _make(self, steps: list) -> CampaignResult: r = CampaignResult(campaign_name="test_campaign") r.steps = steps r.end_time = r.start_time + 5.0 return r def test_total_steps(self): r = self._make([ StepResult("tx1", "s1", "/out", _ok_qa(), 0.0), StepResult("tx1", "s2", "/out", _ok_qa(), 0.0), ]) assert r.total_steps == 2 def test_passed_count(self): r = self._make([ StepResult("tx1", "s1", "/out", _ok_qa(), 0.0), StepResult("tx1", "s2", "/out", _failed_qa(), 0.0), ]) assert r.passed == 1 def test_failed_count(self): r = self._make([ StepResult("tx1", "s1", "/out", _ok_qa(), 0.0), StepResult("tx1", "s2", "/out", _failed_qa(), 0.0), ]) assert r.failed == 1 def test_flagged_count(self): r = self._make([ StepResult("tx1", "s1", "/out", _ok_qa(), 0.0), StepResult("tx1", "s2", "/out", _flagged_qa(), 0.0), ]) assert r.flagged == 1 def test_error_step_counts_as_failed_not_passed(self): r = self._make([ StepResult("tx1", "s1", None, _ok_qa(), 0.0, error="disk full"), ]) assert r.failed == 1 assert r.passed == 0 def test_duration_s_from_end_time(self): r = CampaignResult(campaign_name="c") r.start_time = 100.0 r.end_time = 115.0 assert r.duration_s == pytest.approx(15.0) def test_to_dict_structure(self): r = self._make([StepResult("tx1", "s1", "/out", _ok_qa(), 0.0)]) d = r.to_dict() assert d["campaign_name"] == "test_campaign" assert d["total_steps"] == 1 assert d["passed"] == 1 assert len(d["steps"]) == 1 def test_write_report(self, tmp_path): r = self._make([StepResult("tx1", "s1", "/out", _ok_qa(), 0.0)]) out = tmp_path / "report.json" r.write_report(str(out)) assert out.exists() data = json.loads(out.read_text()) assert data["campaign_name"] == "test_campaign" def test_write_report_creates_nested_dirs(self, tmp_path): r = self._make([]) out = tmp_path / "nested" / "deep" / "report.json" r.write_report(str(out)) assert out.exists() # --------------------------------------------------------------------------- # _run_script # --------------------------------------------------------------------------- class TestRunScript: def _script(self, tmp_path, body: str) -> str: s = tmp_path / "script.sh" s.write_text("#!/bin/sh\n" + body) s.chmod(s.stat().st_mode | stat.S_IEXEC) return str(s) def test_returns_stdout(self, tmp_path): out = _run_script(self._script(tmp_path, 'echo "hello world"')) assert out == "hello world" def test_passes_args_to_script(self, tmp_path): out = _run_script(self._script(tmp_path, 'echo "$1 $2"'), "configure", "arg2") assert "configure" in out def test_raises_on_nonzero_exit(self, tmp_path): with pytest.raises(RuntimeError, match="exited 1"): _run_script(self._script(tmp_path, "exit 1")) def test_raises_on_relative_path(self): with pytest.raises(RuntimeError, match="absolute"): _run_script("relative/script.sh") def test_raises_on_missing_file(self, tmp_path): with pytest.raises(RuntimeError): _run_script(str(tmp_path / "nonexistent.sh")) def test_raises_on_timeout(self, tmp_path): with pytest.raises(RuntimeError, match="timed out"): _run_script(self._script(tmp_path, "sleep 60"), timeout=0.1) def test_stderr_included_in_error_message(self, tmp_path): with pytest.raises(RuntimeError) as exc_info: _run_script(self._script(tmp_path, "echo 'bad thing' >&2; exit 1")) assert "bad thing" in str(exc_info.value) # --------------------------------------------------------------------------- # _extract_tx_params # --------------------------------------------------------------------------- class TestExtractTxParams: def test_returns_none_when_no_sdr_agent_attribute(self): tx = SimpleNamespace() assert _extract_tx_params(tx) is None def test_returns_none_when_sdr_agent_is_none(self): tx = SimpleNamespace(sdr_agent=None) assert _extract_tx_params(tx) is None def test_returns_none_when_sdr_agent_is_empty_dict(self): tx = SimpleNamespace(sdr_agent={}) assert _extract_tx_params(tx) is None def test_returns_signal_params(self): tx = SimpleNamespace(sdr_agent={ "modulation": "QPSK", "symbol_rate": 1e6, "center_frequency": 2.4e9, }) result = _extract_tx_params(tx) assert result == {"modulation": "QPSK", "symbol_rate": 1e6, "center_frequency": 2.4e9} def test_strips_infra_key_node_id(self): tx = SimpleNamespace(sdr_agent={ "modulation": "BPSK", "node_id": "node_abc123", }) result = _extract_tx_params(tx) assert "node_id" not in result assert result == {"modulation": "BPSK"} def test_strips_infra_key_session_code(self): tx = SimpleNamespace(sdr_agent={ "modulation": "FSK", "session_code": "amber-peak-transmit", }) result = _extract_tx_params(tx) assert "session_code" not in result def test_strips_none_values(self): tx = SimpleNamespace(sdr_agent={ "modulation": "QPSK", "order": None, "rolloff": 0.35, }) result = _extract_tx_params(tx) assert "order" not in result assert result == {"modulation": "QPSK", "rolloff": 0.35} def test_does_not_mutate_source_dict(self): cfg = {"modulation": "QPSK", "node_id": "nid", "session_code": "code"} tx = SimpleNamespace(sdr_agent=cfg) _extract_tx_params(tx) assert "node_id" in cfg def test_full_sdr_agent_config(self): tx = SimpleNamespace(sdr_agent={ "modulation": "16QAM", "order": 4, "symbol_rate": 5e6, "center_frequency": 915e6, "filter": "rrc", "rolloff": 0.35, "node_id": "node_xyz", "session_code": "some-code", }) result = _extract_tx_params(tx) assert result == { "modulation": "16QAM", "order": 4, "symbol_rate": 5e6, "center_frequency": 915e6, "filter": "rrc", "rolloff": 0.35, }