ria-toolkit-oss/tests/test_agent.py
2026-04-20 12:33:14 -04:00

250 lines
7.8 KiB
Python

"""Tests for NodeAgent — TX role, session code, and TX command dispatch."""
from __future__ import annotations
import threading
import time
from unittest.mock import MagicMock, patch
import pytest
from ria_toolkit_oss.agent import NodeAgent
def _agent(role="general", session_code=None, **kwargs):
return NodeAgent(
hub_url="http://hub.test",
api_key="test-key",
name="test-node",
sdr_device="mock",
role=role,
session_code=session_code,
**kwargs,
)
def _mock_register(agent, node_id="node_abc123"):
"""Patch _post so _register() returns a fake node_id response."""
resp = MagicMock()
resp.json.return_value = {"node_id": node_id}
resp.raise_for_status.return_value = None
agent._post = MagicMock(return_value=resp)
return agent._post
# ---------------------------------------------------------------------------
# Initialisation
# ---------------------------------------------------------------------------
class TestNodeAgentInit:
def test_stores_role_general(self):
assert _agent(role="general").role == "general"
def test_stores_role_tx(self):
assert _agent(role="tx").role == "tx"
def test_stores_role_rx(self):
assert _agent(role="rx").role == "rx"
def test_session_code_stored(self):
assert _agent(session_code="amber-peak-transmit").session_code == "amber-peak-transmit"
def test_session_code_none_by_default(self):
assert _agent().session_code is None
def test_tx_stop_event_created(self):
a = _agent()
assert isinstance(a._tx_stop, threading.Event)
def test_tx_thread_none_initially(self):
assert _agent()._tx_thread is None
def test_hub_url_trailing_slash_stripped(self):
a = NodeAgent(hub_url="http://hub.test/", api_key="k", name="n")
assert a.hub_url == "http://hub.test"
# ---------------------------------------------------------------------------
# _register payload
# ---------------------------------------------------------------------------
class TestNodeAgentRegisterPayload:
def _payload(self, agent):
post = _mock_register(agent)
agent._register()
_, kwargs = post.call_args
return kwargs["json"]
def test_general_role_in_payload(self):
payload = self._payload(_agent(role="general"))
assert payload["role"] == "general"
def test_tx_role_in_payload(self):
payload = self._payload(_agent(role="tx"))
assert payload["role"] == "tx"
def test_tx_role_adds_transmit_capability(self):
payload = self._payload(_agent(role="tx"))
assert "transmit" in payload["capabilities"]
def test_general_role_omits_transmit_capability(self):
payload = self._payload(_agent(role="general"))
assert "transmit" not in payload.get("capabilities", [])
def test_session_code_included_when_set(self):
payload = self._payload(_agent(role="tx", session_code="amber-peak-transmit"))
assert payload["session_code"] == "amber-peak-transmit"
def test_session_code_omitted_when_none(self):
payload = self._payload(_agent())
assert "session_code" not in payload
def test_register_stores_returned_node_id(self):
a = _agent()
_mock_register(a, node_id="node_xyz999")
a._register()
assert a.node_id == "node_xyz999"
def test_name_in_payload(self):
a = NodeAgent(hub_url="http://h", api_key="k", name="my-bench")
_mock_register(a)
a._register()
_, kwargs = a._post.call_args
assert kwargs["json"]["name"] == "my-bench"
def test_sdr_device_in_payload(self):
a = _agent()
post = _mock_register(a)
a._register()
_, kwargs = post.call_args
assert kwargs["json"]["sdr_device"] == "mock"
def test_campaign_capability_always_present(self):
for role in ("general", "rx", "tx"):
a = _agent(role=role)
payload = self._payload(a)
assert "campaign" in payload["capabilities"]
# ---------------------------------------------------------------------------
# _dispatch — TX commands
# ---------------------------------------------------------------------------
class TestNodeAgentDispatch:
def _make_agent(self):
a = _agent(role="tx")
a.node_id = "node_abc"
a._report_campaign_status = MagicMock()
return a
def test_start_transmit_spawns_thread(self):
a = self._make_agent()
done = threading.Event()
class _FakeExecutor:
def run(self_):
done.wait(timeout=2)
with patch("ria_toolkit_oss.orchestration.tx_executor.TxExecutor", return_value=_FakeExecutor()):
a._dispatch({"command": "start_transmit", "sdr_agent": {}, "schedule": []})
time.sleep(0.05)
assert a._tx_thread is not None
done.set()
def test_start_transmit_clears_stop_event(self):
a = self._make_agent()
a._tx_stop.set() # pre-set
done = threading.Event()
class _FakeExecutor:
def run(self_):
done.wait(timeout=2)
with patch("ria_toolkit_oss.orchestration.tx_executor.TxExecutor", return_value=_FakeExecutor()):
a._dispatch({"command": "start_transmit", "sdr_agent": {}, "schedule": []})
time.sleep(0.05)
assert not a._tx_stop.is_set()
done.set()
def test_stop_transmit_sets_stop_event(self):
a = self._make_agent()
a._dispatch({"command": "stop_transmit"})
assert a._tx_stop.is_set()
def test_configure_transmit_does_not_raise(self):
a = self._make_agent()
a._dispatch({"command": "configure_transmit", "modulation": "BPSK"})
def test_unknown_command_is_ignored(self):
a = self._make_agent()
a._dispatch({"command": "frobnicate_xyz"})
def test_duplicate_start_transmit_ignored_while_running(self):
a = self._make_agent()
done = threading.Event()
run_calls = []
class _FakeExecutor:
def run(self_):
run_calls.append(1)
done.wait(timeout=2)
with patch("ria_toolkit_oss.orchestration.tx_executor.TxExecutor", return_value=_FakeExecutor()):
a._dispatch({"command": "start_transmit"})
time.sleep(0.05)
a._dispatch({"command": "start_transmit"}) # second while first alive
done.set()
time.sleep(0.05)
assert len(run_calls) == 1
def test_run_campaign_dispatched_in_thread(self):
a = self._make_agent()
done = threading.Event()
with patch("ria_toolkit_oss.agent.NodeAgent._run_campaign") as mock_run:
mock_run.side_effect = lambda *_: done.set()
a._dispatch({"command": "run_campaign", "campaign_id": "c1", "payload": {}})
done.wait(timeout=2)
assert mock_run.called
# ---------------------------------------------------------------------------
# _stop_transmit
# ---------------------------------------------------------------------------
class TestStopTransmit:
def test_no_thread_noop(self):
a = _agent()
a._stop_transmit() # must not raise
def test_sets_stop_event(self):
a = _agent()
a._stop_transmit()
assert a._tx_stop.is_set()
def test_joins_live_thread(self):
a = _agent()
finished = threading.Event()
unblock = threading.Event()
def _task():
unblock.wait(timeout=2)
finished.set()
t = threading.Thread(target=_task, daemon=True)
t.start()
a._tx_thread = t
# Signal stop and trigger thread exit
a._tx_stop.set()
unblock.set()
a._stop_transmit()
assert not t.is_alive()