"""Tests for orchestration QA metrics.""" import numpy as np import pytest from ria_toolkit_oss.datatypes.recording import Recording from ria_toolkit_oss.orchestration.campaign import QAConfig from ria_toolkit_oss.orchestration.qa import QAResult, check_recording, estimate_snr_db # --------------------------------------------------------------------------- # Fixtures # --------------------------------------------------------------------------- def _make_recording(n_samples: int, sample_rate: float, signal: np.ndarray) -> Recording: return Recording( signal.astype(np.complex64), metadata={"sample_rate": sample_rate, "center_frequency": 2.45e9}, ) def _tone(n: int, sr: float, freq_hz: float = 100e3, amplitude: float = 0.5) -> np.ndarray: t = np.arange(n) / sr return (np.exp(1j * 2 * np.pi * freq_hz * t) * amplitude).astype(np.complex64) def _noise(n: int, amplitude: float = 0.001) -> np.ndarray: rng = np.random.default_rng(42) return ((rng.standard_normal(n) + 1j * rng.standard_normal(n)) * amplitude).astype(np.complex64) DEFAULT_QA = QAConfig(snr_threshold_db=10.0, min_duration_s=25.0, flag_for_review=True) # --------------------------------------------------------------------------- # estimate_snr_db # --------------------------------------------------------------------------- class TestEstimateSnrDb: def test_high_snr_tone(self): sr = 1e6 samples = _tone(int(sr * 1), sr) snr = estimate_snr_db(samples) assert snr > 20.0, f"Expected high SNR for clean tone, got {snr:.1f} dB" def test_pure_noise_low_snr(self): sr = 1e6 rng = np.random.default_rng(0) samples = (rng.standard_normal(int(sr)) + 1j * rng.standard_normal(int(sr))).astype(np.complex64) snr = estimate_snr_db(samples) # Pure noise should yield a low (possibly negative) SNR assert snr < 15.0, f"Expected low SNR for noise, got {snr:.1f} dB" def test_snr_increases_with_amplitude(self): sr = 1e6 n = int(sr) rng = np.random.default_rng(1) noise = (rng.standard_normal(n) + 1j * rng.standard_normal(n)).astype(np.complex64) * 0.01 t = np.arange(n) / sr tone = np.exp(1j * 2 * np.pi * 100e3 * t).astype(np.complex64) low_snr = estimate_snr_db(noise + tone * 0.1) high_snr = estimate_snr_db(noise + tone * 1.0) assert high_snr > low_snr def test_short_input_still_works(self): # Input shorter than n_fft=4096 should not raise samples = _tone(512, 1e6) snr = estimate_snr_db(samples) assert np.isfinite(snr) # --------------------------------------------------------------------------- # check_recording — pass cases # --------------------------------------------------------------------------- class TestCheckRecordingPass: def test_clean_tone_passes(self): sr = 1e6 rec = _make_recording(int(sr * 30), sr, _tone(int(sr * 30), sr)) result = check_recording(rec, DEFAULT_QA) assert result.passed is True assert result.flagged is False assert result.snr_db > 10.0 assert abs(result.duration_s - 30.0) < 0.1 def test_duration_exactly_at_threshold(self): sr = 1e6 n = int(sr * 25) # exactly at min_duration_s rec = _make_recording(n, sr, _tone(n, sr)) result = check_recording(rec, DEFAULT_QA) assert result.flagged is False def test_issues_empty_when_passing(self): sr = 1e6 rec = _make_recording(int(sr * 30), sr, _tone(int(sr * 30), sr)) result = check_recording(rec, DEFAULT_QA) assert result.issues == [] # --------------------------------------------------------------------------- # check_recording — flag cases # --------------------------------------------------------------------------- class TestCheckRecordingFlag: def test_short_recording_flagged(self): sr = 1e6 n = int(sr * 10) # shorter than 25s min rec = _make_recording(n, sr, _tone(n, sr)) result = check_recording(rec, DEFAULT_QA) assert result.flagged is True assert any("Duration" in issue for issue in result.issues) def test_low_snr_flagged(self): sr = 1e6 n = int(sr * 30) rec = _make_recording(n, sr, _noise(n, amplitude=0.001)) result = check_recording(rec, DEFAULT_QA) assert result.flagged is True assert any("SNR" in issue for issue in result.issues) def test_flag_for_review_still_passes(self): """With flag_for_review=True, flagged recordings are still marked passed.""" sr = 1e6 n = int(sr * 10) # short → will be flagged rec = _make_recording(n, sr, _tone(n, sr)) qa = QAConfig(snr_threshold_db=10.0, min_duration_s=25.0, flag_for_review=True) result = check_recording(rec, qa) assert result.flagged is True assert result.passed is True # human review, not auto-reject def test_flag_for_review_false_fails(self): """With flag_for_review=False, a flagged recording is also marked failed.""" sr = 1e6 n = int(sr * 10) rec = _make_recording(n, sr, _tone(n, sr)) qa = QAConfig(snr_threshold_db=10.0, min_duration_s=25.0, flag_for_review=False) result = check_recording(rec, qa) assert result.flagged is True assert result.passed is False def test_multiple_issues_reported(self): """Both short duration AND low SNR should both appear in issues list.""" sr = 1e6 n = int(sr * 5) # very short rec = _make_recording(n, sr, _noise(n, amplitude=0.0001)) result = check_recording(rec, DEFAULT_QA) assert result.flagged is True assert len(result.issues) >= 2 # --------------------------------------------------------------------------- # check_recording — multichannel input # --------------------------------------------------------------------------- class TestCheckRecordingMultichannel: def test_multichannel_recording(self): """2-channel recording should evaluate channel 0 without error.""" sr = 1e6 n = int(sr * 30) ch0 = _tone(n, sr) ch1 = _tone(n, sr, freq_hz=200e3) data = np.stack([ch0, ch1]) # shape (2, N) rec = Recording(data, metadata={"sample_rate": sr, "center_frequency": 2.45e9}) result = check_recording(rec, DEFAULT_QA) assert result.passed is True assert result.flagged is False # --------------------------------------------------------------------------- # QAResult.to_dict # --------------------------------------------------------------------------- class TestQAResultToDict: def test_to_dict_keys(self): r = QAResult(passed=True, flagged=False, snr_db=18.3, duration_s=30.0) d = r.to_dict() assert set(d.keys()) == {"passed", "flagged", "snr_db", "duration_s", "issues"} def test_to_dict_values(self): r = QAResult(passed=False, flagged=True, snr_db=7.5, duration_s=10.2, issues=["SNR below threshold"]) d = r.to_dict() assert d["passed"] is False assert d["flagged"] is True assert d["snr_db"] == pytest.approx(7.5, abs=0.01) assert d["duration_s"] == pytest.approx(10.2, abs=0.01) assert d["issues"] == ["SNR below threshold"]