"""Tests for transmit command.""" import os import tempfile from unittest.mock import MagicMock, patch import numpy as np import pytest from click.testing import CliRunner from ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.common import get_sdr_device from ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.transmit import ( auto_select_tx_device, check_sample_rate_mismatch, load_input_file, transmit, validate_tx_gain, ) class TestGetTxDevice: """Tests for get_sdr_device function.""" def test_get_pluto_device(self): """Test getting PlutoSDR device.""" mock_sdr_class = MagicMock() mock_sdr_instance = MagicMock() mock_sdr_class.return_value = mock_sdr_instance with patch.dict("sys.modules", {"src.ria_toolkit_oss.sdr.pluto": MagicMock(Pluto=mock_sdr_class)}): device = get_sdr_device("pluto") assert device is mock_sdr_instance def test_get_hackrf_device(self): """Test getting HackRF device.""" mock_sdr_class = MagicMock() mock_sdr_instance = MagicMock() mock_sdr_class.return_value = mock_sdr_instance with patch.dict("sys.modules", {"src.ria_toolkit_oss.sdr.hackrf": MagicMock(HackRF=mock_sdr_class)}): device = get_sdr_device("hackrf") assert device is mock_sdr_instance def test_get_unknown_device(self): """Test getting unknown device type.""" from click.exceptions import ClickException with pytest.raises(ClickException) as exc_info: get_sdr_device("unknown_device") assert "Unknown device type" in str(exc_info.value) class TestAutoSelectTxDevice: """Tests for auto_select_tx_device function.""" def test_auto_select_no_devices(self): """Test auto-select with no TX devices found.""" from click.exceptions import ClickException with ( patch("ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.transmit.load_sdr_drivers"), patch("ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.transmit.find_uhd_devices", return_value=[]), patch("ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.transmit.find_pluto_devices", return_value=[]), patch("ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.transmit.find_hackrf_devices", return_value=[]), patch( "ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.transmit.find_bladerf_devices", return_value=[] ), ): with pytest.raises(ClickException) as exc_info: auto_select_tx_device() assert "No TX-capable SDR devices found" in str(exc_info.value) def test_auto_select_single_device(self): """Test auto-select with single TX device.""" with ( patch("ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.transmit.load_sdr_drivers"), patch("ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.transmit.find_uhd_devices", return_value=[]), patch("ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.transmit.find_pluto_devices", return_value=[]), patch( "ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.transmit.find_hackrf_devices", return_value=[{"type": "HackRF One", "serial": "123456"}], ), patch( "ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.transmit.find_bladerf_devices", return_value=[] ), ): device_type = auto_select_tx_device(quiet=True) assert device_type == "hackrf" def test_auto_select_multiple_devices(self): """Test auto-select with multiple TX devices raises error.""" from click.exceptions import ClickException with ( patch("ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.transmit.load_sdr_drivers"), patch("ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.transmit.find_uhd_devices", return_value=[]), patch( "ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.transmit.find_pluto_devices", return_value=[{"type": "PlutoSDR", "uri": "ip:pluto.local"}], ), patch( "ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.transmit.find_hackrf_devices", return_value=[{"type": "HackRF One", "serial": "123456"}], ), patch( "ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.transmit.find_bladerf_devices", return_value=[] ), ): with pytest.raises(ClickException) as exc_info: auto_select_tx_device() assert "Multiple TX-capable devices found" in str(exc_info.value) def test_auto_select_device_mapping(self): """Test device type name mapping.""" test_cases = [ ("PlutoSDR", "pluto"), ("HackRF One", "hackrf"), ("BladeRF", "bladerf"), ("b200", "usrp"), ("B210", "usrp"), ] for device_name, expected_type in test_cases: with ( patch("ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.transmit.load_sdr_drivers"), patch( "ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.transmit.find_uhd_devices", return_value=[] ), patch( "ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.transmit.find_pluto_devices", return_value=[] ), patch( "ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.transmit.find_hackrf_devices", return_value=[] ), patch( "ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.transmit.find_bladerf_devices", return_value=[{"type": device_name}], ), ): device_type = auto_select_tx_device(quiet=True) assert device_type == expected_type class TestLoadInputFile: """Tests for load_input_file function.""" def test_load_file_not_found(self): """Test loading non-existent file.""" from click.exceptions import ClickException with pytest.raises(ClickException) as exc_info: load_input_file("nonexistent.sigmf") assert "Input file not found" in str(exc_info.value) def test_load_sigmf_file(self): """Test loading SigMF file.""" with tempfile.NamedTemporaryFile(suffix=".sigmf-data", delete=False) as f: test_file = f.name try: mock_recording = MagicMock() with patch( "ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.transmit.load_recording", return_value=mock_recording, ): recording = load_input_file(test_file, legacy=False) assert recording == mock_recording finally: os.unlink(test_file) def test_load_legacy_npy_file(self): """Test loading legacy NPY file.""" with tempfile.NamedTemporaryFile(suffix=".npy", delete=False) as f: test_file = f.name try: mock_recording = MagicMock() with patch( "ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.transmit.from_npy_legacy", return_value=mock_recording, ): recording = load_input_file(test_file, legacy=True) assert recording == mock_recording finally: os.unlink(test_file) def test_load_unsupported_format(self): """Test loading unsupported file format.""" from click.exceptions import ClickException with tempfile.NamedTemporaryFile(suffix=".bin", delete=False) as f: test_file = f.name try: with patch( "ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.transmit.load_recording", side_effect=Exception("Unsupported format"), ): with pytest.raises(ClickException) as exc_info: load_input_file(test_file) assert "Could not load" in str(exc_info.value) assert "Supported formats" in str(exc_info.value) finally: os.unlink(test_file) class TestValidateTxGain: """Tests for validate_tx_gain function.""" def test_valid_pluto_gain(self): """Test valid PlutoSDR gain.""" validate_tx_gain("pluto", -30) validate_tx_gain("pluto", 0) validate_tx_gain("pluto", -89) def test_invalid_pluto_gain_too_high(self): """Test PlutoSDR gain too high.""" from click.exceptions import ClickException with pytest.raises(ClickException) as exc_info: validate_tx_gain("pluto", 10) assert "out of range" in str(exc_info.value) def test_invalid_pluto_gain_too_low(self): """Test PlutoSDR gain too low.""" from click.exceptions import ClickException with pytest.raises(ClickException) as exc_info: validate_tx_gain("pluto", -100) assert "out of range" in str(exc_info.value) def test_valid_hackrf_gain(self): """Test valid HackRF gain.""" validate_tx_gain("hackrf", 0) validate_tx_gain("hackrf", 20) validate_tx_gain("hackrf", 47) def test_invalid_hackrf_gain(self): """Test invalid HackRF gain.""" from click.exceptions import ClickException with pytest.raises(ClickException): validate_tx_gain("hackrf", -10) with pytest.raises(ClickException): validate_tx_gain("hackrf", 50) def test_high_gain_warning(self): """Test warning for high gain levels.""" import click with patch.object(click, "echo") as mock_echo: validate_tx_gain("hackrf", 45) mock_echo.assert_called() args = str(mock_echo.call_args) assert "WARNING" in args assert "high gain" in args.lower() class TestCheckSampleRateMismatch: """Tests for check_sample_rate_mismatch function.""" def test_no_mismatch(self): """Test when sample rates match.""" mock_recording = MagicMock() mock_recording.metadata = {"sample_rate": 2e6} with patch("ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.transmit.click.echo") as mock_echo: check_sample_rate_mismatch(mock_recording, 2e6, quiet=False) mock_echo.assert_not_called() def test_mismatch_warning(self): """Test warning when sample rates differ.""" mock_recording = MagicMock() mock_recording.metadata = {"sample_rate": 1e6} with patch("ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.transmit.click.echo") as mock_echo: check_sample_rate_mismatch(mock_recording, 2e6, quiet=False) mock_echo.assert_called_once() args = str(mock_echo.call_args) assert "Warning" in args assert "differs" in args def test_mismatch_quiet_mode(self): """Test no warning in quiet mode.""" mock_recording = MagicMock() mock_recording.metadata = {"sample_rate": 1e6} with patch("ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.transmit.click.echo") as mock_echo: check_sample_rate_mismatch(mock_recording, 2e6, quiet=True) mock_echo.assert_not_called() def test_no_metadata(self): """Test when recording has no metadata.""" mock_recording = MagicMock() mock_recording.metadata = None with patch("ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.transmit.click.echo") as mock_echo: check_sample_rate_mismatch(mock_recording, 2e6, quiet=False) mock_echo.assert_not_called() class TestTransmitCommand: """Tests for transmit CLI command.""" def setup_method(self): """Set up test fixtures.""" self.runner = CliRunner() self.temp_dir = tempfile.mkdtemp() def teardown_method(self): """Clean up test fixtures.""" import shutil if os.path.exists(self.temp_dir): shutil.rmtree(self.temp_dir) def test_transmit_basic(self): """Test basic transmit command.""" test_file = os.path.join(self.temp_dir, "test.npy") open(test_file, "w").close() mock_sdr = MagicMock() mock_recording = MagicMock() mock_recording.data = np.array([[0.1 + 0.1j] * 1000]) mock_recording.metadata = {} with ( patch( "ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.transmit.get_sdr_device", return_value=mock_sdr ), patch( "ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.transmit.load_input_file", return_value=mock_recording, ), ): result = self.runner.invoke( transmit, [ "--device", "hackrf", "--sample-rate", "2e6", "--center-frequency", "915M", "--gain", "10", "--input", test_file, "--quiet", ], ) assert result.exit_code == 0 mock_sdr.tx_recording.assert_called_once() mock_sdr.close.assert_called_once()