"""Tests for the RT-OSS HTTP server. Covers: auth, inference lifecycle (without SDR/ONNX hardware), orchestrator lifecycle (with mocked executor), and state helpers. ``start_inference`` and ``_inference_loop`` require real SDR hardware and an ONNX model file — those are integration tests left for hardware-in-the-loop CI. """ from __future__ import annotations import threading from unittest.mock import MagicMock, patch import pytest from fastapi.testclient import TestClient import ria_toolkit_oss.server.state as state_module from ria_toolkit_oss.server.app import create_app from ria_toolkit_oss.server.state import CampaignState, InferenceState, set_inference # --------------------------------------------------------------------------- # Fixtures # --------------------------------------------------------------------------- @pytest.fixture(autouse=True) def reset_state(): """Wipe global server state before and after every test.""" state_module._inference = None state_module._campaigns.clear() yield state_module._inference = None state_module._campaigns.clear() @pytest.fixture def client(): """Unauthenticated client (dev mode — no API key configured).""" return TestClient(create_app(api_key="")) @pytest.fixture def auth_client(): """Client for an app configured with API key 'test-secret'.""" return TestClient(create_app(api_key="test-secret")) def _mock_inference_state(**kwargs) -> InferenceState: """Return a minimal InferenceState with a fake ONNX session.""" session = MagicMock() defaults = dict( model_path="/models/test.onnx", label_map={"iphone13": 0, "noise": 1}, index_to_label={0: "iphone13", 1: "noise"}, session=session, ) defaults.update(kwargs) return InferenceState(**defaults) # --------------------------------------------------------------------------- # Health check # --------------------------------------------------------------------------- class TestHealth: def test_health_returns_ok(self, client): resp = client.get("/health") assert resp.status_code == 200 assert resp.json() == {"status": "ok"} def test_health_requires_no_auth(self, auth_client): # /health has no auth dependency — should be 200 even without a key resp = auth_client.get("/health") assert resp.status_code == 200 # --------------------------------------------------------------------------- # Authentication # --------------------------------------------------------------------------- class TestAuth: def test_missing_key_rejected(self, auth_client): resp = auth_client.get("/inference/status") assert resp.status_code == 403 def test_wrong_key_rejected(self, auth_client): resp = auth_client.get("/inference/status", headers={"X-API-Key": "wrong"}) assert resp.status_code == 403 def test_correct_key_accepted(self, auth_client): resp = auth_client.get("/inference/status", headers={"X-API-Key": "test-secret"}) # 200 null is fine here — no model loaded yet assert resp.status_code == 200 def test_dev_mode_no_key_required(self, client): resp = client.get("/inference/status") assert resp.status_code == 200 # --------------------------------------------------------------------------- # POST /inference/load # --------------------------------------------------------------------------- class TestInferenceLoad: def test_load_returns_loaded_true(self, client): mock_session = MagicMock() with patch("ria_toolkit_oss.server.routers.inference._load_onnx_session", return_value=mock_session): resp = client.post( "/inference/load", json={"model_path": "/models/m.onnx", "label_map": {"iphone13": 0, "noise": 1}}, ) assert resp.status_code == 200 body = resp.json() assert body["loaded"] is True assert body["model_path"] == "/models/m.onnx" assert body["num_classes"] == 2 def test_load_stores_state(self, client): mock_session = MagicMock() with patch("ria_toolkit_oss.server.routers.inference._load_onnx_session", return_value=mock_session): client.post( "/inference/load", json={"model_path": "/models/m.onnx", "label_map": {"zone_a": 0}}, ) assert state_module._inference is not None assert state_module._inference.model_path == "/models/m.onnx" def test_load_builds_reverse_index(self, client): mock_session = MagicMock() with patch("ria_toolkit_oss.server.routers.inference._load_onnx_session", return_value=mock_session): client.post( "/inference/load", json={"model_path": "/m.onnx", "label_map": {"cat": 0, "dog": 1}}, ) assert state_module._inference.index_to_label == {0: "cat", 1: "dog"} def test_load_503_when_onnxruntime_missing(self, client): from fastapi import HTTPException as FastAPIHTTPException with patch( "ria_toolkit_oss.server.routers.inference._load_onnx_session", side_effect=FastAPIHTTPException(status_code=503, detail="onnxruntime not installed"), ): resp = client.post( "/inference/load", json={"model_path": "/m.onnx", "label_map": {}}, ) assert resp.status_code == 503 # --------------------------------------------------------------------------- # GET /inference/status # --------------------------------------------------------------------------- class TestInferenceStatus: def test_returns_null_when_no_model_loaded(self, client): resp = client.get("/inference/status") assert resp.status_code == 200 assert resp.json() is None def test_returns_null_when_model_loaded_but_no_result_yet(self, client): set_inference(_mock_inference_state()) resp = client.get("/inference/status") assert resp.status_code == 200 assert resp.json() is None def test_returns_latest_result(self, client): state = _mock_inference_state() state.set_latest( { "timestamp": 1234567890.0, "idle": False, "device_id": "iphone13", "confidence": 0.94, "snr_db": 18.5, } ) set_inference(state) resp = client.get("/inference/status") assert resp.status_code == 200 body = resp.json() assert body["device_id"] == "iphone13" assert body["confidence"] == 0.94 assert body["idle"] is False def test_idle_result_returned(self, client): state = _mock_inference_state() state.set_latest( { "timestamp": 1234567890.0, "idle": True, "device_id": None, "confidence": 0.55, "snr_db": 2.1, } ) set_inference(state) resp = client.get("/inference/status") assert resp.status_code == 200 assert resp.json()["idle"] is True assert resp.json()["device_id"] is None # --------------------------------------------------------------------------- # POST /inference/configure # --------------------------------------------------------------------------- class TestInferenceConfigure: def test_configure_409_when_no_model_loaded(self, client): resp = client.post("/inference/configure", json={"center_freq": 2450000000}) assert resp.status_code == 409 def test_configure_stores_pending_config(self, client): set_inference(_mock_inference_state()) resp = client.post( "/inference/configure", json={"center_freq": 915000000, "gain": 30}, ) assert resp.status_code == 200 assert resp.json()["configured"] is True pending = state_module._inference.pop_pending_config() assert pending["center_freq"] == 915000000 assert pending["gain"] == 30 def test_configure_empty_body_returns_configured_false(self, client): set_inference(_mock_inference_state()) resp = client.post("/inference/configure", json={}) assert resp.status_code == 200 assert resp.json()["configured"] is False def test_configure_only_sends_provided_fields(self, client): set_inference(_mock_inference_state()) client.post("/inference/configure", json={"sample_rate": 20000000}) pending = state_module._inference.pop_pending_config() assert "sample_rate" in pending assert "center_freq" not in pending assert "gain" not in pending # --------------------------------------------------------------------------- # POST /inference/stop # --------------------------------------------------------------------------- class TestInferenceStop: def test_stop_returns_false_when_not_running(self, client): resp = client.post("/inference/stop") assert resp.status_code == 200 assert resp.json()["stopped"] is False def test_stop_returns_false_when_model_loaded_but_not_started(self, client): set_inference(_mock_inference_state()) resp = client.post("/inference/stop") assert resp.status_code == 200 assert resp.json()["stopped"] is False def test_stop_signals_running_thread(self, client): state = _mock_inference_state() state.running = True # Thread that waits for stop_event barrier = threading.Event() def _dummy_loop(): barrier.set() state.stop_event.wait(timeout=2) state.running = False state.thread = threading.Thread(target=_dummy_loop, daemon=True) state.thread.start() barrier.wait(timeout=1) set_inference(state) resp = client.post("/inference/stop") assert resp.status_code == 200 assert resp.json()["stopped"] is True assert state.stop_event.is_set() # --------------------------------------------------------------------------- # POST /orchestrator/deploy # --------------------------------------------------------------------------- class TestOrchestratorDeploy: def test_deploy_422_on_invalid_config(self, client): with patch( "ria_toolkit_oss.server.routers.orchestrator.CampaignConfig.from_dict", side_effect=ValueError("missing required field 'name'"), ): resp = client.post("/orchestrator/deploy", json={"config": {}}) assert resp.status_code == 422 def test_deploy_returns_campaign_id(self, client): mock_cfg = MagicMock() mock_cfg.name = "test_campaign" mock_cfg.total_steps.return_value = 5 mock_executor = MagicMock() mock_executor.return_value.run.return_value = MagicMock(to_dict=lambda: {}) with ( patch("ria_toolkit_oss.server.routers.orchestrator.CampaignConfig.from_dict", return_value=mock_cfg), patch("ria_toolkit_oss.server.routers.orchestrator.CampaignExecutor", mock_executor), ): resp = client.post("/orchestrator/deploy", json={"config": {"name": "test_campaign"}}) assert resp.status_code == 200 body = resp.json() assert "campaign_id" in body assert len(body["campaign_id"]) > 0 def test_deploy_registers_campaign_in_state(self, client): mock_cfg = MagicMock() mock_cfg.name = "test_campaign" mock_cfg.total_steps.return_value = 3 mock_executor = MagicMock() mock_executor.return_value.run.return_value = MagicMock(to_dict=lambda: {}) with ( patch("ria_toolkit_oss.server.routers.orchestrator.CampaignConfig.from_dict", return_value=mock_cfg), patch("ria_toolkit_oss.server.routers.orchestrator.CampaignExecutor", mock_executor), ): resp = client.post("/orchestrator/deploy", json={"config": {}}) campaign_id = resp.json()["campaign_id"] assert state_module._campaigns.get(campaign_id) is not None # --------------------------------------------------------------------------- # GET /orchestrator/status/{campaign_id} # --------------------------------------------------------------------------- class TestOrchestratorStatus: def test_status_404_for_unknown_id(self, client): resp = client.get("/orchestrator/status/nonexistent-id") assert resp.status_code == 404 def test_status_returns_campaign_state(self, client): cancel_event = threading.Event() state = CampaignState( campaign_id="abc-123", status="running", config_name="test", cancel_event=cancel_event, thread=MagicMock(), total_steps=10, progress=3, ) state_module._campaigns["abc-123"] = state resp = client.get("/orchestrator/status/abc-123") assert resp.status_code == 200 body = resp.json() assert body["campaign_id"] == "abc-123" assert body["status"] == "running" assert body["progress"] == 3 assert body["total_steps"] == 10 # --------------------------------------------------------------------------- # POST /orchestrator/cancel/{campaign_id} # --------------------------------------------------------------------------- class TestOrchestratorCancel: def test_cancel_404_for_unknown_id(self, client): resp = client.post("/orchestrator/cancel/no-such-id") assert resp.status_code == 404 def test_cancel_sets_cancel_event(self, client): cancel_event = threading.Event() state = CampaignState( campaign_id="camp-to-cancel", status="running", config_name="test", cancel_event=cancel_event, thread=MagicMock(), ) state_module._campaigns["camp-to-cancel"] = state resp = client.post("/orchestrator/cancel/camp-to-cancel") assert resp.status_code == 200 assert resp.json()["cancelled"] is True assert cancel_event.is_set() def test_cancel_already_completed_returns_false(self, client): cancel_event = threading.Event() state = CampaignState( campaign_id="done", status="completed", config_name="test", cancel_event=cancel_event, thread=MagicMock(), ) state_module._campaigns["done"] = state resp = client.post("/orchestrator/cancel/done") assert resp.status_code == 200 assert resp.json()["cancelled"] is False assert not cancel_event.is_set() # --------------------------------------------------------------------------- # State helpers # --------------------------------------------------------------------------- class TestInferenceStateHelpers: def test_set_and_get_latest(self): state = _mock_inference_state() payload = {"timestamp": 1.0, "idle": False, "device_id": "dev1", "confidence": 0.9, "snr_db": 15.0} state.set_latest(payload) assert state.get_latest() == payload def test_get_latest_returns_none_initially(self): state = _mock_inference_state() assert state.get_latest() is None def test_set_and_pop_pending_config(self): state = _mock_inference_state() state.set_pending_config({"center_freq": 915e6}) popped = state.pop_pending_config() assert popped == {"center_freq": 915e6} assert state.pop_pending_config() is None # cleared after pop def test_pending_config_overwrite(self): state = _mock_inference_state() state.set_pending_config({"center_freq": 915e6}) state.set_pending_config({"center_freq": 2450e6, "gain": 40}) assert state.pop_pending_config()["center_freq"] == 2450e6 def test_thread_safety_latest(self): """Multiple threads writing latest; final read should not raise.""" state = _mock_inference_state() results = [] def writer(val): for _ in range(100): state.set_latest({"v": val}) def reader(): for _ in range(100): results.append(state.get_latest()) threads = [threading.Thread(target=writer, args=(i,)) for i in range(4)] threads.append(threading.Thread(target=reader)) for t in threads: t.start() for t in threads: t.join(timeout=5) # No exception raised and reader got non-None values non_none = [r for r in results if r is not None] assert len(non_none) > 0