374 lines
13 KiB
Python
374 lines
13 KiB
Python
"""Tests for transmit command."""
|
|
|
|
import os
|
|
import tempfile
|
|
from unittest.mock import MagicMock, patch
|
|
|
|
import numpy as np
|
|
import pytest
|
|
from click.testing import CliRunner
|
|
|
|
from ria_toolkit_oss_cli.ria_toolkit_oss.common import get_sdr_device
|
|
from ria_toolkit_oss_cli.ria_toolkit_oss.transmit import (
|
|
auto_select_tx_device,
|
|
check_sample_rate_mismatch,
|
|
load_input_file,
|
|
transmit,
|
|
validate_tx_gain,
|
|
)
|
|
|
|
|
|
class TestGetTxDevice:
|
|
"""Tests for get_sdr_device function."""
|
|
|
|
def test_get_pluto_device(self):
|
|
"""Test getting PlutoSDR device."""
|
|
mock_sdr_class = MagicMock()
|
|
mock_sdr_instance = MagicMock()
|
|
mock_sdr_class.return_value = mock_sdr_instance
|
|
|
|
with patch.dict("sys.modules", {"ria_toolkit_oss.sdr.pluto": MagicMock(Pluto=mock_sdr_class)}):
|
|
device = get_sdr_device("pluto")
|
|
assert device is mock_sdr_instance
|
|
|
|
def test_get_hackrf_device(self):
|
|
"""Test getting HackRF device."""
|
|
mock_sdr_class = MagicMock()
|
|
mock_sdr_instance = MagicMock()
|
|
mock_sdr_class.return_value = mock_sdr_instance
|
|
|
|
with patch.dict("sys.modules", {"ria_toolkit_oss.sdr.hackrf": MagicMock(HackRF=mock_sdr_class)}):
|
|
device = get_sdr_device("hackrf")
|
|
assert device is mock_sdr_instance
|
|
|
|
def test_get_unknown_device(self):
|
|
"""Test getting unknown device type."""
|
|
from click.exceptions import ClickException
|
|
|
|
with pytest.raises(ClickException) as exc_info:
|
|
get_sdr_device("unknown_device")
|
|
|
|
assert "Unknown device type" in str(exc_info.value)
|
|
|
|
|
|
class TestAutoSelectTxDevice:
|
|
"""Tests for auto_select_tx_device function."""
|
|
|
|
def test_auto_select_no_devices(self):
|
|
"""Test auto-select with no TX devices found."""
|
|
from click.exceptions import ClickException
|
|
|
|
with (
|
|
patch("ria_toolkit_oss_cli.ria_toolkit_oss.transmit.load_sdr_drivers"),
|
|
patch("ria_toolkit_oss_cli.ria_toolkit_oss.transmit.find_uhd_devices", return_value=[]),
|
|
patch("ria_toolkit_oss_cli.ria_toolkit_oss.transmit.find_pluto_devices", return_value=[]),
|
|
patch("ria_toolkit_oss_cli.ria_toolkit_oss.transmit.find_hackrf_devices", return_value=[]),
|
|
patch(
|
|
"ria_toolkit_oss_cli.ria_toolkit_oss.transmit.find_bladerf_devices", return_value=[]
|
|
),
|
|
):
|
|
|
|
with pytest.raises(ClickException) as exc_info:
|
|
auto_select_tx_device()
|
|
|
|
assert "No TX-capable SDR devices found" in str(exc_info.value)
|
|
|
|
def test_auto_select_single_device(self):
|
|
"""Test auto-select with single TX device."""
|
|
with (
|
|
patch("ria_toolkit_oss_cli.ria_toolkit_oss.transmit.load_sdr_drivers"),
|
|
patch("ria_toolkit_oss_cli.ria_toolkit_oss.transmit.find_uhd_devices", return_value=[]),
|
|
patch("ria_toolkit_oss_cli.ria_toolkit_oss.transmit.find_pluto_devices", return_value=[]),
|
|
patch(
|
|
"ria_toolkit_oss_cli.ria_toolkit_oss.transmit.find_hackrf_devices",
|
|
return_value=[{"type": "HackRF One", "serial": "123456"}],
|
|
),
|
|
patch(
|
|
"ria_toolkit_oss_cli.ria_toolkit_oss.transmit.find_bladerf_devices", return_value=[]
|
|
),
|
|
):
|
|
|
|
device_type = auto_select_tx_device(quiet=True)
|
|
assert device_type == "hackrf"
|
|
|
|
def test_auto_select_multiple_devices(self):
|
|
"""Test auto-select with multiple TX devices raises error."""
|
|
from click.exceptions import ClickException
|
|
|
|
with (
|
|
patch("ria_toolkit_oss_cli.ria_toolkit_oss.transmit.load_sdr_drivers"),
|
|
patch("ria_toolkit_oss_cli.ria_toolkit_oss.transmit.find_uhd_devices", return_value=[]),
|
|
patch(
|
|
"ria_toolkit_oss_cli.ria_toolkit_oss.transmit.find_pluto_devices",
|
|
return_value=[{"type": "PlutoSDR", "uri": "ip:pluto.local"}],
|
|
),
|
|
patch(
|
|
"ria_toolkit_oss_cli.ria_toolkit_oss.transmit.find_hackrf_devices",
|
|
return_value=[{"type": "HackRF One", "serial": "123456"}],
|
|
),
|
|
patch(
|
|
"ria_toolkit_oss_cli.ria_toolkit_oss.transmit.find_bladerf_devices", return_value=[]
|
|
),
|
|
):
|
|
|
|
with pytest.raises(ClickException) as exc_info:
|
|
auto_select_tx_device()
|
|
|
|
assert "Multiple TX-capable devices found" in str(exc_info.value)
|
|
|
|
def test_auto_select_device_mapping(self):
|
|
"""Test device type name mapping."""
|
|
test_cases = [
|
|
("PlutoSDR", "pluto"),
|
|
("HackRF One", "hackrf"),
|
|
("BladeRF", "bladerf"),
|
|
("b200", "usrp"),
|
|
("B210", "usrp"),
|
|
]
|
|
|
|
for device_name, expected_type in test_cases:
|
|
with (
|
|
patch("ria_toolkit_oss_cli.ria_toolkit_oss.transmit.load_sdr_drivers"),
|
|
patch(
|
|
"ria_toolkit_oss_cli.ria_toolkit_oss.transmit.find_uhd_devices", return_value=[]
|
|
),
|
|
patch(
|
|
"ria_toolkit_oss_cli.ria_toolkit_oss.transmit.find_pluto_devices", return_value=[]
|
|
),
|
|
patch(
|
|
"ria_toolkit_oss_cli.ria_toolkit_oss.transmit.find_hackrf_devices", return_value=[]
|
|
),
|
|
patch(
|
|
"ria_toolkit_oss_cli.ria_toolkit_oss.transmit.find_bladerf_devices",
|
|
return_value=[{"type": device_name}],
|
|
),
|
|
):
|
|
|
|
device_type = auto_select_tx_device(quiet=True)
|
|
assert device_type == expected_type
|
|
|
|
|
|
class TestLoadInputFile:
|
|
"""Tests for load_input_file function."""
|
|
|
|
def test_load_file_not_found(self):
|
|
"""Test loading non-existent file."""
|
|
from click.exceptions import ClickException
|
|
|
|
with pytest.raises(ClickException) as exc_info:
|
|
load_input_file("nonexistent.sigmf")
|
|
|
|
assert "Input file not found" in str(exc_info.value)
|
|
|
|
def test_load_sigmf_file(self):
|
|
"""Test loading SigMF file."""
|
|
with tempfile.NamedTemporaryFile(suffix=".sigmf-data", delete=False) as f:
|
|
test_file = f.name
|
|
|
|
try:
|
|
mock_recording = MagicMock()
|
|
|
|
with patch(
|
|
"ria_toolkit_oss_cli.ria_toolkit_oss.transmit.load_recording",
|
|
return_value=mock_recording,
|
|
):
|
|
recording = load_input_file(test_file, legacy=False)
|
|
assert recording == mock_recording
|
|
|
|
finally:
|
|
os.unlink(test_file)
|
|
|
|
def test_load_legacy_npy_file(self):
|
|
"""Test loading legacy NPY file."""
|
|
with tempfile.NamedTemporaryFile(suffix=".npy", delete=False) as f:
|
|
test_file = f.name
|
|
|
|
try:
|
|
mock_recording = MagicMock()
|
|
|
|
with patch(
|
|
"ria_toolkit_oss_cli.ria_toolkit_oss.transmit.from_npy_legacy",
|
|
return_value=mock_recording,
|
|
):
|
|
recording = load_input_file(test_file, legacy=True)
|
|
assert recording == mock_recording
|
|
|
|
finally:
|
|
os.unlink(test_file)
|
|
|
|
def test_load_unsupported_format(self):
|
|
"""Test loading unsupported file format."""
|
|
from click.exceptions import ClickException
|
|
|
|
with tempfile.NamedTemporaryFile(suffix=".bin", delete=False) as f:
|
|
test_file = f.name
|
|
|
|
try:
|
|
with patch(
|
|
"ria_toolkit_oss_cli.ria_toolkit_oss.transmit.load_recording",
|
|
side_effect=Exception("Unsupported format"),
|
|
):
|
|
with pytest.raises(ClickException) as exc_info:
|
|
load_input_file(test_file)
|
|
|
|
assert "Could not load" in str(exc_info.value)
|
|
assert "Supported formats" in str(exc_info.value)
|
|
|
|
finally:
|
|
os.unlink(test_file)
|
|
|
|
|
|
class TestValidateTxGain:
|
|
"""Tests for validate_tx_gain function."""
|
|
|
|
def test_valid_pluto_gain(self):
|
|
"""Test valid PlutoSDR gain."""
|
|
validate_tx_gain("pluto", -30)
|
|
validate_tx_gain("pluto", 0)
|
|
validate_tx_gain("pluto", -89)
|
|
|
|
def test_invalid_pluto_gain_too_high(self):
|
|
"""Test PlutoSDR gain too high."""
|
|
from click.exceptions import ClickException
|
|
|
|
with pytest.raises(ClickException) as exc_info:
|
|
validate_tx_gain("pluto", 10)
|
|
|
|
assert "out of range" in str(exc_info.value)
|
|
|
|
def test_invalid_pluto_gain_too_low(self):
|
|
"""Test PlutoSDR gain too low."""
|
|
from click.exceptions import ClickException
|
|
|
|
with pytest.raises(ClickException) as exc_info:
|
|
validate_tx_gain("pluto", -100)
|
|
|
|
assert "out of range" in str(exc_info.value)
|
|
|
|
def test_valid_hackrf_gain(self):
|
|
"""Test valid HackRF gain."""
|
|
validate_tx_gain("hackrf", 0)
|
|
validate_tx_gain("hackrf", 20)
|
|
validate_tx_gain("hackrf", 47)
|
|
|
|
def test_invalid_hackrf_gain(self):
|
|
"""Test invalid HackRF gain."""
|
|
from click.exceptions import ClickException
|
|
|
|
with pytest.raises(ClickException):
|
|
validate_tx_gain("hackrf", -10)
|
|
|
|
with pytest.raises(ClickException):
|
|
validate_tx_gain("hackrf", 50)
|
|
|
|
def test_high_gain_warning(self):
|
|
"""Test warning for high gain levels."""
|
|
import click
|
|
|
|
with patch.object(click, "echo") as mock_echo:
|
|
validate_tx_gain("hackrf", 45)
|
|
mock_echo.assert_called()
|
|
args = str(mock_echo.call_args)
|
|
assert "WARNING" in args
|
|
assert "high gain" in args.lower()
|
|
|
|
|
|
class TestCheckSampleRateMismatch:
|
|
"""Tests for check_sample_rate_mismatch function."""
|
|
|
|
def test_no_mismatch(self):
|
|
"""Test when sample rates match."""
|
|
mock_recording = MagicMock()
|
|
mock_recording.metadata = {"sample_rate": 2e6}
|
|
|
|
with patch("ria_toolkit_oss_cli.ria_toolkit_oss.transmit.click.echo") as mock_echo:
|
|
check_sample_rate_mismatch(mock_recording, 2e6, quiet=False)
|
|
mock_echo.assert_not_called()
|
|
|
|
def test_mismatch_warning(self):
|
|
"""Test warning when sample rates differ."""
|
|
mock_recording = MagicMock()
|
|
mock_recording.metadata = {"sample_rate": 1e6}
|
|
|
|
with patch("ria_toolkit_oss_cli.ria_toolkit_oss.transmit.click.echo") as mock_echo:
|
|
check_sample_rate_mismatch(mock_recording, 2e6, quiet=False)
|
|
mock_echo.assert_called_once()
|
|
args = str(mock_echo.call_args)
|
|
assert "Warning" in args
|
|
assert "differs" in args
|
|
|
|
def test_mismatch_quiet_mode(self):
|
|
"""Test no warning in quiet mode."""
|
|
mock_recording = MagicMock()
|
|
mock_recording.metadata = {"sample_rate": 1e6}
|
|
|
|
with patch("ria_toolkit_oss_cli.ria_toolkit_oss.transmit.click.echo") as mock_echo:
|
|
check_sample_rate_mismatch(mock_recording, 2e6, quiet=True)
|
|
mock_echo.assert_not_called()
|
|
|
|
def test_no_metadata(self):
|
|
"""Test when recording has no metadata."""
|
|
mock_recording = MagicMock()
|
|
mock_recording.metadata = None
|
|
|
|
with patch("ria_toolkit_oss_cli.ria_toolkit_oss.transmit.click.echo") as mock_echo:
|
|
check_sample_rate_mismatch(mock_recording, 2e6, quiet=False)
|
|
mock_echo.assert_not_called()
|
|
|
|
|
|
class TestTransmitCommand:
|
|
"""Tests for transmit CLI command."""
|
|
|
|
def setup_method(self):
|
|
"""Set up test fixtures."""
|
|
self.runner = CliRunner()
|
|
self.temp_dir = tempfile.mkdtemp()
|
|
|
|
def teardown_method(self):
|
|
"""Clean up test fixtures."""
|
|
import shutil
|
|
|
|
if os.path.exists(self.temp_dir):
|
|
shutil.rmtree(self.temp_dir)
|
|
|
|
def test_transmit_basic(self):
|
|
"""Test basic transmit command."""
|
|
test_file = os.path.join(self.temp_dir, "test.npy")
|
|
open(test_file, "w").close()
|
|
|
|
mock_sdr = MagicMock()
|
|
mock_recording = MagicMock()
|
|
mock_recording.data = np.array([[0.1 + 0.1j] * 1000])
|
|
mock_recording.metadata = {}
|
|
|
|
with (
|
|
patch(
|
|
"ria_toolkit_oss_cli.ria_toolkit_oss.transmit.get_sdr_device", return_value=mock_sdr
|
|
),
|
|
patch(
|
|
"ria_toolkit_oss_cli.ria_toolkit_oss.transmit.load_input_file",
|
|
return_value=mock_recording,
|
|
),
|
|
):
|
|
|
|
result = self.runner.invoke(
|
|
transmit,
|
|
[
|
|
"--device",
|
|
"hackrf",
|
|
"--sample-rate",
|
|
"2e6",
|
|
"--center-frequency",
|
|
"915M",
|
|
"--gain",
|
|
"10",
|
|
"--input",
|
|
test_file,
|
|
"--quiet",
|
|
],
|
|
)
|
|
|
|
assert result.exit_code == 0
|
|
mock_sdr.tx_recording.assert_called_once()
|
|
mock_sdr.close.assert_called_once()
|