ria-toolkit-oss/tests/ria_toolkit_oss_cli/test_transmit.py

374 lines
13 KiB
Python

"""Tests for transmit command."""
import os
import tempfile
from unittest.mock import MagicMock, patch
import numpy as np
import pytest
from click.testing import CliRunner
from ria_toolkit_oss_cli.ria_toolkit_oss.common import get_sdr_device
from ria_toolkit_oss_cli.ria_toolkit_oss.transmit import (
auto_select_tx_device,
check_sample_rate_mismatch,
load_input_file,
transmit,
validate_tx_gain,
)
class TestGetTxDevice:
"""Tests for get_sdr_device function."""
def test_get_pluto_device(self):
"""Test getting PlutoSDR device."""
mock_sdr_class = MagicMock()
mock_sdr_instance = MagicMock()
mock_sdr_class.return_value = mock_sdr_instance
with patch.dict("sys.modules", {"ria_toolkit_oss.sdr.pluto": MagicMock(Pluto=mock_sdr_class)}):
device = get_sdr_device("pluto")
assert device is mock_sdr_instance
def test_get_hackrf_device(self):
"""Test getting HackRF device."""
mock_sdr_class = MagicMock()
mock_sdr_instance = MagicMock()
mock_sdr_class.return_value = mock_sdr_instance
with patch.dict("sys.modules", {"ria_toolkit_oss.sdr.hackrf": MagicMock(HackRF=mock_sdr_class)}):
device = get_sdr_device("hackrf")
assert device is mock_sdr_instance
def test_get_unknown_device(self):
"""Test getting unknown device type."""
from click.exceptions import ClickException
with pytest.raises(ClickException) as exc_info:
get_sdr_device("unknown_device")
assert "Unknown device type" in str(exc_info.value)
class TestAutoSelectTxDevice:
"""Tests for auto_select_tx_device function."""
def test_auto_select_no_devices(self):
"""Test auto-select with no TX devices found."""
from click.exceptions import ClickException
with (
patch("ria_toolkit_oss_cli.ria_toolkit_oss.transmit.load_sdr_drivers"),
patch("ria_toolkit_oss_cli.ria_toolkit_oss.transmit.find_uhd_devices", return_value=[]),
patch("ria_toolkit_oss_cli.ria_toolkit_oss.transmit.find_pluto_devices", return_value=[]),
patch("ria_toolkit_oss_cli.ria_toolkit_oss.transmit.find_hackrf_devices", return_value=[]),
patch(
"ria_toolkit_oss_cli.ria_toolkit_oss.transmit.find_bladerf_devices", return_value=[]
),
):
with pytest.raises(ClickException) as exc_info:
auto_select_tx_device()
assert "No TX-capable SDR devices found" in str(exc_info.value)
def test_auto_select_single_device(self):
"""Test auto-select with single TX device."""
with (
patch("ria_toolkit_oss_cli.ria_toolkit_oss.transmit.load_sdr_drivers"),
patch("ria_toolkit_oss_cli.ria_toolkit_oss.transmit.find_uhd_devices", return_value=[]),
patch("ria_toolkit_oss_cli.ria_toolkit_oss.transmit.find_pluto_devices", return_value=[]),
patch(
"ria_toolkit_oss_cli.ria_toolkit_oss.transmit.find_hackrf_devices",
return_value=[{"type": "HackRF One", "serial": "123456"}],
),
patch(
"ria_toolkit_oss_cli.ria_toolkit_oss.transmit.find_bladerf_devices", return_value=[]
),
):
device_type = auto_select_tx_device(quiet=True)
assert device_type == "hackrf"
def test_auto_select_multiple_devices(self):
"""Test auto-select with multiple TX devices raises error."""
from click.exceptions import ClickException
with (
patch("ria_toolkit_oss_cli.ria_toolkit_oss.transmit.load_sdr_drivers"),
patch("ria_toolkit_oss_cli.ria_toolkit_oss.transmit.find_uhd_devices", return_value=[]),
patch(
"ria_toolkit_oss_cli.ria_toolkit_oss.transmit.find_pluto_devices",
return_value=[{"type": "PlutoSDR", "uri": "ip:pluto.local"}],
),
patch(
"ria_toolkit_oss_cli.ria_toolkit_oss.transmit.find_hackrf_devices",
return_value=[{"type": "HackRF One", "serial": "123456"}],
),
patch(
"ria_toolkit_oss_cli.ria_toolkit_oss.transmit.find_bladerf_devices", return_value=[]
),
):
with pytest.raises(ClickException) as exc_info:
auto_select_tx_device()
assert "Multiple TX-capable devices found" in str(exc_info.value)
def test_auto_select_device_mapping(self):
"""Test device type name mapping."""
test_cases = [
("PlutoSDR", "pluto"),
("HackRF One", "hackrf"),
("BladeRF", "bladerf"),
("b200", "usrp"),
("B210", "usrp"),
]
for device_name, expected_type in test_cases:
with (
patch("ria_toolkit_oss_cli.ria_toolkit_oss.transmit.load_sdr_drivers"),
patch(
"ria_toolkit_oss_cli.ria_toolkit_oss.transmit.find_uhd_devices", return_value=[]
),
patch(
"ria_toolkit_oss_cli.ria_toolkit_oss.transmit.find_pluto_devices", return_value=[]
),
patch(
"ria_toolkit_oss_cli.ria_toolkit_oss.transmit.find_hackrf_devices", return_value=[]
),
patch(
"ria_toolkit_oss_cli.ria_toolkit_oss.transmit.find_bladerf_devices",
return_value=[{"type": device_name}],
),
):
device_type = auto_select_tx_device(quiet=True)
assert device_type == expected_type
class TestLoadInputFile:
"""Tests for load_input_file function."""
def test_load_file_not_found(self):
"""Test loading non-existent file."""
from click.exceptions import ClickException
with pytest.raises(ClickException) as exc_info:
load_input_file("nonexistent.sigmf")
assert "Input file not found" in str(exc_info.value)
def test_load_sigmf_file(self):
"""Test loading SigMF file."""
with tempfile.NamedTemporaryFile(suffix=".sigmf-data", delete=False) as f:
test_file = f.name
try:
mock_recording = MagicMock()
with patch(
"ria_toolkit_oss_cli.ria_toolkit_oss.transmit.load_recording",
return_value=mock_recording,
):
recording = load_input_file(test_file, legacy=False)
assert recording == mock_recording
finally:
os.unlink(test_file)
def test_load_legacy_npy_file(self):
"""Test loading legacy NPY file."""
with tempfile.NamedTemporaryFile(suffix=".npy", delete=False) as f:
test_file = f.name
try:
mock_recording = MagicMock()
with patch(
"ria_toolkit_oss_cli.ria_toolkit_oss.transmit.from_npy_legacy",
return_value=mock_recording,
):
recording = load_input_file(test_file, legacy=True)
assert recording == mock_recording
finally:
os.unlink(test_file)
def test_load_unsupported_format(self):
"""Test loading unsupported file format."""
from click.exceptions import ClickException
with tempfile.NamedTemporaryFile(suffix=".bin", delete=False) as f:
test_file = f.name
try:
with patch(
"ria_toolkit_oss_cli.ria_toolkit_oss.transmit.load_recording",
side_effect=Exception("Unsupported format"),
):
with pytest.raises(ClickException) as exc_info:
load_input_file(test_file)
assert "Could not load" in str(exc_info.value)
assert "Supported formats" in str(exc_info.value)
finally:
os.unlink(test_file)
class TestValidateTxGain:
"""Tests for validate_tx_gain function."""
def test_valid_pluto_gain(self):
"""Test valid PlutoSDR gain."""
validate_tx_gain("pluto", -30)
validate_tx_gain("pluto", 0)
validate_tx_gain("pluto", -89)
def test_invalid_pluto_gain_too_high(self):
"""Test PlutoSDR gain too high."""
from click.exceptions import ClickException
with pytest.raises(ClickException) as exc_info:
validate_tx_gain("pluto", 10)
assert "out of range" in str(exc_info.value)
def test_invalid_pluto_gain_too_low(self):
"""Test PlutoSDR gain too low."""
from click.exceptions import ClickException
with pytest.raises(ClickException) as exc_info:
validate_tx_gain("pluto", -100)
assert "out of range" in str(exc_info.value)
def test_valid_hackrf_gain(self):
"""Test valid HackRF gain."""
validate_tx_gain("hackrf", 0)
validate_tx_gain("hackrf", 20)
validate_tx_gain("hackrf", 47)
def test_invalid_hackrf_gain(self):
"""Test invalid HackRF gain."""
from click.exceptions import ClickException
with pytest.raises(ClickException):
validate_tx_gain("hackrf", -10)
with pytest.raises(ClickException):
validate_tx_gain("hackrf", 50)
def test_high_gain_warning(self):
"""Test warning for high gain levels."""
import click
with patch.object(click, "echo") as mock_echo:
validate_tx_gain("hackrf", 45)
mock_echo.assert_called()
args = str(mock_echo.call_args)
assert "WARNING" in args
assert "high gain" in args.lower()
class TestCheckSampleRateMismatch:
"""Tests for check_sample_rate_mismatch function."""
def test_no_mismatch(self):
"""Test when sample rates match."""
mock_recording = MagicMock()
mock_recording.metadata = {"sample_rate": 2e6}
with patch("ria_toolkit_oss_cli.ria_toolkit_oss.transmit.click.echo") as mock_echo:
check_sample_rate_mismatch(mock_recording, 2e6, quiet=False)
mock_echo.assert_not_called()
def test_mismatch_warning(self):
"""Test warning when sample rates differ."""
mock_recording = MagicMock()
mock_recording.metadata = {"sample_rate": 1e6}
with patch("ria_toolkit_oss_cli.ria_toolkit_oss.transmit.click.echo") as mock_echo:
check_sample_rate_mismatch(mock_recording, 2e6, quiet=False)
mock_echo.assert_called_once()
args = str(mock_echo.call_args)
assert "Warning" in args
assert "differs" in args
def test_mismatch_quiet_mode(self):
"""Test no warning in quiet mode."""
mock_recording = MagicMock()
mock_recording.metadata = {"sample_rate": 1e6}
with patch("ria_toolkit_oss_cli.ria_toolkit_oss.transmit.click.echo") as mock_echo:
check_sample_rate_mismatch(mock_recording, 2e6, quiet=True)
mock_echo.assert_not_called()
def test_no_metadata(self):
"""Test when recording has no metadata."""
mock_recording = MagicMock()
mock_recording.metadata = None
with patch("ria_toolkit_oss_cli.ria_toolkit_oss.transmit.click.echo") as mock_echo:
check_sample_rate_mismatch(mock_recording, 2e6, quiet=False)
mock_echo.assert_not_called()
class TestTransmitCommand:
"""Tests for transmit CLI command."""
def setup_method(self):
"""Set up test fixtures."""
self.runner = CliRunner()
self.temp_dir = tempfile.mkdtemp()
def teardown_method(self):
"""Clean up test fixtures."""
import shutil
if os.path.exists(self.temp_dir):
shutil.rmtree(self.temp_dir)
def test_transmit_basic(self):
"""Test basic transmit command."""
test_file = os.path.join(self.temp_dir, "test.npy")
open(test_file, "w").close()
mock_sdr = MagicMock()
mock_recording = MagicMock()
mock_recording.data = np.array([[0.1 + 0.1j] * 1000])
mock_recording.metadata = {}
with (
patch(
"ria_toolkit_oss_cli.ria_toolkit_oss.transmit.get_sdr_device", return_value=mock_sdr
),
patch(
"ria_toolkit_oss_cli.ria_toolkit_oss.transmit.load_input_file",
return_value=mock_recording,
),
):
result = self.runner.invoke(
transmit,
[
"--device",
"hackrf",
"--sample-rate",
"2e6",
"--center-frequency",
"915M",
"--gain",
"10",
"--input",
test_file,
"--quiet",
],
)
assert result.exit_code == 0
mock_sdr.tx_recording.assert_called_once()
mock_sdr.close.assert_called_once()