193 lines
7.2 KiB
Python
193 lines
7.2 KiB
Python
"""Tests for orchestration QA metrics."""
|
|
|
|
import numpy as np
|
|
import pytest
|
|
|
|
from ria_toolkit_oss.datatypes.recording import Recording
|
|
from ria_toolkit_oss.orchestration.campaign import QAConfig
|
|
from ria_toolkit_oss.orchestration.qa import QAResult, check_recording, estimate_snr_db
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Fixtures
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def _make_recording(n_samples: int, sample_rate: float, signal: np.ndarray) -> Recording:
|
|
return Recording(
|
|
signal.astype(np.complex64),
|
|
metadata={"sample_rate": sample_rate, "center_frequency": 2.45e9},
|
|
)
|
|
|
|
|
|
def _tone(n: int, sr: float, freq_hz: float = 100e3, amplitude: float = 0.5) -> np.ndarray:
|
|
t = np.arange(n) / sr
|
|
return (np.exp(1j * 2 * np.pi * freq_hz * t) * amplitude).astype(np.complex64)
|
|
|
|
|
|
def _noise(n: int, amplitude: float = 0.001) -> np.ndarray:
|
|
rng = np.random.default_rng(42)
|
|
return ((rng.standard_normal(n) + 1j * rng.standard_normal(n)) * amplitude).astype(np.complex64)
|
|
|
|
|
|
DEFAULT_QA = QAConfig(snr_threshold_db=10.0, min_duration_s=25.0, flag_for_review=True)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# estimate_snr_db
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestEstimateSnrDb:
|
|
def test_high_snr_tone(self):
|
|
sr = 1e6
|
|
samples = _tone(int(sr * 1), sr)
|
|
snr = estimate_snr_db(samples)
|
|
assert snr > 20.0, f"Expected high SNR for clean tone, got {snr:.1f} dB"
|
|
|
|
def test_pure_noise_low_snr(self):
|
|
sr = 1e6
|
|
rng = np.random.default_rng(0)
|
|
samples = (rng.standard_normal(int(sr)) + 1j * rng.standard_normal(int(sr))).astype(np.complex64)
|
|
snr = estimate_snr_db(samples)
|
|
# Pure noise should yield a low (possibly negative) SNR
|
|
assert snr < 15.0, f"Expected low SNR for noise, got {snr:.1f} dB"
|
|
|
|
def test_snr_increases_with_amplitude(self):
|
|
sr = 1e6
|
|
n = int(sr)
|
|
rng = np.random.default_rng(1)
|
|
noise = (rng.standard_normal(n) + 1j * rng.standard_normal(n)).astype(np.complex64) * 0.01
|
|
t = np.arange(n) / sr
|
|
tone = np.exp(1j * 2 * np.pi * 100e3 * t).astype(np.complex64)
|
|
|
|
low_snr = estimate_snr_db(noise + tone * 0.1)
|
|
high_snr = estimate_snr_db(noise + tone * 1.0)
|
|
assert high_snr > low_snr
|
|
|
|
def test_short_input_still_works(self):
|
|
# Input shorter than n_fft=4096 should not raise
|
|
samples = _tone(512, 1e6)
|
|
snr = estimate_snr_db(samples)
|
|
assert np.isfinite(snr)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# check_recording — pass cases
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestCheckRecordingPass:
|
|
def test_clean_tone_passes(self):
|
|
sr = 1e6
|
|
rec = _make_recording(int(sr * 30), sr, _tone(int(sr * 30), sr))
|
|
result = check_recording(rec, DEFAULT_QA)
|
|
assert result.passed is True
|
|
assert result.flagged is False
|
|
assert result.snr_db > 10.0
|
|
assert abs(result.duration_s - 30.0) < 0.1
|
|
|
|
def test_duration_exactly_at_threshold(self):
|
|
sr = 1e6
|
|
n = int(sr * 25) # exactly at min_duration_s
|
|
rec = _make_recording(n, sr, _tone(n, sr))
|
|
result = check_recording(rec, DEFAULT_QA)
|
|
assert result.flagged is False
|
|
|
|
def test_issues_empty_when_passing(self):
|
|
sr = 1e6
|
|
rec = _make_recording(int(sr * 30), sr, _tone(int(sr * 30), sr))
|
|
result = check_recording(rec, DEFAULT_QA)
|
|
assert result.issues == []
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# check_recording — flag cases
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestCheckRecordingFlag:
|
|
def test_short_recording_flagged(self):
|
|
sr = 1e6
|
|
n = int(sr * 10) # shorter than 25s min
|
|
rec = _make_recording(n, sr, _tone(n, sr))
|
|
result = check_recording(rec, DEFAULT_QA)
|
|
assert result.flagged is True
|
|
assert any("Duration" in issue for issue in result.issues)
|
|
|
|
def test_low_snr_flagged(self):
|
|
sr = 1e6
|
|
n = int(sr * 30)
|
|
rec = _make_recording(n, sr, _noise(n, amplitude=0.001))
|
|
result = check_recording(rec, DEFAULT_QA)
|
|
assert result.flagged is True
|
|
assert any("SNR" in issue for issue in result.issues)
|
|
|
|
def test_flag_for_review_still_passes(self):
|
|
"""With flag_for_review=True, flagged recordings are still marked passed."""
|
|
sr = 1e6
|
|
n = int(sr * 10) # short → will be flagged
|
|
rec = _make_recording(n, sr, _tone(n, sr))
|
|
qa = QAConfig(snr_threshold_db=10.0, min_duration_s=25.0, flag_for_review=True)
|
|
result = check_recording(rec, qa)
|
|
assert result.flagged is True
|
|
assert result.passed is True # human review, not auto-reject
|
|
|
|
def test_flag_for_review_false_fails(self):
|
|
"""With flag_for_review=False, a flagged recording is also marked failed."""
|
|
sr = 1e6
|
|
n = int(sr * 10)
|
|
rec = _make_recording(n, sr, _tone(n, sr))
|
|
qa = QAConfig(snr_threshold_db=10.0, min_duration_s=25.0, flag_for_review=False)
|
|
result = check_recording(rec, qa)
|
|
assert result.flagged is True
|
|
assert result.passed is False
|
|
|
|
def test_multiple_issues_reported(self):
|
|
"""Both short duration AND low SNR should both appear in issues list."""
|
|
sr = 1e6
|
|
n = int(sr * 5) # very short
|
|
rec = _make_recording(n, sr, _noise(n, amplitude=0.0001))
|
|
result = check_recording(rec, DEFAULT_QA)
|
|
assert result.flagged is True
|
|
assert len(result.issues) >= 2
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# check_recording — multichannel input
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestCheckRecordingMultichannel:
|
|
def test_multichannel_recording(self):
|
|
"""2-channel recording should evaluate channel 0 without error."""
|
|
sr = 1e6
|
|
n = int(sr * 30)
|
|
ch0 = _tone(n, sr)
|
|
ch1 = _tone(n, sr, freq_hz=200e3)
|
|
data = np.stack([ch0, ch1]) # shape (2, N)
|
|
rec = Recording(data, metadata={"sample_rate": sr, "center_frequency": 2.45e9})
|
|
result = check_recording(rec, DEFAULT_QA)
|
|
assert result.passed is True
|
|
assert result.flagged is False
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# QAResult.to_dict
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestQAResultToDict:
|
|
def test_to_dict_keys(self):
|
|
r = QAResult(passed=True, flagged=False, snr_db=18.3, duration_s=30.0)
|
|
d = r.to_dict()
|
|
assert set(d.keys()) == {"passed", "flagged", "snr_db", "duration_s", "issues"}
|
|
|
|
def test_to_dict_values(self):
|
|
r = QAResult(passed=False, flagged=True, snr_db=7.5, duration_s=10.2, issues=["SNR below threshold"])
|
|
d = r.to_dict()
|
|
assert d["passed"] is False
|
|
assert d["flagged"] is True
|
|
assert d["snr_db"] == pytest.approx(7.5, abs=0.01)
|
|
assert d["duration_s"] == pytest.approx(10.2, abs=0.01)
|
|
assert d["issues"] == ["SNR below threshold"]
|