"""Tests for RemoteTransmitterController — mocks paramiko and ZMQ entirely. paramiko and zmq are optional runtime deps; these tests inject fakes into sys.modules so they run regardless of whether the packages are installed. """ from __future__ import annotations import json import sys import threading import time from types import ModuleType from unittest.mock import MagicMock, patch import pytest # --------------------------------------------------------------------------- # Fake modules injected into sys.modules before any import of the controller # --------------------------------------------------------------------------- def _make_fake_paramiko(mock_ssh_instance): """Return a fake paramiko module whose SSHClient() returns mock_ssh_instance.""" mod = MagicMock(spec=ModuleType) mod.SSHClient = MagicMock(return_value=mock_ssh_instance) mod.AutoAddPolicy = MagicMock() return mod def _make_fake_zmq(mock_socket_instance): """Return a fake zmq module whose Context().socket() returns mock_socket_instance.""" mock_context = MagicMock() mock_context.socket.return_value = mock_socket_instance mod = MagicMock(spec=ModuleType) mod.Context = MagicMock(return_value=mock_context) mod.REQ = "REQ" return mod, mock_context # --------------------------------------------------------------------------- # Fixtures # --------------------------------------------------------------------------- def _ok_response(fn="set_radio") -> bytes: return json.dumps({"status": True, "message": "", "error_message": ""}).encode() def _err_response(fn="set_radio", msg="boom") -> bytes: return json.dumps({"status": False, "message": "", "error_message": msg}).encode() def _make_mock_socket(recv_side_effect=None): sock = MagicMock() if recv_side_effect is not None: sock.recv.side_effect = recv_side_effect else: sock.recv.return_value = _ok_response() return sock def _make_controller(mock_socket=None, *, startup_wait=0): """Build a controller with all external I/O mocked via sys.modules injection.""" mock_sock = mock_socket or _make_mock_socket() mock_ssh = MagicMock() mock_stdout = MagicMock() mock_stdout.channel = MagicMock() mock_ssh.exec_command.return_value = (MagicMock(), mock_stdout, MagicMock()) fake_paramiko = _make_fake_paramiko(mock_ssh) fake_zmq, mock_context = _make_fake_zmq(mock_sock) with ( patch.dict("sys.modules", {"paramiko": fake_paramiko, "zmq": fake_zmq}), patch( "ria_toolkit_oss.remote_control.remote_transmitter_controller._STARTUP_WAIT_S", startup_wait, ), ): from ria_toolkit_oss.remote_control.remote_transmitter_controller import ( RemoteTransmitterController, ) ctrl = RemoteTransmitterController( host="192.168.1.10", ssh_user="ubuntu", ssh_key_path="/home/user/.ssh/id_rsa", zmq_port=5556, ) ctrl._mock_ssh = mock_ssh ctrl._mock_socket = mock_sock ctrl._mock_context = mock_context ctrl._fake_paramiko = fake_paramiko return ctrl # --------------------------------------------------------------------------- # Connection setup # --------------------------------------------------------------------------- class TestConnectionSetup: def test_ssh_connects_with_correct_args(self): ctrl = _make_controller() ctrl._mock_ssh.connect.assert_called_once_with( hostname="192.168.1.10", username="ubuntu", key_filename="/home/user/.ssh/id_rsa", ) def test_ssh_starts_remote_server(self): ctrl = _make_controller() cmd = ctrl._mock_ssh.exec_command.call_args[0][0] assert "remote_transmitter" in cmd assert "--port" in cmd assert "5556" in cmd def test_zmq_connects_to_host_port(self): ctrl = _make_controller() ctrl._mock_socket.connect.assert_called_once_with("tcp://192.168.1.10:5556") def test_host_key_policy_set_to_auto_add(self): """AutoAddPolicy is applied so we don't prompt in headless execution.""" ctrl = _make_controller() ctrl._mock_ssh.set_missing_host_key_policy.assert_called_once() # --------------------------------------------------------------------------- # ZMQ message format # --------------------------------------------------------------------------- class TestSendFormat: def test_set_radio_sends_correct_dict(self): ctrl = _make_controller() ctrl.set_radio("pluto", "ip:192.168.2.1") sent = json.loads(ctrl._mock_socket.send.call_args[0][0].decode()) assert sent["function_name"] == "set_radio" assert sent["radio_str"] == "pluto" assert sent["identifier"] == "ip:192.168.2.1" def test_set_radio_default_identifier(self): ctrl = _make_controller() ctrl.set_radio("hackrf") sent = json.loads(ctrl._mock_socket.send.call_args[0][0].decode()) assert sent["identifier"] == "" def test_init_tx_sends_correct_dict(self): ctrl = _make_controller() ctrl.init_tx(center_frequency=2.4e9, sample_rate=20e6, gain=30, channel=1) sent = json.loads(ctrl._mock_socket.send.call_args[0][0].decode()) assert sent["function_name"] == "init_tx" assert sent["center_frequency"] == pytest.approx(2.4e9) assert sent["sample_rate"] == pytest.approx(20e6) assert sent["gain"] == pytest.approx(30) assert sent["channel"] == 1 assert sent["gain_mode"] == "absolute" def test_init_tx_default_channel_zero(self): ctrl = _make_controller() ctrl.init_tx(center_frequency=2.4e9, sample_rate=20e6, gain=0) sent = json.loads(ctrl._mock_socket.send.call_args[0][0].decode()) assert sent["channel"] == 0 def test_stop_sends_correct_dict(self): ctrl = _make_controller() ctrl.stop() sent = json.loads(ctrl._mock_socket.send.call_args[0][0].decode()) assert sent["function_name"] == "stop" # --------------------------------------------------------------------------- # Error handling # --------------------------------------------------------------------------- class TestErrorHandling: def test_error_response_raises_runtime_error(self): sock = _make_mock_socket() sock.recv.return_value = _err_response(msg="radio not found") ctrl = _make_controller(mock_socket=sock) with pytest.raises(RuntimeError, match="radio not found"): ctrl.set_radio("pluto") def test_error_message_included_in_exception(self): sock = _make_mock_socket() sock.recv.return_value = _err_response(msg="gain out of range") ctrl = _make_controller(mock_socket=sock) with pytest.raises(RuntimeError, match="gain out of range"): ctrl.init_tx(center_frequency=2.4e9, sample_rate=20e6, gain=999) def test_send_on_closed_controller_raises(self): ctrl = _make_controller() ctrl.close() with pytest.raises(RuntimeError, match="closed"): ctrl._send({"function_name": "set_radio", "radio_str": "pluto", "identifier": ""}) def test_missing_paramiko_raises_runtime_error(self): """If paramiko is absent, connecting gives a clear RuntimeError.""" import importlib import ria_toolkit_oss.remote_control.remote_transmitter_controller as mod with patch.dict("sys.modules", {"paramiko": None}): with pytest.raises((RuntimeError, ImportError)): mod.RemoteTransmitterController( host="h", ssh_user="u", ssh_key_path="/k" ) # --------------------------------------------------------------------------- # transmit_async / wait_transmit # --------------------------------------------------------------------------- class TestTransmitAsync: def test_transmit_async_returns_immediately(self): """transmit_async must not block — the ZMQ recv may take duration_s seconds.""" def slow_recv(): time.sleep(0.1) return _ok_response("transmit") sock = _make_mock_socket() sock.recv.side_effect = slow_recv ctrl = _make_controller(mock_socket=sock) t0 = time.monotonic() ctrl.transmit_async(duration_s=5.0) elapsed = time.monotonic() - t0 assert elapsed < 0.05, "transmit_async must not block" ctrl.wait_transmit(timeout=2.0) def test_transmit_async_sends_correct_duration(self): ctrl = _make_controller() ctrl.transmit_async(duration_s=12.5) ctrl.wait_transmit(timeout=1.0) sent = json.loads(ctrl._mock_socket.send.call_args[0][0].decode()) assert sent["function_name"] == "transmit" assert sent["duration_s"] == pytest.approx(12.5) def test_wait_transmit_joins_thread(self): ctrl = _make_controller() ctrl.transmit_async(duration_s=0.01) ctrl.wait_transmit(timeout=2.0) assert ctrl._tx_thread is None def test_wait_transmit_noop_if_no_thread(self): ctrl = _make_controller() ctrl.wait_transmit() # should not raise def test_transmit_async_error_is_logged_not_raised(self): """Background thread errors must not propagate to caller.""" sock = _make_mock_socket() sock.recv.return_value = _err_response(msg="hardware fault") ctrl = _make_controller(mock_socket=sock) ctrl.transmit_async(duration_s=0.01) ctrl.wait_transmit(timeout=2.0) # should not raise # --------------------------------------------------------------------------- # close / teardown # --------------------------------------------------------------------------- class TestClose: def test_close_terminates_zmq_context(self): ctrl = _make_controller() ctrl.close() ctrl._mock_context.term.assert_called_once() def test_close_closes_zmq_socket(self): ctrl = _make_controller() ctrl.close() ctrl._mock_socket.close.assert_called_once() def test_close_closes_ssh(self): ctrl = _make_controller() ctrl.close() ctrl._mock_ssh.close.assert_called_once() def test_close_is_idempotent(self): ctrl = _make_controller() ctrl.close() ctrl.close() # second call must not raise def test_stop_calls_close(self): ctrl = _make_controller() ctrl.stop() assert ctrl._socket is None assert ctrl._ssh is None