From a268f2ab25b81dbbc5411af578ba680c3163a1a0 Mon Sep 17 00:00:00 2001 From: jonny Date: Mon, 13 Apr 2026 11:48:15 -0400 Subject: [PATCH 01/12] Add changes for screens agent connections. --- poetry.lock | 159 +++++++------ pyproject.toml | 3 +- src/ria_toolkit_oss/agent/__init__.py | 26 +++ src/ria_toolkit_oss/agent/cli.py | 131 +++++++++++ src/ria_toolkit_oss/agent/config.py | 63 +++++ src/ria_toolkit_oss/agent/hardware.py | 22 ++ .../{agent.py => agent/legacy_executor.py} | 0 src/ria_toolkit_oss/agent/streamer.py | 221 ++++++++++++++++++ src/ria_toolkit_oss/agent/ws_client.py | 117 ++++++++++ src/ria_toolkit_oss/sdr/__init__.py | 42 +++- src/ria_toolkit_oss/sdr/pluto.py | 21 +- src/ria_toolkit_oss/sdr/sdr.py | 48 ++++ tests/agent/__init__.py | 0 tests/agent/test_config.py | 33 +++ tests/agent/test_disconnect.py | 81 +++++++ tests/agent/test_hardware.py | 29 +++ tests/agent/test_integration.py | 100 ++++++++ tests/agent/test_legacy.py | 19 ++ tests/agent/test_streamer.py | 124 ++++++++++ tests/agent/test_ws_client.py | 161 +++++++++++++ 20 files changed, 1329 insertions(+), 71 deletions(-) create mode 100644 src/ria_toolkit_oss/agent/__init__.py create mode 100644 src/ria_toolkit_oss/agent/cli.py create mode 100644 src/ria_toolkit_oss/agent/config.py create mode 100644 src/ria_toolkit_oss/agent/hardware.py rename src/ria_toolkit_oss/{agent.py => agent/legacy_executor.py} (100%) create mode 100644 src/ria_toolkit_oss/agent/streamer.py create mode 100644 src/ria_toolkit_oss/agent/ws_client.py create mode 100644 tests/agent/__init__.py create mode 100644 tests/agent/test_config.py create mode 100644 tests/agent/test_disconnect.py create mode 100644 tests/agent/test_hardware.py create mode 100644 tests/agent/test_integration.py create mode 100644 tests/agent/test_legacy.py create mode 100644 tests/agent/test_streamer.py create mode 100644 tests/agent/test_ws_client.py diff --git a/poetry.lock b/poetry.lock index d2ddd55..cb7a9f0 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 2.1.2 and should not be changed by hand. +# This file is automatically @generated by Poetry 2.3.4 and should not be changed by hand. [[package]] name = "alabaster" @@ -1096,7 +1096,7 @@ files = [ [package.dependencies] attrs = ">=22.2.0" -jsonschema-specifications = ">=2023.03.6" +jsonschema-specifications = ">=2023.3.6" referencing = ">=0.28.4" rpds-py = ">=0.25.0" @@ -3451,76 +3451,101 @@ anyio = ">=3.0.0" [[package]] name = "websockets" -version = "16.0" +version = "13.1" description = "An implementation of the WebSocket Protocol (RFC 6455 & 7692)" optional = false -python-versions = ">=3.10" -groups = ["docs", "server", "test"] +python-versions = ">=3.8" +groups = ["agent", "docs", "server", "test"] files = [ - {file = "websockets-16.0-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:04cdd5d2d1dacbad0a7bf36ccbcd3ccd5a30ee188f2560b7a62a30d14107b31a"}, - {file = "websockets-16.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:8ff32bb86522a9e5e31439a58addbb0166f0204d64066fb955265c4e214160f0"}, - {file = "websockets-16.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:583b7c42688636f930688d712885cf1531326ee05effd982028212ccc13e5957"}, - {file = "websockets-16.0-cp310-cp310-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:7d837379b647c0c4c2355c2499723f82f1635fd2c26510e1f587d89bc2199e72"}, - {file = "websockets-16.0-cp310-cp310-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:df57afc692e517a85e65b72e165356ed1df12386ecb879ad5693be08fac65dde"}, - {file = "websockets-16.0-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:2b9f1e0d69bc60a4a87349d50c09a037a2607918746f07de04df9e43252c77a3"}, - {file = "websockets-16.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:335c23addf3d5e6a8633f9f8eda77efad001671e80b95c491dd0924587ece0b3"}, - {file = "websockets-16.0-cp310-cp310-win32.whl", hash = "sha256:37b31c1623c6605e4c00d466c9d633f9b812ea430c11c8a278774a1fde1acfa9"}, - {file = "websockets-16.0-cp310-cp310-win_amd64.whl", hash = "sha256:8e1dab317b6e77424356e11e99a432b7cb2f3ec8c5ab4dabbcee6add48f72b35"}, - {file = "websockets-16.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:31a52addea25187bde0797a97d6fc3d2f92b6f72a9370792d65a6e84615ac8a8"}, - {file = "websockets-16.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:417b28978cdccab24f46400586d128366313e8a96312e4b9362a4af504f3bbad"}, - {file = "websockets-16.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:af80d74d4edfa3cb9ed973a0a5ba2b2a549371f8a741e0800cb07becdd20f23d"}, - {file = "websockets-16.0-cp311-cp311-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:08d7af67b64d29823fed316505a89b86705f2b7981c07848fb5e3ea3020c1abe"}, - {file = "websockets-16.0-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:7be95cfb0a4dae143eaed2bcba8ac23f4892d8971311f1b06f3c6b78952ee70b"}, - {file = "websockets-16.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:d6297ce39ce5c2e6feb13c1a996a2ded3b6832155fcfc920265c76f24c7cceb5"}, - {file = "websockets-16.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:1c1b30e4f497b0b354057f3467f56244c603a79c0d1dafce1d16c283c25f6e64"}, - {file = "websockets-16.0-cp311-cp311-win32.whl", hash = "sha256:5f451484aeb5cafee1ccf789b1b66f535409d038c56966d6101740c1614b86c6"}, - {file = "websockets-16.0-cp311-cp311-win_amd64.whl", hash = "sha256:8d7f0659570eefb578dacde98e24fb60af35350193e4f56e11190787bee77dac"}, - {file = "websockets-16.0-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:71c989cbf3254fbd5e84d3bff31e4da39c43f884e64f2551d14bb3c186230f00"}, - {file = "websockets-16.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:8b6e209ffee39ff1b6d0fa7bfef6de950c60dfb91b8fcead17da4ee539121a79"}, - {file = "websockets-16.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:86890e837d61574c92a97496d590968b23c2ef0aeb8a9bc9421d174cd378ae39"}, - {file = "websockets-16.0-cp312-cp312-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:9b5aca38b67492ef518a8ab76851862488a478602229112c4b0d58d63a7a4d5c"}, - {file = "websockets-16.0-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:e0334872c0a37b606418ac52f6ab9cfd17317ac26365f7f65e203e2d0d0d359f"}, - {file = "websockets-16.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:a0b31e0b424cc6b5a04b8838bbaec1688834b2383256688cf47eb97412531da1"}, - {file = "websockets-16.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:485c49116d0af10ac698623c513c1cc01c9446c058a4e61e3bf6c19dff7335a2"}, - {file = "websockets-16.0-cp312-cp312-win32.whl", hash = "sha256:eaded469f5e5b7294e2bdca0ab06becb6756ea86894a47806456089298813c89"}, - {file = "websockets-16.0-cp312-cp312-win_amd64.whl", hash = "sha256:5569417dc80977fc8c2d43a86f78e0a5a22fee17565d78621b6bb264a115d4ea"}, - {file = "websockets-16.0-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:878b336ac47938b474c8f982ac2f7266a540adc3fa4ad74ae96fea9823a02cc9"}, - {file = "websockets-16.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:52a0fec0e6c8d9a784c2c78276a48a2bdf099e4ccc2a4cad53b27718dbfd0230"}, - {file = "websockets-16.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:e6578ed5b6981005df1860a56e3617f14a6c307e6a71b4fff8c48fdc50f3ed2c"}, - {file = "websockets-16.0-cp313-cp313-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:95724e638f0f9c350bb1c2b0a7ad0e83d9cc0c9259f3ea94e40d7b02a2179ae5"}, - {file = "websockets-16.0-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:c0204dc62a89dc9d50d682412c10b3542d748260d743500a85c13cd1ee4bde82"}, - {file = "websockets-16.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:52ac480f44d32970d66763115edea932f1c5b1312de36df06d6b219f6741eed8"}, - {file = "websockets-16.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:6e5a82b677f8f6f59e8dfc34ec06ca6b5b48bc4fcda346acd093694cc2c24d8f"}, - {file = "websockets-16.0-cp313-cp313-win32.whl", hash = "sha256:abf050a199613f64c886ea10f38b47770a65154dc37181bfaff70c160f45315a"}, - {file = "websockets-16.0-cp313-cp313-win_amd64.whl", hash = "sha256:3425ac5cf448801335d6fdc7ae1eb22072055417a96cc6b31b3861f455fbc156"}, - {file = "websockets-16.0-cp314-cp314-macosx_10_15_universal2.whl", hash = "sha256:8cc451a50f2aee53042ac52d2d053d08bf89bcb31ae799cb4487587661c038a0"}, - {file = "websockets-16.0-cp314-cp314-macosx_10_15_x86_64.whl", hash = "sha256:daa3b6ff70a9241cf6c7fc9e949d41232d9d7d26fd3522b1ad2b4d62487e9904"}, - {file = "websockets-16.0-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:fd3cb4adb94a2a6e2b7c0d8d05cb94e6f1c81a0cf9dc2694fb65c7e8d94c42e4"}, - {file = "websockets-16.0-cp314-cp314-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:781caf5e8eee67f663126490c2f96f40906594cb86b408a703630f95550a8c3e"}, - {file = "websockets-16.0-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:caab51a72c51973ca21fa8a18bd8165e1a0183f1ac7066a182ff27107b71e1a4"}, - {file = "websockets-16.0-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:19c4dc84098e523fd63711e563077d39e90ec6702aff4b5d9e344a60cb3c0cb1"}, - {file = "websockets-16.0-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:a5e18a238a2b2249c9a9235466b90e96ae4795672598a58772dd806edc7ac6d3"}, - {file = "websockets-16.0-cp314-cp314-win32.whl", hash = "sha256:a069d734c4a043182729edd3e9f247c3b2a4035415a9172fd0f1b71658a320a8"}, - {file = "websockets-16.0-cp314-cp314-win_amd64.whl", hash = "sha256:c0ee0e63f23914732c6d7e0cce24915c48f3f1512ec1d079ed01fc629dab269d"}, - {file = "websockets-16.0-cp314-cp314t-macosx_10_15_universal2.whl", hash = "sha256:a35539cacc3febb22b8f4d4a99cc79b104226a756aa7400adc722e83b0d03244"}, - {file = "websockets-16.0-cp314-cp314t-macosx_10_15_x86_64.whl", hash = "sha256:b784ca5de850f4ce93ec85d3269d24d4c82f22b7212023c974c401d4980ebc5e"}, - {file = "websockets-16.0-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:569d01a4e7fba956c5ae4fc988f0d4e187900f5497ce46339c996dbf24f17641"}, - {file = "websockets-16.0-cp314-cp314t-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:50f23cdd8343b984957e4077839841146f67a3d31ab0d00e6b824e74c5b2f6e8"}, - {file = "websockets-16.0-cp314-cp314t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:152284a83a00c59b759697b7f9e9cddf4e3c7861dd0d964b472b70f78f89e80e"}, - {file = "websockets-16.0-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:bc59589ab64b0022385f429b94697348a6a234e8ce22544e3681b2e9331b5944"}, - {file = "websockets-16.0-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:32da954ffa2814258030e5a57bc73a3635463238e797c7375dc8091327434206"}, - {file = "websockets-16.0-cp314-cp314t-win32.whl", hash = "sha256:5a4b4cc550cb665dd8a47f868c8d04c8230f857363ad3c9caf7a0c3bf8c61ca6"}, - {file = "websockets-16.0-cp314-cp314t-win_amd64.whl", hash = "sha256:b14dc141ed6d2dde437cddb216004bcac6a1df0935d79656387bd41632ba0bbd"}, - {file = "websockets-16.0-pp311-pypy311_pp73-macosx_10_15_x86_64.whl", hash = "sha256:349f83cd6c9a415428ee1005cadb5c2c56f4389bc06a9af16103c3bc3dcc8b7d"}, - {file = "websockets-16.0-pp311-pypy311_pp73-macosx_11_0_arm64.whl", hash = "sha256:4a1aba3340a8dca8db6eb5a7986157f52eb9e436b74813764241981ca4888f03"}, - {file = "websockets-16.0-pp311-pypy311_pp73-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:f4a32d1bd841d4bcbffdcb3d2ce50c09c3909fbead375ab28d0181af89fd04da"}, - {file = "websockets-16.0-pp311-pypy311_pp73-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:0298d07ee155e2e9fda5be8a9042200dd2e3bb0b8a38482156576f863a9d457c"}, - {file = "websockets-16.0-pp311-pypy311_pp73-win_amd64.whl", hash = "sha256:a653aea902e0324b52f1613332ddf50b00c06fdaf7e92624fbf8c77c78fa5767"}, - {file = "websockets-16.0-py3-none-any.whl", hash = "sha256:1637db62fad1dc833276dded54215f2c7fa46912301a24bd94d45d46a011ceec"}, - {file = "websockets-16.0.tar.gz", hash = "sha256:5f6261a5e56e8d5c42a4497b364ea24d94d9563e8fbd44e78ac40879c60179b5"}, + {file = "websockets-13.1-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:f48c749857f8fb598fb890a75f540e3221d0976ed0bf879cf3c7eef34151acee"}, + {file = "websockets-13.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:c7e72ce6bda6fb9409cc1e8164dd41d7c91466fb599eb047cfda72fe758a34a7"}, + {file = "websockets-13.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:f779498eeec470295a2b1a5d97aa1bc9814ecd25e1eb637bd9d1c73a327387f6"}, + {file = "websockets-13.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4676df3fe46956fbb0437d8800cd5f2b6d41143b6e7e842e60554398432cf29b"}, + {file = "websockets-13.1-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a7affedeb43a70351bb811dadf49493c9cfd1ed94c9c70095fd177e9cc1541fa"}, + {file = "websockets-13.1-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1971e62d2caa443e57588e1d82d15f663b29ff9dfe7446d9964a4b6f12c1e700"}, + {file = "websockets-13.1-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:5f2e75431f8dc4a47f31565a6e1355fb4f2ecaa99d6b89737527ea917066e26c"}, + {file = "websockets-13.1-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:58cf7e75dbf7e566088b07e36ea2e3e2bd5676e22216e4cad108d4df4a7402a0"}, + {file = "websockets-13.1-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:c90d6dec6be2c7d03378a574de87af9b1efea77d0c52a8301dd831ece938452f"}, + {file = "websockets-13.1-cp310-cp310-win32.whl", hash = "sha256:730f42125ccb14602f455155084f978bd9e8e57e89b569b4d7f0f0c17a448ffe"}, + {file = "websockets-13.1-cp310-cp310-win_amd64.whl", hash = "sha256:5993260f483d05a9737073be197371940c01b257cc45ae3f1d5d7adb371b266a"}, + {file = "websockets-13.1-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:61fc0dfcda609cda0fc9fe7977694c0c59cf9d749fbb17f4e9483929e3c48a19"}, + {file = "websockets-13.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:ceec59f59d092c5007e815def4ebb80c2de330e9588e101cf8bd94c143ec78a5"}, + {file = "websockets-13.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:c1dca61c6db1166c48b95198c0b7d9c990b30c756fc2923cc66f68d17dc558fd"}, + {file = "websockets-13.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:308e20f22c2c77f3f39caca508e765f8725020b84aa963474e18c59accbf4c02"}, + {file = "websockets-13.1-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:62d516c325e6540e8a57b94abefc3459d7dab8ce52ac75c96cad5549e187e3a7"}, + {file = "websockets-13.1-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:87c6e35319b46b99e168eb98472d6c7d8634ee37750d7693656dc766395df096"}, + {file = "websockets-13.1-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:5f9fee94ebafbc3117c30be1844ed01a3b177bb6e39088bc6b2fa1dc15572084"}, + {file = "websockets-13.1-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:7c1e90228c2f5cdde263253fa5db63e6653f1c00e7ec64108065a0b9713fa1b3"}, + {file = "websockets-13.1-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:6548f29b0e401eea2b967b2fdc1c7c7b5ebb3eeb470ed23a54cd45ef078a0db9"}, + {file = "websockets-13.1-cp311-cp311-win32.whl", hash = "sha256:c11d4d16e133f6df8916cc5b7e3e96ee4c44c936717d684a94f48f82edb7c92f"}, + {file = "websockets-13.1-cp311-cp311-win_amd64.whl", hash = "sha256:d04f13a1d75cb2b8382bdc16ae6fa58c97337253826dfe136195b7f89f661557"}, + {file = "websockets-13.1-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:9d75baf00138f80b48f1eac72ad1535aac0b6461265a0bcad391fc5aba875cfc"}, + {file = "websockets-13.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:9b6f347deb3dcfbfde1c20baa21c2ac0751afaa73e64e5b693bb2b848efeaa49"}, + {file = "websockets-13.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:de58647e3f9c42f13f90ac7e5f58900c80a39019848c5547bc691693098ae1bd"}, + {file = "websockets-13.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a1b54689e38d1279a51d11e3467dd2f3a50f5f2e879012ce8f2d6943f00e83f0"}, + {file = "websockets-13.1-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:cf1781ef73c073e6b0f90af841aaf98501f975d306bbf6221683dd594ccc52b6"}, + {file = "websockets-13.1-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8d23b88b9388ed85c6faf0e74d8dec4f4d3baf3ecf20a65a47b836d56260d4b9"}, + {file = "websockets-13.1-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:3c78383585f47ccb0fcf186dcb8a43f5438bd7d8f47d69e0b56f71bf431a0a68"}, + {file = "websockets-13.1-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:d6d300f8ec35c24025ceb9b9019ae9040c1ab2f01cddc2bcc0b518af31c75c14"}, + {file = "websockets-13.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:a9dcaf8b0cc72a392760bb8755922c03e17a5a54e08cca58e8b74f6902b433cf"}, + {file = "websockets-13.1-cp312-cp312-win32.whl", hash = "sha256:2f85cf4f2a1ba8f602298a853cec8526c2ca42a9a4b947ec236eaedb8f2dc80c"}, + {file = "websockets-13.1-cp312-cp312-win_amd64.whl", hash = "sha256:38377f8b0cdeee97c552d20cf1865695fcd56aba155ad1b4ca8779a5b6ef4ac3"}, + {file = "websockets-13.1-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:a9ab1e71d3d2e54a0aa646ab6d4eebfaa5f416fe78dfe4da2839525dc5d765c6"}, + {file = "websockets-13.1-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:b9d7439d7fab4dce00570bb906875734df13d9faa4b48e261c440a5fec6d9708"}, + {file = "websockets-13.1-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:327b74e915cf13c5931334c61e1a41040e365d380f812513a255aa804b183418"}, + {file = "websockets-13.1-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:325b1ccdbf5e5725fdcb1b0e9ad4d2545056479d0eee392c291c1bf76206435a"}, + {file = "websockets-13.1-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:346bee67a65f189e0e33f520f253d5147ab76ae42493804319b5716e46dddf0f"}, + {file = "websockets-13.1-cp313-cp313-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:91a0fa841646320ec0d3accdff5b757b06e2e5c86ba32af2e0815c96c7a603c5"}, + {file = "websockets-13.1-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:18503d2c5f3943e93819238bf20df71982d193f73dcecd26c94514f417f6b135"}, + {file = "websockets-13.1-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:a9cd1af7e18e5221d2878378fbc287a14cd527fdd5939ed56a18df8a31136bb2"}, + {file = "websockets-13.1-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:70c5be9f416aa72aab7a2a76c90ae0a4fe2755c1816c153c1a2bcc3333ce4ce6"}, + {file = "websockets-13.1-cp313-cp313-win32.whl", hash = "sha256:624459daabeb310d3815b276c1adef475b3e6804abaf2d9d2c061c319f7f187d"}, + {file = "websockets-13.1-cp313-cp313-win_amd64.whl", hash = "sha256:c518e84bb59c2baae725accd355c8dc517b4a3ed8db88b4bc93c78dae2974bf2"}, + {file = "websockets-13.1-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:c7934fd0e920e70468e676fe7f1b7261c1efa0d6c037c6722278ca0228ad9d0d"}, + {file = "websockets-13.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:149e622dc48c10ccc3d2760e5f36753db9cacf3ad7bc7bbbfd7d9c819e286f23"}, + {file = "websockets-13.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:a569eb1b05d72f9bce2ebd28a1ce2054311b66677fcd46cf36204ad23acead8c"}, + {file = "websockets-13.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:95df24ca1e1bd93bbca51d94dd049a984609687cb2fb08a7f2c56ac84e9816ea"}, + {file = "websockets-13.1-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d8dbb1bf0c0a4ae8b40bdc9be7f644e2f3fb4e8a9aca7145bfa510d4a374eeb7"}, + {file = "websockets-13.1-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:035233b7531fb92a76beefcbf479504db8c72eb3bff41da55aecce3a0f729e54"}, + {file = "websockets-13.1-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:e4450fc83a3df53dec45922b576e91e94f5578d06436871dce3a6be38e40f5db"}, + {file = "websockets-13.1-cp38-cp38-musllinux_1_2_i686.whl", hash = "sha256:463e1c6ec853202dd3657f156123d6b4dad0c546ea2e2e38be2b3f7c5b8e7295"}, + {file = "websockets-13.1-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:6d6855bbe70119872c05107e38fbc7f96b1d8cb047d95c2c50869a46c65a8e96"}, + {file = "websockets-13.1-cp38-cp38-win32.whl", hash = "sha256:204e5107f43095012b00f1451374693267adbb832d29966a01ecc4ce1db26faf"}, + {file = "websockets-13.1-cp38-cp38-win_amd64.whl", hash = "sha256:485307243237328c022bc908b90e4457d0daa8b5cf4b3723fd3c4a8012fce4c6"}, + {file = "websockets-13.1-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:9b37c184f8b976f0c0a231a5f3d6efe10807d41ccbe4488df8c74174805eea7d"}, + {file = "websockets-13.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:163e7277e1a0bd9fb3c8842a71661ad19c6aa7bb3d6678dc7f89b17fbcc4aeb7"}, + {file = "websockets-13.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:4b889dbd1342820cc210ba44307cf75ae5f2f96226c0038094455a96e64fb07a"}, + {file = "websockets-13.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:586a356928692c1fed0eca68b4d1c2cbbd1ca2acf2ac7e7ebd3b9052582deefa"}, + {file = "websockets-13.1-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:7bd6abf1e070a6b72bfeb71049d6ad286852e285f146682bf30d0296f5fbadfa"}, + {file = "websockets-13.1-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6d2aad13a200e5934f5a6767492fb07151e1de1d6079c003ab31e1823733ae79"}, + {file = "websockets-13.1-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:df01aea34b6e9e33572c35cd16bae5a47785e7d5c8cb2b54b2acdb9678315a17"}, + {file = "websockets-13.1-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:e54affdeb21026329fb0744ad187cf812f7d3c2aa702a5edb562b325191fcab6"}, + {file = "websockets-13.1-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:9ef8aa8bdbac47f4968a5d66462a2a0935d044bf35c0e5a8af152d58516dbeb5"}, + {file = "websockets-13.1-cp39-cp39-win32.whl", hash = "sha256:deeb929efe52bed518f6eb2ddc00cc496366a14c726005726ad62c2dd9017a3c"}, + {file = "websockets-13.1-cp39-cp39-win_amd64.whl", hash = "sha256:7c65ffa900e7cc958cd088b9a9157a8141c991f8c53d11087e6fb7277a03f81d"}, + {file = "websockets-13.1-pp310-pypy310_pp73-macosx_10_15_x86_64.whl", hash = "sha256:5dd6da9bec02735931fccec99d97c29f47cc61f644264eb995ad6c0c27667238"}, + {file = "websockets-13.1-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:2510c09d8e8df777177ee3d40cd35450dc169a81e747455cc4197e63f7e7bfe5"}, + {file = "websockets-13.1-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f1c3cf67185543730888b20682fb186fc8d0fa6f07ccc3ef4390831ab4b388d9"}, + {file = "websockets-13.1-pp310-pypy310_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:bcc03c8b72267e97b49149e4863d57c2d77f13fae12066622dc78fe322490fe6"}, + {file = "websockets-13.1-pp310-pypy310_pp73-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:004280a140f220c812e65f36944a9ca92d766b6cc4560be652a0a3883a79ed8a"}, + {file = "websockets-13.1-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:e2620453c075abeb0daa949a292e19f56de518988e079c36478bacf9546ced23"}, + {file = "websockets-13.1-pp38-pypy38_pp73-macosx_10_9_x86_64.whl", hash = "sha256:9156c45750b37337f7b0b00e6248991a047be4aa44554c9886fe6bdd605aab3b"}, + {file = "websockets-13.1-pp38-pypy38_pp73-macosx_11_0_arm64.whl", hash = "sha256:80c421e07973a89fbdd93e6f2003c17d20b69010458d3a8e37fb47874bd67d51"}, + {file = "websockets-13.1-pp38-pypy38_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:82d0ba76371769d6a4e56f7e83bb8e81846d17a6190971e38b5de108bde9b0d7"}, + {file = "websockets-13.1-pp38-pypy38_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e9875a0143f07d74dc5e1ded1c4581f0d9f7ab86c78994e2ed9e95050073c94d"}, + {file = "websockets-13.1-pp38-pypy38_pp73-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a11e38ad8922c7961447f35c7b17bffa15de4d17c70abd07bfbe12d6faa3e027"}, + {file = "websockets-13.1-pp38-pypy38_pp73-win_amd64.whl", hash = "sha256:4059f790b6ae8768471cddb65d3c4fe4792b0ab48e154c9f0a04cefaabcd5978"}, + {file = "websockets-13.1-pp39-pypy39_pp73-macosx_10_15_x86_64.whl", hash = "sha256:25c35bf84bf7c7369d247f0b8cfa157f989862c49104c5cf85cb5436a641d93e"}, + {file = "websockets-13.1-pp39-pypy39_pp73-macosx_11_0_arm64.whl", hash = "sha256:83f91d8a9bb404b8c2c41a707ac7f7f75b9442a0a876df295de27251a856ad09"}, + {file = "websockets-13.1-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7a43cfdcddd07f4ca2b1afb459824dd3c6d53a51410636a2c7fc97b9a8cf4842"}, + {file = "websockets-13.1-pp39-pypy39_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:48a2ef1381632a2f0cb4efeff34efa97901c9fbc118e01951ad7cfc10601a9bb"}, + {file = "websockets-13.1-pp39-pypy39_pp73-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:459bf774c754c35dbb487360b12c5727adab887f1622b8aed5755880a21c4a20"}, + {file = "websockets-13.1-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:95858ca14a9f6fa8413d29e0a585b31b278388aa775b8a81fa24830123874678"}, + {file = "websockets-13.1-py3-none-any.whl", hash = "sha256:a9a396a6ad26130cdae92ae10c36af09d9bfe6cafe69670fd3b6da9b07b4044f"}, + {file = "websockets-13.1.tar.gz", hash = "sha256:a3b3366087c1bc0a2795111edcadddb8b3b59509d5db5d7ea3fdd69f954a8878"}, ] [metadata] lock-version = "2.1" python-versions = ">=3.10" -content-hash = "b1e5ddd7284aecf49624e51740b7a4c31bc8d0e703c255126ba5d9b2a4a0e519" +content-hash = "7ddbf7d85e9ae7bd3a1b99ae481df20aaf6fd185d5f628b0fdf9b7bd278730ed" diff --git a/pyproject.toml b/pyproject.toml index 8db3469..a0bd664 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -100,6 +100,7 @@ optional = true [tool.poetry.group.agent.dependencies] requests = ">=2.28,<3.0" +websockets = ">=12.0,<14.0" [tool.poetry.group.dev.dependencies] flake8 = "^7.1.0" @@ -116,7 +117,7 @@ pylint = "^3.2.6" # For pyreverse, to automate the creation of UML diagrams ria = "ria_toolkit_oss_cli.cli:cli" ria-tools = "ria_toolkit_oss_cli.cli:cli" ria-server = "ria_toolkit_oss.server.cli:serve" -ria-agent = "ria_toolkit_oss.agent:main" +ria-agent = "ria_toolkit_oss.agent.cli:main" [tool.poetry.group.server.dependencies] fastapi = ">=0.111,<1.0" diff --git a/src/ria_toolkit_oss/agent/__init__.py b/src/ria_toolkit_oss/agent/__init__.py new file mode 100644 index 0000000..11647ef --- /dev/null +++ b/src/ria_toolkit_oss/agent/__init__.py @@ -0,0 +1,26 @@ +"""RIA Toolkit agent package. + +Provides two execution modes: + +- **Legacy long-poll executor** (`NodeAgent` in :mod:`legacy_executor`) — an + HTTP long-polling agent that runs ONNX inference locally on the host. +- **Streamer** (:mod:`streamer`) — a thin WebSocket client that opens an SDR + and streams raw IQ to the RIA Hub server, which performs all inference. + +Back-compat: ``from ria_toolkit_oss.agent import NodeAgent`` and the ``main`` +entry point continue to work. +""" + +from __future__ import annotations + +from .legacy_executor import NodeAgent +from .legacy_executor import main as _legacy_main + +__all__ = ["NodeAgent", "main"] + + +def main() -> None: + """Unified CLI entry point. Dispatches to streamer/legacy subcommands.""" + from .cli import main as _cli_main + + _cli_main() diff --git a/src/ria_toolkit_oss/agent/cli.py b/src/ria_toolkit_oss/agent/cli.py new file mode 100644 index 0000000..2873293 --- /dev/null +++ b/src/ria_toolkit_oss/agent/cli.py @@ -0,0 +1,131 @@ +"""Unified ``ria-agent`` CLI. + +Subcommands: + +- ``ria-agent run [legacy args]`` — legacy long-poll NodeAgent (unchanged). +- ``ria-agent stream`` — new WebSocket-based IQ streamer. +- ``ria-agent detect`` — print SDR drivers whose modules import cleanly. +- ``ria-agent register --url URL --token TOKEN`` — save credentials to + ``~/.ria/agent.json``. + +Invoking ``ria-agent`` with no subcommand falls through to the legacy +long-poll behavior for back-compatibility with existing deployments. +""" + +from __future__ import annotations + +import argparse +import asyncio +import json +import logging +import sys + +from . import config as _config +from .hardware import available_devices +from .legacy_executor import main as _legacy_main + +_LEGACY_ALIASES = {"--hub", "--key", "--name", "--device", "--insecure", "--log-level", "--config"} + + +def _cmd_detect(_args: argparse.Namespace) -> int: + devices = available_devices() + if not devices: + print("No SDR drivers available (install ria-toolkit-oss[all-sdr] or per-driver extras).") + return 0 + for name in devices: + print(name) + return 0 + + +def _cmd_register(args: argparse.Namespace) -> int: + cfg = _config.load() + cfg.hub_url = args.url + cfg.token = args.token + if args.name: + cfg.name = args.name + if args.agent_id: + cfg.agent_id = args.agent_id + cfg.insecure = bool(args.insecure) + path = _config.save(cfg) + print(f"Saved agent credentials to {path}") + return 0 + + +def _cmd_stream(args: argparse.Namespace) -> int: + from .streamer import run_streamer + + cfg = _config.load() + url = args.url or _derive_ws_url(cfg.hub_url, cfg.agent_id) + token = args.token or cfg.token + if not url: + print("error: --url is required (or run `ria-agent register` first)", file=sys.stderr) + return 2 + try: + asyncio.run(run_streamer(url, token)) + except KeyboardInterrupt: + pass + return 0 + + +def _derive_ws_url(hub_url: str, agent_id: str) -> str: + if not hub_url: + return "" + base = hub_url.rstrip("/") + if base.startswith("https://"): + base = "wss://" + base[len("https://"):] + elif base.startswith("http://"): + base = "ws://" + base[len("http://"):] + suffix = f"/api/agent/ws/{agent_id}" if agent_id else "/api/agent/ws" + return base + suffix + + +def main() -> None: + # Back-compat: if the first non-flag token matches a known legacy flag, + # or there is no subcommand at all, dispatch to the legacy CLI. + argv = sys.argv[1:] + if not argv or (argv[0].startswith("--") and argv[0] in _LEGACY_ALIASES): + _legacy_main() + return + + parser = argparse.ArgumentParser(prog="ria-agent") + sub = parser.add_subparsers(dest="command", required=True) + + sub.add_parser("run", help="Legacy long-poll agent (NodeAgent)") + sub.add_parser("detect", help="List available SDR drivers") + + p_reg = sub.add_parser("register", help="Save agent credentials to ~/.ria/agent.json") + p_reg.add_argument("--url", required=True, help="RIA Hub base URL") + p_reg.add_argument("--token", required=True, help="Agent registration token") + p_reg.add_argument("--name", default=None) + p_reg.add_argument("--agent-id", dest="agent_id", default=None) + p_reg.add_argument("--insecure", action="store_true") + + p_stream = sub.add_parser("stream", help="Run the WebSocket IQ streamer") + p_stream.add_argument("--url", default=None, help="Override WebSocket URL") + p_stream.add_argument("--token", default=None, help="Override bearer token") + p_stream.add_argument("--log-level", default="INFO") + + # Unknown extras are forwarded to the legacy CLI when command == "run". + args, extras = parser.parse_known_args(argv) + + logging.basicConfig( + level=getattr(logging, getattr(args, "log_level", "INFO"), logging.INFO), + format="%(asctime)s [%(levelname)s] %(name)s: %(message)s", + ) + + if args.command == "run": + sys.argv = [sys.argv[0], *extras] + _legacy_main() + return + if args.command == "detect": + sys.exit(_cmd_detect(args)) + if args.command == "register": + sys.exit(_cmd_register(args)) + if args.command == "stream": + sys.exit(_cmd_stream(args)) + + parser.error(f"unknown command: {args.command}") + + +if __name__ == "__main__": + main() diff --git a/src/ria_toolkit_oss/agent/config.py b/src/ria_toolkit_oss/agent/config.py new file mode 100644 index 0000000..01f99ba --- /dev/null +++ b/src/ria_toolkit_oss/agent/config.py @@ -0,0 +1,63 @@ +"""Agent configuration stored at ``~/.ria/agent.json``. + +Schema:: + + { + "hub_url": "https://riahub.example.com", + "agent_id": "agent-abc123", + "token": "rha_xxxx", + "name": "lab-bench-1", + "insecure": false + } +""" + +from __future__ import annotations + +import json +import os +from dataclasses import asdict, dataclass, field +from pathlib import Path + +_DEFAULT_PATH = Path(os.environ.get("RIA_AGENT_CONFIG", str(Path.home() / ".ria" / "agent.json"))) + + +@dataclass +class AgentConfig: + hub_url: str = "" + agent_id: str = "" + token: str = "" + name: str = "" + insecure: bool = False + extra: dict = field(default_factory=dict) + + +def default_path() -> Path: + return _DEFAULT_PATH + + +def load(path: Path | None = None) -> AgentConfig: + p = path or _DEFAULT_PATH + if not p.exists(): + return AgentConfig() + data = json.loads(p.read_text()) + known = {f for f in AgentConfig.__dataclass_fields__ if f != "extra"} + extra = {k: v for k, v in data.items() if k not in known} + return AgentConfig( + hub_url=data.get("hub_url", ""), + agent_id=data.get("agent_id", ""), + token=data.get("token", ""), + name=data.get("name", ""), + insecure=bool(data.get("insecure", False)), + extra=extra, + ) + + +def save(cfg: AgentConfig, path: Path | None = None) -> Path: + p = path or _DEFAULT_PATH + p.parent.mkdir(parents=True, exist_ok=True) + data = asdict(cfg) + extra = data.pop("extra", {}) or {} + data.update(extra) + p.write_text(json.dumps(data, indent=2)) + os.chmod(p, 0o600) + return p diff --git a/src/ria_toolkit_oss/agent/hardware.py b/src/ria_toolkit_oss/agent/hardware.py new file mode 100644 index 0000000..417bf1c --- /dev/null +++ b/src/ria_toolkit_oss/agent/hardware.py @@ -0,0 +1,22 @@ +"""Hardware detection and heartbeat payload construction for the streamer.""" + +from __future__ import annotations + +from ria_toolkit_oss.sdr import detect_available + + +def available_devices() -> list[str]: + """Return a sorted list of device names whose driver modules import cleanly.""" + return sorted(detect_available().keys()) + + +def heartbeat_payload(status: str = "idle", app_id: str | None = None) -> dict: + """Build the JSON body of a periodic heartbeat frame.""" + payload: dict = { + "type": "heartbeat", + "hardware": available_devices(), + "status": status, + } + if app_id: + payload["app_id"] = app_id + return payload diff --git a/src/ria_toolkit_oss/agent.py b/src/ria_toolkit_oss/agent/legacy_executor.py similarity index 100% rename from src/ria_toolkit_oss/agent.py rename to src/ria_toolkit_oss/agent/legacy_executor.py diff --git a/src/ria_toolkit_oss/agent/streamer.py b/src/ria_toolkit_oss/agent/streamer.py new file mode 100644 index 0000000..4d89743 --- /dev/null +++ b/src/ria_toolkit_oss/agent/streamer.py @@ -0,0 +1,221 @@ +"""Thin IQ-streaming agent. + +Listens for control messages from the RIA Hub over a persistent WebSocket. +When the server sends ``start``, opens the SDR described in ``radio_config``, +loops over ``sdr.rx(buffer_size)``, and sends each buffer as raw +interleaved float32 bytes. ``stop`` closes the SDR; ``configure`` applies +parameter updates at the next capture boundary. +""" + +from __future__ import annotations + +import asyncio +import logging +from typing import Any + +import numpy as np + +from .hardware import heartbeat_payload +from .ws_client import WsClient + +logger = logging.getLogger("ria_agent.streamer") + +_DEFAULT_BUFFER_SIZE = 1024 + + +class Streamer: + """Main streamer loop. + + Parameters + ---------- + ws: + Connected :class:`WsClient`. + sdr_factory: + Callable ``(device, identifier) -> SDR``. Defaults to + :func:`ria_toolkit_oss.sdr.get_sdr_device`. Injectable for tests. + """ + + def __init__(self, ws: WsClient, sdr_factory=None) -> None: + self.ws = ws + self._sdr_factory = sdr_factory + self._app_id: str | None = None + self._sdr: Any = None + self._pending_config: dict = {} + self._capture_task: asyncio.Task | None = None + self._status = "idle" + + # ------------------------------------------------------------------ + # WsClient wiring + + def build_heartbeat(self) -> dict: + return heartbeat_payload(status=self._status, app_id=self._app_id) + + async def on_message(self, msg: dict) -> None: + t = msg.get("type") + if t == "start": + await self._handle_start(msg) + elif t == "stop": + await self._handle_stop(msg) + elif t == "configure": + self._pending_config.update(msg.get("radio_config") or {}) + logger.debug("Queued configure: %s", self._pending_config) + else: + logger.warning("Unknown server message type: %r", t) + + # ------------------------------------------------------------------ + async def _handle_start(self, msg: dict) -> None: + if self._capture_task is not None and not self._capture_task.done(): + logger.warning("start received while already streaming — ignoring") + return + + self._app_id = msg.get("app_id") + radio_config = dict(msg.get("radio_config") or {}) + device = radio_config.pop("device", None) + identifier = radio_config.pop("identifier", None) + buffer_size = int(radio_config.pop("buffer_size", _DEFAULT_BUFFER_SIZE)) + if not device: + await self._send_error("start missing radio_config.device") + return + + try: + factory = self._sdr_factory or _default_sdr_factory + self._sdr = factory(device, identifier) + _apply_sdr_config(self._sdr, radio_config) + except Exception as exc: + logger.exception("Failed to open SDR %r", device) + await self._send_error(f"SDR init failed: {exc}") + return + + self._status = "streaming" + await self._send_status("streaming") + self._capture_task = asyncio.create_task( + self._capture_loop(buffer_size), name="ria-streamer-capture" + ) + + async def _handle_stop(self, msg: dict) -> None: + if self._capture_task is not None: + self._capture_task.cancel() + try: + await self._capture_task + except (asyncio.CancelledError, Exception): + pass + self._capture_task = None + self._close_sdr() + self._app_id = None + self._status = "idle" + await self._send_status("idle") + + async def _capture_loop(self, buffer_size: int) -> None: + loop = asyncio.get_running_loop() + try: + while True: + if self._pending_config: + cfg = self._pending_config + self._pending_config = {} + try: + _apply_sdr_config(self._sdr, cfg) + except Exception as exc: + logger.warning("Applying configure failed: %s", exc) + + try: + samples = await loop.run_in_executor(None, self._sdr.rx, buffer_size) + except Exception as exc: + from ria_toolkit_oss.sdr import SdrDisconnectedError + + if isinstance(exc, SdrDisconnectedError): + logger.warning("SDR disconnected: %s", exc) + await self._send_error(f"SDR disconnected: {exc}") + else: + logger.exception("SDR rx error") + await self._send_error(f"SDR capture failed: {exc}") + break + + payload = _samples_to_interleaved_float32(samples) + try: + await self.ws.send_bytes(payload) + except Exception as exc: + logger.warning("Send failed: %s — ending capture", exc) + break + except asyncio.CancelledError: + raise + finally: + self._close_sdr() + + def _close_sdr(self) -> None: + if self._sdr is None: + return + try: + self._sdr.close() + except Exception: + pass + self._sdr = None + + async def _send_status(self, status: str) -> None: + try: + await self.ws.send_json({"type": "status", "status": status, "app_id": self._app_id}) + except Exception as exc: + logger.debug("Status send failed: %s", exc) + + async def _send_error(self, message: str) -> None: + try: + await self.ws.send_json({"type": "error", "app_id": self._app_id, "message": message}) + except Exception as exc: + logger.debug("Error-frame send failed: %s", exc) + + +# --------------------------------------------------------------------------- +# Helpers + +_CONFIG_ATTR_MAP = { + "sample_rate": ("sample_rate", "rx_sample_rate"), + "center_frequency": ("center_freq", "rx_center_frequency"), + "center_freq": ("center_freq", "rx_center_frequency"), + "gain": ("gain", "rx_gain"), + "bandwidth": ("bandwidth", "rx_bandwidth"), +} + + +def _apply_sdr_config(sdr: Any, cfg: dict) -> None: + """Apply a radio_config dict to an SDR, trying multiple attribute aliases.""" + for key, value in cfg.items(): + if value is None: + continue + attrs = _CONFIG_ATTR_MAP.get(key, (key,)) + applied = False + for attr in attrs: + if hasattr(sdr, attr): + try: + setattr(sdr, attr, value) + applied = True + break + except Exception as exc: + logger.debug("setattr %s=%r failed: %s", attr, value, exc) + if not applied: + logger.debug("radio_config key %r ignored (no matching attr)", key) + + +def _samples_to_interleaved_float32(samples: Any) -> bytes: + """Convert complex IQ samples (any numeric dtype) to interleaved float32 bytes.""" + arr = np.asarray(samples) + if np.iscomplexobj(arr): + interleaved = np.empty(arr.size * 2, dtype=np.float32) + interleaved[0::2] = arr.real.astype(np.float32, copy=False).ravel() + interleaved[1::2] = arr.imag.astype(np.float32, copy=False).ravel() + return interleaved.tobytes() + return arr.astype(np.float32, copy=False).tobytes() + + +def _default_sdr_factory(device: str, identifier: str | None): + from ria_toolkit_oss.sdr import get_sdr_device + + return get_sdr_device(device, ident=identifier) + + +# --------------------------------------------------------------------------- +# Top-level entry + +async def run_streamer(ws_url: str, token: str) -> None: + """Connect to *ws_url* and run the streamer loop until cancelled.""" + ws = WsClient(ws_url, token) + streamer = Streamer(ws) + await ws.run(streamer.on_message, streamer.build_heartbeat) diff --git a/src/ria_toolkit_oss/agent/ws_client.py b/src/ria_toolkit_oss/agent/ws_client.py new file mode 100644 index 0000000..1bc66f6 --- /dev/null +++ b/src/ria_toolkit_oss/agent/ws_client.py @@ -0,0 +1,117 @@ +"""Persistent WebSocket client for the streamer agent. + +Handles connection lifecycle: connect, heartbeat, auto-reconnect on drop. +The caller drives the I/O loop via ``run()`` with a message handler callback. +""" + +from __future__ import annotations + +import asyncio +import json +import logging +from typing import Awaitable, Callable + +logger = logging.getLogger("ria_agent.ws") + +MessageHandler = Callable[[dict], Awaitable[None]] +HeartbeatBuilder = Callable[[], dict] + + +class WsClient: + """Persistent WebSocket connection with heartbeat and auto-reconnect. + + ``url`` should be a full ``wss://host/path`` (or ``ws://``) URL. ``token`` + is sent as a bearer in the ``Authorization`` header on connect. + """ + + def __init__( + self, + url: str, + token: str, + *, + heartbeat_interval: float = 30.0, + reconnect_pause: float = 5.0, + ) -> None: + self.url = url + self.token = token + self.heartbeat_interval = heartbeat_interval + self.reconnect_pause = reconnect_pause + self._ws = None + self._stop = asyncio.Event() + + # ------------------------------------------------------------------ + async def _connect(self): + import websockets + + headers = [("Authorization", f"Bearer {self.token}")] if self.token else None + # websockets >= 12 accepts additional_headers; fall back to extra_headers for older versions. + try: + return await websockets.connect(self.url, additional_headers=headers) + except TypeError: + return await websockets.connect(self.url, extra_headers=headers) + + # ------------------------------------------------------------------ + async def send_json(self, payload: dict) -> None: + if self._ws is None: + raise ConnectionError("WebSocket is not connected") + await self._ws.send(json.dumps(payload)) + + async def send_bytes(self, data: bytes) -> None: + if self._ws is None: + raise ConnectionError("WebSocket is not connected") + await self._ws.send(data) + + def stop(self) -> None: + self._stop.set() + + # ------------------------------------------------------------------ + async def run(self, on_message: MessageHandler, heartbeat: HeartbeatBuilder) -> None: + """Main loop: connect, heartbeat, dispatch messages, reconnect on drop.""" + while not self._stop.is_set(): + try: + self._ws = await self._connect() + logger.info("Connected to %s", self.url) + hb_task = asyncio.create_task(self._heartbeat_loop(heartbeat)) + try: + async for raw in self._ws: + if isinstance(raw, bytes): + # Server shouldn't send binary to the agent; log and drop. + logger.debug("Discarding unexpected %d-byte binary frame", len(raw)) + continue + try: + msg = json.loads(raw) + except json.JSONDecodeError: + logger.warning("Malformed control frame: %r", raw[:200]) + continue + await on_message(msg) + finally: + hb_task.cancel() + try: + await hb_task + except (asyncio.CancelledError, Exception): + pass + except asyncio.CancelledError: + raise + except Exception as exc: + if self._stop.is_set(): + break + logger.warning("WS error: %s — reconnecting in %.1fs", exc, self.reconnect_pause) + finally: + try: + if self._ws is not None: + await self._ws.close() + except Exception: + pass + self._ws = None + if self._stop.is_set(): + break + await asyncio.sleep(self.reconnect_pause) + + async def _heartbeat_loop(self, heartbeat: HeartbeatBuilder) -> None: + while True: + try: + await self.send_json(heartbeat()) + except Exception as exc: + logger.debug("Heartbeat send failed: %s", exc) + return + await asyncio.sleep(self.heartbeat_interval) diff --git a/src/ria_toolkit_oss/sdr/__init__.py b/src/ria_toolkit_oss/sdr/__init__.py index 78a13a9..4b327a2 100644 --- a/src/ria_toolkit_oss/sdr/__init__.py +++ b/src/ria_toolkit_oss/sdr/__init__.py @@ -4,10 +4,48 @@ It streamlines tasks involving signal reception and transmission, as well as com operations such as detecting and configuring available devices. """ -__all__ = ["SDR", "SDRError", "SDRParameterError", "MockSDR", "get_sdr_device"] +__all__ = [ + "SDR", + "SDRError", + "SDRParameterError", + "SdrDisconnectedError", + "MockSDR", + "get_sdr_device", + "detect_available", +] from .mock import MockSDR -from .sdr import SDR, SDRError, SDRParameterError +from .sdr import SDR, SDRError, SdrDisconnectedError, SDRParameterError, translate_disconnect # noqa: F401 + + +_DRIVER_CANDIDATES: tuple[tuple[str, str, str], ...] = ( + ("mock", "ria_toolkit_oss.sdr.mock", "MockSDR"), + ("pluto", "ria_toolkit_oss.sdr.pluto", "Pluto"), + ("hackrf", "ria_toolkit_oss.sdr.hackrf", "HackRF"), + ("rtlsdr", "ria_toolkit_oss.sdr.rtlsdr", "RTLSDR"), + ("usrp", "ria_toolkit_oss.sdr.usrp", "USRP"), + ("blade", "ria_toolkit_oss.sdr.blade", "Blade"), + ("thinkrf", "ria_toolkit_oss.sdr.thinkrf", "ThinkRF"), +) + + +def detect_available() -> dict[str, type]: + """Return ``{device_name: driver_class}`` for every driver whose module imports cleanly. + + Importability is a proxy for "the user has installed this driver's optional dependency". + It does not probe for physical hardware presence — that requires actually instantiating + the driver, which can be slow and side-effectful. + """ + import importlib + + out: dict[str, type] = {} + for name, module_path, cls_name in _DRIVER_CANDIDATES: + try: + mod = importlib.import_module(module_path) + out[name] = getattr(mod, cls_name) + except Exception: + continue + return out def get_sdr_device(device_type: str, ident: str | None = None, tx: bool = False) -> SDR: diff --git a/src/ria_toolkit_oss/sdr/pluto.py b/src/ria_toolkit_oss/sdr/pluto.py index 68b3973..7ed3be0 100644 --- a/src/ria_toolkit_oss/sdr/pluto.py +++ b/src/ria_toolkit_oss/sdr/pluto.py @@ -8,7 +8,7 @@ import adi import numpy as np from ria_toolkit_oss.datatypes.recording import Recording -from ria_toolkit_oss.sdr.sdr import SDR, SDRError, SDRParameterError +from ria_toolkit_oss.sdr.sdr import SDR, SDRError, SDRParameterError, translate_disconnect class Pluto(SDR): @@ -164,6 +164,25 @@ class Pluto(SDR): # send callback complex signal callback(buffer=signal, metadata=None) + def rx(self, num_samples: Optional[int] = None) -> np.ndarray: + """PlutoSDR-style single-buffer capture returning a complex64 array. + + Sets the radio buffer size to *num_samples* (if given) and returns one + buffer directly from ``self.radio.rx()``. Raises + :class:`SdrDisconnectedError` on USB/device drop so callers (e.g. the + streamer) can report the failure and stop cleanly instead of crashing. + """ + if num_samples is not None: + try: + self.set_rx_buffer_size(buffer_size=int(num_samples)) + except Exception as exc: + raise translate_disconnect(exc) from exc + try: + samples = self.radio.rx() + except Exception as exc: + raise translate_disconnect(exc) from exc + return np.asarray(samples) + def _record_fast(self, num_samples): """Optimized single-buffer capture for ≤16M samples.""" diff --git a/src/ria_toolkit_oss/sdr/sdr.py b/src/ria_toolkit_oss/sdr/sdr.py index 36e26f7..abab125 100644 --- a/src/ria_toolkit_oss/sdr/sdr.py +++ b/src/ria_toolkit_oss/sdr/sdr.py @@ -528,3 +528,51 @@ class SDROverflowError(SDRError): """Buffer overflow detected.""" pass + + +class SdrDisconnectedError(SDRError): + """Raised when the SDR device disappears mid-operation (USB unplug, network drop).""" + + pass + + +# Substrings that strongly indicate a device has disappeared rather than a +# transient / recoverable error. Checked case-insensitively against str(exc). +_DISCONNECT_MARKERS = ( + "no such device", + "device not found", + "not found", + "broken pipe", + "disconnected", + "no device", + "device unplugged", + "usb", + "i/o error", + "input/output error", + "errno 19", # ENODEV + "errno 5", # EIO +) + + +def translate_disconnect(exc: BaseException) -> BaseException: + """Return ``SdrDisconnectedError`` if *exc* looks like a USB/device drop, else *exc*. + + Drivers wrap their native-API calls with:: + + try: + return self.radio.rx() + except Exception as exc: + raise translate_disconnect(exc) from exc + + The caller (e.g. the streamer) can then catch ``SdrDisconnectedError`` + specifically and report it to the hub rather than crashing the loop. + """ + if isinstance(exc, SdrDisconnectedError): + return exc + msg = str(exc).lower() + if any(marker in msg for marker in _DISCONNECT_MARKERS): + return SdrDisconnectedError(str(exc)) + # OSError subclass with ENODEV / EIO errno is also a disconnect signal. + if isinstance(exc, OSError) and getattr(exc, "errno", None) in (5, 19): + return SdrDisconnectedError(str(exc)) + return exc diff --git a/tests/agent/__init__.py b/tests/agent/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/agent/test_config.py b/tests/agent/test_config.py new file mode 100644 index 0000000..2532abd --- /dev/null +++ b/tests/agent/test_config.py @@ -0,0 +1,33 @@ +from ria_toolkit_oss.agent import config as agent_config + + +def test_round_trip(tmp_path): + p = tmp_path / "agent.json" + cfg = agent_config.AgentConfig( + hub_url="https://hub.example.com", + agent_id="agent-1", + token="t", + name="bench", + insecure=True, + ) + agent_config.save(cfg, path=p) + loaded = agent_config.load(path=p) + assert loaded == cfg + + +def test_load_missing_returns_empty(tmp_path): + loaded = agent_config.load(path=tmp_path / "none.json") + assert loaded == agent_config.AgentConfig() + + +def test_extra_keys_preserved(tmp_path): + p = tmp_path / "agent.json" + p.write_text('{"hub_url": "x", "custom": 42}') + cfg = agent_config.load(path=p) + assert cfg.hub_url == "x" + assert cfg.extra == {"custom": 42} + agent_config.save(cfg, path=p) + import json + + data = json.loads(p.read_text()) + assert data["custom"] == 42 diff --git a/tests/agent/test_disconnect.py b/tests/agent/test_disconnect.py new file mode 100644 index 0000000..f063e3a --- /dev/null +++ b/tests/agent/test_disconnect.py @@ -0,0 +1,81 @@ +"""SdrDisconnectedError translation + streamer handling.""" + +from __future__ import annotations + +import asyncio + +from ria_toolkit_oss.agent.streamer import Streamer +from ria_toolkit_oss.sdr import SdrDisconnectedError +from ria_toolkit_oss.sdr.sdr import translate_disconnect + + +def test_translate_disconnect_usb_message(): + exc = RuntimeError("libiio: Input/output error (errno 5)") + out = translate_disconnect(exc) + assert isinstance(out, SdrDisconnectedError) + + +def test_translate_disconnect_enodev_oserror(): + exc = OSError(19, "No such device") + assert isinstance(translate_disconnect(exc), SdrDisconnectedError) + + +def test_translate_disconnect_passes_through_unrelated(): + exc = ValueError("bad sample rate") + assert translate_disconnect(exc) is exc + + +def test_translate_disconnect_preserves_sdr_disconnected(): + original = SdrDisconnectedError("already typed") + assert translate_disconnect(original) is original + + +class _FlakySdr: + """SDR that raises SdrDisconnectedError on the first rx() call.""" + + def __init__(self) -> None: + self.closed = False + + def rx(self, n): # noqa: D401 - trivial + raise SdrDisconnectedError("usb gone") + + def close(self): + self.closed = True + + +class _Ws: + def __init__(self): + self.json_sent = [] + self.bytes_sent = [] + + async def send_json(self, p): + self.json_sent.append(p) + + async def send_bytes(self, b): + self.bytes_sent.append(b) + + +def test_streamer_reports_disconnected_and_ends_capture(): + async def scenario(): + ws = _Ws() + sdr = _FlakySdr() + streamer = Streamer(ws=ws, sdr_factory=lambda d, i: sdr) + await streamer.on_message( + { + "type": "start", + "app_id": "a", + "radio_config": {"device": "fake", "buffer_size": 8}, + } + ) + # Wait for the capture task to fail out. + for _ in range(50): + if streamer._capture_task and streamer._capture_task.done(): + break + await asyncio.sleep(0.01) + return ws, sdr, streamer + + ws, sdr, streamer = asyncio.run(scenario()) + assert sdr.closed + errors = [m for m in ws.json_sent if m.get("type") == "error"] + assert errors, "expected an error frame" + assert "disconnected" in errors[-1]["message"].lower() diff --git a/tests/agent/test_hardware.py b/tests/agent/test_hardware.py new file mode 100644 index 0000000..ab9fcdf --- /dev/null +++ b/tests/agent/test_hardware.py @@ -0,0 +1,29 @@ +from ria_toolkit_oss.agent import hardware +from ria_toolkit_oss.sdr import detect_available + + +def test_detect_available_includes_mock(): + drivers = detect_available() + assert "mock" in drivers + from ria_toolkit_oss.sdr.mock import MockSDR + + assert drivers["mock"] is MockSDR + + +def test_available_devices_sorted_list(): + devices = hardware.available_devices() + assert isinstance(devices, list) + assert devices == sorted(devices) + assert "mock" in devices + + +def test_heartbeat_payload_shape(): + p = hardware.heartbeat_payload() + assert p["type"] == "heartbeat" + assert p["status"] == "idle" + assert "mock" in p["hardware"] + assert "app_id" not in p + + p2 = hardware.heartbeat_payload(status="streaming", app_id="abc") + assert p2["status"] == "streaming" + assert p2["app_id"] == "abc" diff --git a/tests/agent/test_integration.py b/tests/agent/test_integration.py new file mode 100644 index 0000000..168e7a6 --- /dev/null +++ b/tests/agent/test_integration.py @@ -0,0 +1,100 @@ +"""End-to-end: local websockets server drives a Streamer with a MockSDR.""" + +from __future__ import annotations + +import asyncio +import json + +import websockets + +from ria_toolkit_oss.agent.streamer import Streamer +from ria_toolkit_oss.agent.ws_client import WsClient +from ria_toolkit_oss.sdr.mock import MockSDR + + +def test_server_start_stream_stop_cycle_over_real_ws(): + async def scenario(): + control_frames: list[dict] = [] + binary_frames: list[bytes] = [] + ready = asyncio.Event() + stopped = asyncio.Event() + + async def server_handler(ws): + # Agent will open the connection; wait for heartbeat first. + try: + first = await asyncio.wait_for(ws.recv(), timeout=2.0) + control_frames.append(json.loads(first)) + await ws.send( + json.dumps( + { + "type": "start", + "app_id": "app-1", + "radio_config": { + "device": "mock", + "buffer_size": 32, + "sample_rate": 1_000_000, + "center_frequency": 2.45e9, + }, + } + ) + ) + while len(binary_frames) < 3: + msg = await asyncio.wait_for(ws.recv(), timeout=2.0) + if isinstance(msg, bytes): + binary_frames.append(msg) + else: + control_frames.append(json.loads(msg)) + ready.set() + await ws.send(json.dumps({"type": "stop", "app_id": "app-1"})) + # Drain final status frame. + try: + while True: + msg = await asyncio.wait_for(ws.recv(), timeout=0.5) + if isinstance(msg, bytes): + binary_frames.append(msg) + else: + control_frames.append(json.loads(msg)) + except (asyncio.TimeoutError, Exception): + pass + stopped.set() + except Exception: + stopped.set() + + server = await websockets.serve(server_handler, "127.0.0.1", 0) + port = server.sockets[0].getsockname()[1] + try: + client = WsClient( + f"ws://127.0.0.1:{port}", + token="", + heartbeat_interval=10.0, + reconnect_pause=0.05, + ) + streamer = Streamer(ws=client, sdr_factory=lambda d, i: MockSDR(buffer_size=32, seed=0)) + task = asyncio.create_task( + client.run(on_message=streamer.on_message, heartbeat=streamer.build_heartbeat) + ) + await asyncio.wait_for(ready.wait(), timeout=3.0) + await asyncio.wait_for(stopped.wait(), timeout=3.0) + client.stop() + task.cancel() + try: + await task + except (asyncio.CancelledError, Exception): + pass + finally: + server.close() + await server.wait_closed() + return control_frames, binary_frames + + controls, binaries = asyncio.run(scenario()) + + # Heartbeat reached the server. + assert any(f.get("type") == "heartbeat" for f in controls) + # Status transitioned idle -> streaming -> idle. + statuses = [f["status"] for f in controls if f.get("type") == "status"] + assert "streaming" in statuses + assert statuses[-1] == "idle" + # Three binary IQ frames of 32 samples × 2 floats × 4 bytes. + assert len(binaries) >= 3 + for b in binaries[:3]: + assert len(b) == 32 * 2 * 4 diff --git a/tests/agent/test_legacy.py b/tests/agent/test_legacy.py new file mode 100644 index 0000000..36e4ea0 --- /dev/null +++ b/tests/agent/test_legacy.py @@ -0,0 +1,19 @@ +"""Regression: legacy NodeAgent still importable after the package move.""" + + +def test_import_node_agent_from_package(): + from ria_toolkit_oss.agent import NodeAgent + + assert NodeAgent.__name__ == "NodeAgent" + + +def test_main_entry_point_exists(): + from ria_toolkit_oss.agent import main + + assert callable(main) + + +def test_legacy_module_still_direct_importable(): + from ria_toolkit_oss.agent.legacy_executor import NodeAgent as LegacyNodeAgent + + assert LegacyNodeAgent.__name__ == "NodeAgent" diff --git a/tests/agent/test_streamer.py b/tests/agent/test_streamer.py new file mode 100644 index 0000000..1bb2081 --- /dev/null +++ b/tests/agent/test_streamer.py @@ -0,0 +1,124 @@ +"""Unit tests for the streamer: drive it with a fake WsClient + MockSDR.""" + +from __future__ import annotations + +import asyncio + +import numpy as np + +from ria_toolkit_oss.agent.streamer import ( + Streamer, + _apply_sdr_config, + _samples_to_interleaved_float32, +) +from ria_toolkit_oss.sdr.mock import MockSDR + + +class FakeWs: + def __init__(self): + self.json_sent: list[dict] = [] + self.bytes_sent: list[bytes] = [] + + async def send_json(self, payload: dict) -> None: + self.json_sent.append(payload) + + async def send_bytes(self, data: bytes) -> None: + self.bytes_sent.append(data) + + +def _factory(device: str, identifier): + return MockSDR(buffer_size=32, seed=0) + + +def test_samples_to_interleaved_float32_roundtrip(): + c = np.array([1 + 2j, 3 + 4j], dtype=np.complex64) + raw = _samples_to_interleaved_float32(c) + arr = np.frombuffer(raw, dtype=np.float32) + assert arr.tolist() == [1.0, 2.0, 3.0, 4.0] + + +def test_apply_sdr_config_sets_attributes(): + sdr = MockSDR(buffer_size=16) + _apply_sdr_config(sdr, {"sample_rate": 2e6, "center_frequency": 9.15e8, "gain": 30}) + assert sdr.sample_rate == 2e6 + assert sdr.center_freq == 9.15e8 + assert sdr.gain == 30 + + +def test_heartbeat_reflects_status_and_app(): + s = Streamer(ws=FakeWs(), sdr_factory=_factory) + hb = s.build_heartbeat() + assert hb["type"] == "heartbeat" + assert hb["status"] == "idle" + s._status = "streaming" + s._app_id = "app-42" + hb2 = s.build_heartbeat() + assert hb2["status"] == "streaming" + assert hb2["app_id"] == "app-42" + + +def test_full_start_stream_stop_cycle(): + async def scenario(): + ws = FakeWs() + streamer = Streamer(ws=ws, sdr_factory=_factory) + + await streamer.on_message( + { + "type": "start", + "app_id": "abc", + "radio_config": { + "device": "mock", + "sample_rate": 1_000_000, + "center_frequency": 2.45e9, + "gain": 40, + "buffer_size": 64, + }, + } + ) + for _ in range(30): + if len(ws.bytes_sent) >= 2: + break + await asyncio.sleep(0.02) + await streamer.on_message({"type": "stop", "app_id": "abc"}) + return ws, streamer + + ws, streamer = asyncio.run(scenario()) + assert len(ws.bytes_sent) >= 1 + for frame in ws.bytes_sent: + assert len(frame) == 64 * 2 * 4 # 64 samples × (I,Q) × float32 + statuses = [m for m in ws.json_sent if m.get("type") == "status"] + assert statuses[0]["status"] == "streaming" + assert statuses[-1]["status"] == "idle" + assert streamer._sdr is None + + +def test_start_without_device_emits_error(): + async def scenario(): + ws = FakeWs() + streamer = Streamer(ws=ws, sdr_factory=_factory) + await streamer.on_message({"type": "start", "app_id": "x", "radio_config": {}}) + return ws + + ws = asyncio.run(scenario()) + errors = [m for m in ws.json_sent if m.get("type") == "error"] + assert errors and "device" in errors[0]["message"] + + +def test_configure_queues_update(): + async def scenario(): + streamer = Streamer(ws=FakeWs(), sdr_factory=_factory) + await streamer.on_message( + {"type": "configure", "app_id": "x", "radio_config": {"center_frequency": 915e6}} + ) + return streamer._pending_config + + pending = asyncio.run(scenario()) + assert pending == {"center_frequency": 915e6} + + +def test_unknown_message_type_is_ignored(): + async def scenario(): + s = Streamer(ws=FakeWs(), sdr_factory=_factory) + await s.on_message({"type": "nope"}) + + asyncio.run(scenario()) diff --git a/tests/agent/test_ws_client.py b/tests/agent/test_ws_client.py new file mode 100644 index 0000000..0994a5b --- /dev/null +++ b/tests/agent/test_ws_client.py @@ -0,0 +1,161 @@ +"""Reconnect + heartbeat timing against a real local websockets server.""" + +from __future__ import annotations + +import asyncio +import json + +import pytest +import websockets + +from ria_toolkit_oss.agent.ws_client import WsClient + + +async def _recv_json(ws) -> dict: + raw = await ws.recv() + return json.loads(raw) + + +async def _open_server(handler): + # websockets 13 ignores extra positional args; bind to localhost:0 for an + # ephemeral port and return both the server and the port. + server = await websockets.serve(handler, "127.0.0.1", 0) + port = server.sockets[0].getsockname()[1] + return server, port + + +def test_heartbeat_sent_on_connect(): + async def scenario(): + received: list[dict] = [] + connected = asyncio.Event() + + async def handler(ws): + connected.set() + msg = await _recv_json(ws) + received.append(msg) + + server, port = await _open_server(handler) + try: + client = WsClient( + f"ws://127.0.0.1:{port}", + token="", + heartbeat_interval=0.05, + reconnect_pause=0.05, + ) + task = asyncio.create_task( + client.run(on_message=lambda _m: asyncio.sleep(0), heartbeat=lambda: {"type": "heartbeat", "n": 1}) + ) + await asyncio.wait_for(connected.wait(), timeout=2.0) + for _ in range(50): + if received: + break + await asyncio.sleep(0.02) + client.stop() + task.cancel() + try: + await task + except (asyncio.CancelledError, Exception): + pass + finally: + server.close() + await server.wait_closed() + return received + + received = asyncio.run(scenario()) + assert received and received[0]["type"] == "heartbeat" + + +def test_reconnects_after_server_drop(): + async def scenario(): + connections = 0 + first_dropped = asyncio.Event() + + async def handler(ws): + nonlocal connections + connections += 1 + if connections == 1: + await ws.close() + first_dropped.set() + else: + try: + await ws.recv() + except Exception: + pass + + server, port = await _open_server(handler) + try: + client = WsClient( + f"ws://127.0.0.1:{port}", + token="", + heartbeat_interval=10.0, + reconnect_pause=0.05, + ) + task = asyncio.create_task( + client.run(on_message=lambda _m: asyncio.sleep(0), heartbeat=lambda: {"type": "heartbeat"}) + ) + await asyncio.wait_for(first_dropped.wait(), timeout=2.0) + for _ in range(100): + if connections >= 2: + break + await asyncio.sleep(0.02) + client.stop() + task.cancel() + try: + await task + except (asyncio.CancelledError, Exception): + pass + finally: + server.close() + await server.wait_closed() + return connections + + n = asyncio.run(scenario()) + assert n >= 2 + + +def test_malformed_control_frame_does_not_crash(): + async def scenario(): + handled: list[dict] = [] + done = asyncio.Event() + + async def handler(ws): + await ws.send("not json") + await ws.send(json.dumps({"type": "ping"})) + done.set() + try: + await ws.wait_closed() + except Exception: + pass + + server, port = await _open_server(handler) + try: + client = WsClient( + f"ws://127.0.0.1:{port}", + token="", + heartbeat_interval=10.0, + reconnect_pause=0.05, + ) + + async def on_msg(m): + handled.append(m) + + task = asyncio.create_task( + client.run(on_message=on_msg, heartbeat=lambda: {"type": "heartbeat"}) + ) + for _ in range(50): + if handled: + break + await asyncio.sleep(0.02) + client.stop() + task.cancel() + try: + await task + except (asyncio.CancelledError, Exception): + pass + finally: + server.close() + await server.wait_closed() + return handled + + handled = asyncio.run(scenario()) + assert handled and handled[0] == {"type": "ping"} From 84f3a63e8b89efacfc5dfe6ab3f2269676286328 Mon Sep 17 00:00:00 2001 From: jonny Date: Mon, 13 Apr 2026 12:54:05 -0400 Subject: [PATCH 02/12] simplifying cli --- src/ria_toolkit_oss/agent/cli.py | 48 +++++++++++++++++++++-------- src/ria_toolkit_oss/agent/config.py | 2 ++ 2 files changed, 38 insertions(+), 12 deletions(-) diff --git a/src/ria_toolkit_oss/agent/cli.py b/src/ria_toolkit_oss/agent/cli.py index 2873293..0b06e72 100644 --- a/src/ria_toolkit_oss/agent/cli.py +++ b/src/ria_toolkit_oss/agent/cli.py @@ -38,16 +38,41 @@ def _cmd_detect(_args: argparse.Namespace) -> int: def _cmd_register(args: argparse.Namespace) -> int: + import urllib.request + + hub_url = args.hub.rstrip("/") + url = f"{hub_url}/screens/agents/register" + body = json.dumps({"name": args.name or ""}).encode() + req = urllib.request.Request( + url, + data=body, + headers={ + "Content-Type": "application/json", + "X-API-Key": args.api_key, + }, + ) + try: + with urllib.request.urlopen(req) as resp: + data = json.loads(resp.read()) + except Exception as e: + print(f"error: registration failed: {e}", file=sys.stderr) + return 1 + + agent_id = data["agent_id"] + token = data["token"] + cfg = _config.load() - cfg.hub_url = args.url - cfg.token = args.token + cfg.hub_url = hub_url + cfg.agent_id = agent_id + cfg.token = token + cfg.api_key = args.api_key if args.name: cfg.name = args.name - if args.agent_id: - cfg.agent_id = args.agent_id cfg.insecure = bool(args.insecure) path = _config.save(cfg) - print(f"Saved agent credentials to {path}") + + print(f"Registered agent: {agent_id}") + print(f"Credentials saved to {path}") return 0 @@ -75,7 +100,7 @@ def _derive_ws_url(hub_url: str, agent_id: str) -> str: base = "wss://" + base[len("https://"):] elif base.startswith("http://"): base = "ws://" + base[len("http://"):] - suffix = f"/api/agent/ws/{agent_id}" if agent_id else "/api/agent/ws" + suffix = f"/screens/agent/ws?agent_id={agent_id}" if agent_id else "/screens/agent/ws" return base + suffix @@ -93,12 +118,11 @@ def main() -> None: sub.add_parser("run", help="Legacy long-poll agent (NodeAgent)") sub.add_parser("detect", help="List available SDR drivers") - p_reg = sub.add_parser("register", help="Save agent credentials to ~/.ria/agent.json") - p_reg.add_argument("--url", required=True, help="RIA Hub base URL") - p_reg.add_argument("--token", required=True, help="Agent registration token") - p_reg.add_argument("--name", default=None) - p_reg.add_argument("--agent-id", dest="agent_id", default=None) - p_reg.add_argument("--insecure", action="store_true") + p_reg = sub.add_parser("register", help="Register agent with RIA Hub and save credentials") + p_reg.add_argument("--hub", required=True, help="RIA Hub URL (e.g. http://whitehorse:3005)") + p_reg.add_argument("--api-key", dest="api_key", required=True, help="Hub API key") + p_reg.add_argument("--name", default=None, help="Human-friendly agent name") + p_reg.add_argument("--insecure", action="store_true", help="Skip TLS verification") p_stream = sub.add_parser("stream", help="Run the WebSocket IQ streamer") p_stream.add_argument("--url", default=None, help="Override WebSocket URL") diff --git a/src/ria_toolkit_oss/agent/config.py b/src/ria_toolkit_oss/agent/config.py index 01f99ba..d1f0e00 100644 --- a/src/ria_toolkit_oss/agent/config.py +++ b/src/ria_toolkit_oss/agent/config.py @@ -28,6 +28,7 @@ class AgentConfig: token: str = "" name: str = "" insecure: bool = False + api_key: str = "" extra: dict = field(default_factory=dict) @@ -48,6 +49,7 @@ def load(path: Path | None = None) -> AgentConfig: token=data.get("token", ""), name=data.get("name", ""), insecure=bool(data.get("insecure", False)), + api_key=data.get("api_key", ""), extra=extra, ) From 87bc78e063ef04f2b32c560d2db683bcb0b5eeb4 Mon Sep 17 00:00:00 2001 From: jonny Date: Tue, 14 Apr 2026 13:03:26 -0400 Subject: [PATCH 03/12] new commands --- pyproject.toml | 1 + src/ria_toolkit_oss/app/__init__.py | 1 + src/ria_toolkit_oss/app/cli.py | 242 ++++++++++++++++++++++++++++ src/ria_toolkit_oss/app/config.py | 49 ++++++ 4 files changed, 293 insertions(+) create mode 100644 src/ria_toolkit_oss/app/__init__.py create mode 100644 src/ria_toolkit_oss/app/cli.py create mode 100644 src/ria_toolkit_oss/app/config.py diff --git a/pyproject.toml b/pyproject.toml index a0bd664..3dc1fd3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -118,6 +118,7 @@ ria = "ria_toolkit_oss_cli.cli:cli" ria-tools = "ria_toolkit_oss_cli.cli:cli" ria-server = "ria_toolkit_oss.server.cli:serve" ria-agent = "ria_toolkit_oss.agent.cli:main" +ria-app = "ria_toolkit_oss.app.cli:main" [tool.poetry.group.server.dependencies] fastapi = ">=0.111,<1.0" diff --git a/src/ria_toolkit_oss/app/__init__.py b/src/ria_toolkit_oss/app/__init__.py new file mode 100644 index 0000000..465659f --- /dev/null +++ b/src/ria_toolkit_oss/app/__init__.py @@ -0,0 +1 @@ +"""App runner: pull and run containerized RIA applications.""" diff --git a/src/ria_toolkit_oss/app/cli.py b/src/ria_toolkit_oss/app/cli.py new file mode 100644 index 0000000..c70eb16 --- /dev/null +++ b/src/ria_toolkit_oss/app/cli.py @@ -0,0 +1,242 @@ +"""Unified ``ria-app`` CLI. + +Subcommands: + +- ``ria-app pull [:tag]`` — pull a RIA app image from the configured registry. +- ``ria-app run [:tag]`` — pull (if needed) and run, auto-configuring + GPU/USB/network flags from image labels set by CI. +- ``ria-app list`` — list locally cached RIA app images. +- ``ria-app stop `` — stop a running app container. +- ``ria-app logs `` — tail logs of a running app container. +- ``ria-app configure`` — set default registry/namespace. + +Image references resolve as:: + + my-classifier -> {registry}/{namespace}/my-classifier:latest + group/my-classifier -> {registry}/group/my-classifier:latest + host/group/app:tag -> host/group/app:tag (fully-qualified passthrough) +""" + +from __future__ import annotations + +import argparse +import json +import shutil +import subprocess +import sys + +from . import config as _config + +_LABEL_PROFILE = "ria.profile" +_LABEL_HARDWARE = "ria.hardware" +_LABEL_APP = "ria.app" + + +def _engine() -> str: + for exe in ("docker", "podman"): + if shutil.which(exe): + return exe + print("error: neither 'docker' nor 'podman' found on PATH", file=sys.stderr) + sys.exit(2) + + +def _resolve_ref(app: str, cfg: _config.AppConfig) -> str: + ref = app if ":" in app.split("/")[-1] else f"{app}:latest" + slashes = ref.count("/") + if slashes >= 2: + return ref + if slashes == 1: + return f"{cfg.registry}/{ref}" if cfg.registry else ref + if not cfg.registry or not cfg.namespace: + print( + "error: app is not fully qualified and no default registry/namespace configured. " + "Run `ria-app configure` or pass a full image reference (registry/namespace/app:tag).", + file=sys.stderr, + ) + sys.exit(2) + return f"{cfg.registry}/{cfg.namespace}/{ref}" + + +def _container_name(ref: str) -> str: + name = ref.rsplit("/", 1)[-1].split(":", 1)[0] + return f"ria-app-{name}" + + +def _inspect_labels(engine: str, ref: str) -> dict: + try: + out = subprocess.check_output( + [engine, "image", "inspect", "--format", "{{json .Config.Labels}}", ref], + stderr=subprocess.DEVNULL, + ) + except subprocess.CalledProcessError: + return {} + try: + return json.loads(out.decode().strip()) or {} + except json.JSONDecodeError: + return {} + + +def _hardware_flags(labels: dict) -> list[str]: + flags: list[str] = [] + profile = (labels.get(_LABEL_PROFILE) or "").lower() + hardware = (labels.get(_LABEL_HARDWARE) or "").lower() + hw_items = {h.strip() for h in hardware.split(",") if h.strip()} + + if "nvidia" in profile or "holoscan" in profile or "cuda" in profile: + flags += ["--gpus", "all"] + + needs_usb = hw_items & {"pluto", "rtlsdr", "hackrf", "bladerf"} + if needs_usb: + flags += ["--device", "/dev/bus/usb"] + + needs_net = hw_items & {"usrp", "thinkrf", "pluto"} + if needs_net: + flags += ["--net", "host"] + + return flags + + +def _cmd_configure(args: argparse.Namespace) -> int: + cfg = _config.load() + if args.registry: + cfg.registry = args.registry + if args.namespace: + cfg.namespace = args.namespace + path = _config.save(cfg) + print(f"Saved app config to {path}") + print(f" registry: {cfg.registry or '(unset)'}") + print(f" namespace: {cfg.namespace or '(unset)'}") + return 0 + + +def _cmd_pull(args: argparse.Namespace) -> int: + engine = _engine() + cfg = _config.load() + ref = _resolve_ref(args.app, cfg) + print(f"Pulling {ref}") + return subprocess.call([engine, "pull", ref]) + + +def _cmd_run(args: argparse.Namespace) -> int: + engine = _engine() + cfg = _config.load() + ref = _resolve_ref(args.app, cfg) + + if not _inspect_labels(engine, ref): + rc = subprocess.call([engine, "pull", ref]) + if rc != 0: + return rc + + labels = _inspect_labels(engine, ref) + hw_flags = _hardware_flags(labels) + + cmd = [engine, "run", "--rm"] + if not args.foreground: + cmd += ["-d"] + cmd += ["--name", args.name or _container_name(ref)] + cmd += hw_flags + + if args.config: + cmd += ["-v", f"{args.config}:/config/config.yaml:ro", "-e", "RIA_CONFIG=/config/config.yaml"] + + for env in args.env or []: + cmd += ["-e", env] + for vol in args.volume or []: + cmd += ["-v", vol] + for port in args.publish or []: + cmd += ["-p", port] + + cmd += list(args.docker_args or []) + cmd += [ref] + cmd += list(args.app_args or []) + + if args.dry_run: + print(" ".join(cmd)) + return 0 + + label_str = ", ".join(f"{k}={v}" for k, v in labels.items() if k.startswith("ria.")) or "(no ria.* labels)" + print(f"Running {ref} [{label_str}]") + if hw_flags: + print(f" auto flags: {' '.join(hw_flags)}") + return subprocess.call(cmd) + + +def _cmd_list(_args: argparse.Namespace) -> int: + engine = _engine() + return subprocess.call( + [ + engine, + "images", + "--filter", + f"label={_LABEL_APP}", + "--format", + "table {{.Repository}}:{{.Tag}}\t{{.ID}}\t{{.Size}}", + ] + ) + + +def _cmd_stop(args: argparse.Namespace) -> int: + engine = _engine() + name = args.name or _container_name(_resolve_ref(args.app, _config.load())) + return subprocess.call([engine, "stop", name]) + + +def _cmd_logs(args: argparse.Namespace) -> int: + engine = _engine() + name = args.name or _container_name(_resolve_ref(args.app, _config.load())) + cmd = [engine, "logs"] + if args.follow: + cmd += ["-f"] + cmd += [name] + return subprocess.call(cmd) + + +def main() -> None: + parser = argparse.ArgumentParser(prog="ria-app") + sub = parser.add_subparsers(dest="command", required=True) + + p_cfg = sub.add_parser("configure", help="Set default registry/namespace") + p_cfg.add_argument("--registry", default=None, help="Default container registry (e.g. registry.riahub.ai)") + p_cfg.add_argument("--namespace", default=None, help="Default namespace (e.g. qoherent)") + + p_pull = sub.add_parser("pull", help="Pull an app image") + p_pull.add_argument("app", help="App name or image reference") + + p_run = sub.add_parser("run", help="Run an app, auto-detecting hardware flags") + p_run.add_argument("app", help="App name or image reference") + p_run.add_argument("--name", default=None, help="Container name (default: ria-app-)") + p_run.add_argument("--config", default=None, help="Path to config.yaml to mount into the container") + p_run.add_argument("-e", "--env", action="append", help="Extra env var (KEY=VALUE)") + p_run.add_argument("-v", "--volume", action="append", help="Extra volume mount") + p_run.add_argument("-p", "--publish", action="append", help="Publish port") + p_run.add_argument("--foreground", "-F", action="store_true", help="Run in foreground (no -d)") + p_run.add_argument("--dry-run", action="store_true", help="Print the container command and exit") + p_run.add_argument("--docker-args", nargs=argparse.REMAINDER, help="Pass remaining args to docker/podman run") + p_run.add_argument("--app-args", nargs=argparse.REMAINDER, help="Pass remaining args to the app entrypoint") + + sub.add_parser("list", help="List locally cached RIA app images") + + p_stop = sub.add_parser("stop", help="Stop a running app") + p_stop.add_argument("app", help="App name or image reference") + p_stop.add_argument("--name", default=None, help="Container name override") + + p_logs = sub.add_parser("logs", help="Tail logs of a running app") + p_logs.add_argument("app", help="App name or image reference") + p_logs.add_argument("--name", default=None, help="Container name override") + p_logs.add_argument("-f", "--follow", action="store_true", help="Follow log output") + + args = parser.parse_args() + + dispatch = { + "configure": _cmd_configure, + "pull": _cmd_pull, + "run": _cmd_run, + "list": _cmd_list, + "stop": _cmd_stop, + "logs": _cmd_logs, + } + sys.exit(dispatch[args.command](args)) + + +if __name__ == "__main__": + main() diff --git a/src/ria_toolkit_oss/app/config.py b/src/ria_toolkit_oss/app/config.py new file mode 100644 index 0000000..2594761 --- /dev/null +++ b/src/ria_toolkit_oss/app/config.py @@ -0,0 +1,49 @@ +"""App runner configuration at ``~/.ria/toolkit.json``. + +Schema:: + + { + "registry": "registry.riahub.ai", + "namespace": "qoherent" + } +""" + +from __future__ import annotations + +import json +import os +from dataclasses import asdict, dataclass +from pathlib import Path + +_DEFAULT_PATH = Path(os.environ.get("RIA_TOOLKIT_CONFIG", str(Path.home() / ".ria" / "toolkit.json"))) + + +@dataclass +class AppConfig: + registry: str = "" + namespace: str = "" + + +def default_path() -> Path: + return _DEFAULT_PATH + + +def load(path: Path | None = None) -> AppConfig: + p = path or _DEFAULT_PATH + if not p.exists(): + return AppConfig( + registry=os.environ.get("RIA_REGISTRY", ""), + namespace=os.environ.get("RIA_NAMESPACE", ""), + ) + data = json.loads(p.read_text()) + return AppConfig( + registry=data.get("registry", "") or os.environ.get("RIA_REGISTRY", ""), + namespace=data.get("namespace", "") or os.environ.get("RIA_NAMESPACE", ""), + ) + + +def save(cfg: AppConfig, path: Path | None = None) -> Path: + p = path or _DEFAULT_PATH + p.parent.mkdir(parents=True, exist_ok=True) + p.write_text(json.dumps(asdict(cfg), indent=2)) + return p From 20fe86d399fa9c11e8904c3aa1eb2828a7ac9e39 Mon Sep 17 00:00:00 2001 From: jonny Date: Tue, 14 Apr 2026 13:18:34 -0400 Subject: [PATCH 04/12] allow sudo calls --- src/ria_toolkit_oss/app/cli.py | 51 ++++++++++++++++++++----------- src/ria_toolkit_oss/app/config.py | 2 ++ 2 files changed, 35 insertions(+), 18 deletions(-) diff --git a/src/ria_toolkit_oss/app/cli.py b/src/ria_toolkit_oss/app/cli.py index c70eb16..6cd0c1c 100644 --- a/src/ria_toolkit_oss/app/cli.py +++ b/src/ria_toolkit_oss/app/cli.py @@ -32,10 +32,11 @@ _LABEL_HARDWARE = "ria.hardware" _LABEL_APP = "ria.app" -def _engine() -> str: +def _engine(cfg: _config.AppConfig, sudo_override: bool = False) -> list[str]: for exe in ("docker", "podman"): if shutil.which(exe): - return exe + use_sudo = sudo_override or cfg.sudo + return (["sudo", exe] if use_sudo else [exe]) print("error: neither 'docker' nor 'podman' found on PATH", file=sys.stderr) sys.exit(2) @@ -62,10 +63,10 @@ def _container_name(ref: str) -> str: return f"ria-app-{name}" -def _inspect_labels(engine: str, ref: str) -> dict: +def _inspect_labels(engine: list[str], ref: str) -> dict: try: out = subprocess.check_output( - [engine, "image", "inspect", "--format", "{{json .Config.Labels}}", ref], + [*engine, "image", "inspect", "--format", "{{json .Config.Labels}}", ref], stderr=subprocess.DEVNULL, ) except subprocess.CalledProcessError: @@ -102,35 +103,38 @@ def _cmd_configure(args: argparse.Namespace) -> int: cfg.registry = args.registry if args.namespace: cfg.namespace = args.namespace + if args.sudo is not None: + cfg.sudo = args.sudo path = _config.save(cfg) print(f"Saved app config to {path}") print(f" registry: {cfg.registry or '(unset)'}") print(f" namespace: {cfg.namespace or '(unset)'}") + print(f" sudo: {cfg.sudo}") return 0 def _cmd_pull(args: argparse.Namespace) -> int: - engine = _engine() cfg = _config.load() + engine = _engine(cfg, args.sudo) ref = _resolve_ref(args.app, cfg) print(f"Pulling {ref}") - return subprocess.call([engine, "pull", ref]) + return subprocess.call([*engine, "pull", ref]) def _cmd_run(args: argparse.Namespace) -> int: - engine = _engine() cfg = _config.load() + engine = _engine(cfg, args.sudo) ref = _resolve_ref(args.app, cfg) if not _inspect_labels(engine, ref): - rc = subprocess.call([engine, "pull", ref]) + rc = subprocess.call([*engine, "pull", ref]) if rc != 0: return rc labels = _inspect_labels(engine, ref) hw_flags = _hardware_flags(labels) - cmd = [engine, "run", "--rm"] + cmd = [*engine, "run", "--rm"] if not args.foreground: cmd += ["-d"] cmd += ["--name", args.name or _container_name(ref)] @@ -161,11 +165,12 @@ def _cmd_run(args: argparse.Namespace) -> int: return subprocess.call(cmd) -def _cmd_list(_args: argparse.Namespace) -> int: - engine = _engine() +def _cmd_list(args: argparse.Namespace) -> int: + cfg = _config.load() + engine = _engine(cfg, args.sudo) return subprocess.call( [ - engine, + *engine, "images", "--filter", f"label={_LABEL_APP}", @@ -176,15 +181,17 @@ def _cmd_list(_args: argparse.Namespace) -> int: def _cmd_stop(args: argparse.Namespace) -> int: - engine = _engine() - name = args.name or _container_name(_resolve_ref(args.app, _config.load())) - return subprocess.call([engine, "stop", name]) + cfg = _config.load() + engine = _engine(cfg, args.sudo) + name = args.name or _container_name(_resolve_ref(args.app, cfg)) + return subprocess.call([*engine, "stop", name]) def _cmd_logs(args: argparse.Namespace) -> int: - engine = _engine() - name = args.name or _container_name(_resolve_ref(args.app, _config.load())) - cmd = [engine, "logs"] + cfg = _config.load() + engine = _engine(cfg, args.sudo) + name = args.name or _container_name(_resolve_ref(args.app, cfg)) + cmd = [*engine, "logs"] if args.follow: cmd += ["-f"] cmd += [name] @@ -193,11 +200,19 @@ def _cmd_logs(args: argparse.Namespace) -> int: def main() -> None: parser = argparse.ArgumentParser(prog="ria-app") + parser.add_argument("--sudo", action="store_true", default=False, help="Run docker/podman via sudo") sub = parser.add_subparsers(dest="command", required=True) p_cfg = sub.add_parser("configure", help="Set default registry/namespace") p_cfg.add_argument("--registry", default=None, help="Default container registry (e.g. registry.riahub.ai)") p_cfg.add_argument("--namespace", default=None, help="Default namespace (e.g. qoherent)") + p_cfg.add_argument( + "--sudo", + dest="sudo", + action=argparse.BooleanOptionalAction, + default=None, + help="Persist sudo default (--sudo / --no-sudo)", + ) p_pull = sub.add_parser("pull", help="Pull an app image") p_pull.add_argument("app", help="App name or image reference") diff --git a/src/ria_toolkit_oss/app/config.py b/src/ria_toolkit_oss/app/config.py index 2594761..8bff807 100644 --- a/src/ria_toolkit_oss/app/config.py +++ b/src/ria_toolkit_oss/app/config.py @@ -22,6 +22,7 @@ _DEFAULT_PATH = Path(os.environ.get("RIA_TOOLKIT_CONFIG", str(Path.home() / ".ri class AppConfig: registry: str = "" namespace: str = "" + sudo: bool = False def default_path() -> Path: @@ -39,6 +40,7 @@ def load(path: Path | None = None) -> AppConfig: return AppConfig( registry=data.get("registry", "") or os.environ.get("RIA_REGISTRY", ""), namespace=data.get("namespace", "") or os.environ.get("RIA_NAMESPACE", ""), + sudo=bool(data.get("sudo", False)) or os.environ.get("RIA_DOCKER_SUDO", "") not in ("", "0", "false"), ) From b955256479f21bebd4594b373a19b31f359bffee Mon Sep 17 00:00:00 2001 From: jonny Date: Thu, 16 Apr 2026 11:13:43 -0400 Subject: [PATCH 05/12] Pluto TX streaming functionality base --- src/ria_toolkit_oss/agent/cli.py | 62 ++- src/ria_toolkit_oss/agent/config.py | 33 +- src/ria_toolkit_oss/agent/hardware.py | 26 +- src/ria_toolkit_oss/agent/streamer.py | 607 ++++++++++++++++++++++--- src/ria_toolkit_oss/agent/ws_client.py | 17 +- src/ria_toolkit_oss/app/cli.py | 37 +- tests/agent/test_cli_tx.py | 111 +++++ tests/agent/test_config.py | 30 ++ tests/agent/test_disconnect.py | 8 +- tests/agent/test_full_duplex.py | 133 ++++++ tests/agent/test_hardware.py | 17 + tests/agent/test_integration_tx.py | 144 ++++++ tests/agent/test_streamer.py | 88 +++- tests/agent/test_streamer_tx.py | 133 ++++++ tests/agent/test_tx_safety.py | 167 +++++++ tests/agent/test_tx_underrun.py | 136 ++++++ tests/agent/test_ws_client.py | 103 +++++ 17 files changed, 1752 insertions(+), 100 deletions(-) create mode 100644 tests/agent/test_cli_tx.py create mode 100644 tests/agent/test_full_duplex.py create mode 100644 tests/agent/test_integration_tx.py create mode 100644 tests/agent/test_streamer_tx.py create mode 100644 tests/agent/test_tx_safety.py create mode 100644 tests/agent/test_tx_underrun.py diff --git a/src/ria_toolkit_oss/agent/cli.py b/src/ria_toolkit_oss/agent/cli.py index 0b06e72..6b88473 100644 --- a/src/ria_toolkit_oss/agent/cli.py +++ b/src/ria_toolkit_oss/agent/cli.py @@ -5,8 +5,8 @@ Subcommands: - ``ria-agent run [legacy args]`` — legacy long-poll NodeAgent (unchanged). - ``ria-agent stream`` — new WebSocket-based IQ streamer. - ``ria-agent detect`` — print SDR drivers whose modules import cleanly. -- ``ria-agent register --url URL --token TOKEN`` — save credentials to - ``~/.ria/agent.json``. +- ``ria-agent register --hub URL --api-key KEY`` — register with the hub and + save credentials (and optional TX interlocks) to ``~/.ria/agent.json``. Invoking ``ria-agent`` with no subcommand falls through to the legacy long-poll behavior for back-compatibility with existing deployments. @@ -69,9 +69,27 @@ def _cmd_register(args: argparse.Namespace) -> int: if args.name: cfg.name = args.name cfg.insecure = bool(args.insecure) + cfg.tx_enabled = bool(getattr(args, "allow_tx", False)) + if (v := getattr(args, "tx_max_gain_db", None)) is not None: + cfg.tx_max_gain_db = float(v) + if (v := getattr(args, "tx_max_duration_s", None)) is not None: + cfg.tx_max_duration_s = float(v) + freq_ranges = getattr(args, "tx_freq_range", None) or [] + if freq_ranges: + cfg.tx_allowed_freq_ranges = [[float(lo), float(hi)] for lo, hi in freq_ranges] path = _config.save(cfg) print(f"Registered agent: {agent_id}") + if cfg.tx_enabled: + caps: list[str] = [] + if cfg.tx_max_gain_db is not None: + caps.append(f"gain<={cfg.tx_max_gain_db} dB") + if cfg.tx_max_duration_s is not None: + caps.append(f"duration<={cfg.tx_max_duration_s} s") + if cfg.tx_allowed_freq_ranges: + caps.append(f"freq in {cfg.tx_allowed_freq_ranges}") + tail = f" ({', '.join(caps)})" if caps else "" + print(f"TX enabled{tail}") print(f"Credentials saved to {path}") return 0 @@ -85,8 +103,10 @@ def _cmd_stream(args: argparse.Namespace) -> int: if not url: print("error: --url is required (or run `ria-agent register` first)", file=sys.stderr) return 2 + if getattr(args, "allow_tx", False): + cfg.tx_enabled = True try: - asyncio.run(run_streamer(url, token)) + asyncio.run(run_streamer(url, token, cfg=cfg)) except KeyboardInterrupt: pass return 0 @@ -123,11 +143,47 @@ def main() -> None: p_reg.add_argument("--api-key", dest="api_key", required=True, help="Hub API key") p_reg.add_argument("--name", default=None, help="Human-friendly agent name") p_reg.add_argument("--insecure", action="store_true", help="Skip TLS verification") + p_reg.add_argument( + "--allow-tx", + dest="allow_tx", + action="store_true", + help="Opt this agent in to TX (required for any transmission from the hub)", + ) + p_reg.add_argument( + "--tx-max-gain-db", + dest="tx_max_gain_db", + type=float, + default=None, + help="Reject tx_start frames whose tx_gain exceeds this cap (dB)", + ) + p_reg.add_argument( + "--tx-max-duration-s", + dest="tx_max_duration_s", + type=float, + default=None, + help="Auto-stop any TX session after this many seconds", + ) + p_reg.add_argument( + "--tx-freq-range", + dest="tx_freq_range", + type=float, + nargs=2, + action="append", + metavar=("LO", "HI"), + default=None, + help="Allowed TX center-frequency range in Hz (repeat for multiple bands)", + ) p_stream = sub.add_parser("stream", help="Run the WebSocket IQ streamer") p_stream.add_argument("--url", default=None, help="Override WebSocket URL") p_stream.add_argument("--token", default=None, help="Override bearer token") p_stream.add_argument("--log-level", default="INFO") + p_stream.add_argument( + "--allow-tx", + dest="allow_tx", + action="store_true", + help="Runtime override: enable TX for this process without writing config", + ) # Unknown extras are forwarded to the legacy CLI when command == "run". args, extras = parser.parse_known_args(argv) diff --git a/src/ria_toolkit_oss/agent/config.py b/src/ria_toolkit_oss/agent/config.py index d1f0e00..431094a 100644 --- a/src/ria_toolkit_oss/agent/config.py +++ b/src/ria_toolkit_oss/agent/config.py @@ -7,7 +7,11 @@ Schema:: "agent_id": "agent-abc123", "token": "rha_xxxx", "name": "lab-bench-1", - "insecure": false + "insecure": false, + "tx_enabled": false, + "tx_max_gain_db": null, + "tx_max_duration_s": null, + "tx_allowed_freq_ranges": null } """ @@ -18,7 +22,8 @@ import os from dataclasses import asdict, dataclass, field from pathlib import Path -_DEFAULT_PATH = Path(os.environ.get("RIA_AGENT_CONFIG", str(Path.home() / ".ria" / "agent.json"))) +def _resolve_default_path() -> Path: + return Path(os.environ.get("RIA_AGENT_CONFIG", str(Path.home() / ".ria" / "agent.json"))) @dataclass @@ -29,15 +34,29 @@ class AgentConfig: name: str = "" insecure: bool = False api_key: str = "" + tx_enabled: bool = False + tx_max_gain_db: float | None = None + tx_max_duration_s: float | None = None + tx_allowed_freq_ranges: list[list[float]] | None = None extra: dict = field(default_factory=dict) def default_path() -> Path: - return _DEFAULT_PATH + return _resolve_default_path() + + +def _coerce_ranges(raw) -> list[list[float]] | None: + if raw is None: + return None + out: list[list[float]] = [] + for pair in raw: + lo, hi = pair + out.append([float(lo), float(hi)]) + return out def load(path: Path | None = None) -> AgentConfig: - p = path or _DEFAULT_PATH + p = path or _resolve_default_path() if not p.exists(): return AgentConfig() data = json.loads(p.read_text()) @@ -50,12 +69,16 @@ def load(path: Path | None = None) -> AgentConfig: name=data.get("name", ""), insecure=bool(data.get("insecure", False)), api_key=data.get("api_key", ""), + tx_enabled=bool(data.get("tx_enabled", False)), + tx_max_gain_db=(float(v) if (v := data.get("tx_max_gain_db")) is not None else None), + tx_max_duration_s=(float(v) if (v := data.get("tx_max_duration_s")) is not None else None), + tx_allowed_freq_ranges=_coerce_ranges(data.get("tx_allowed_freq_ranges")), extra=extra, ) def save(cfg: AgentConfig, path: Path | None = None) -> Path: - p = path or _DEFAULT_PATH + p = path or _resolve_default_path() p.parent.mkdir(parents=True, exist_ok=True) data = asdict(cfg) extra = data.pop("extra", {}) or {} diff --git a/src/ria_toolkit_oss/agent/hardware.py b/src/ria_toolkit_oss/agent/hardware.py index 417bf1c..32a65e5 100644 --- a/src/ria_toolkit_oss/agent/hardware.py +++ b/src/ria_toolkit_oss/agent/hardware.py @@ -4,19 +4,41 @@ from __future__ import annotations from ria_toolkit_oss.sdr import detect_available +from .config import AgentConfig + def available_devices() -> list[str]: """Return a sorted list of device names whose driver modules import cleanly.""" return sorted(detect_available().keys()) -def heartbeat_payload(status: str = "idle", app_id: str | None = None) -> dict: - """Build the JSON body of a periodic heartbeat frame.""" +def heartbeat_payload( + status: str = "idle", + app_id: str | None = None, + *, + cfg: AgentConfig | None = None, + sessions: dict | None = None, +) -> dict: + """Build the JSON body of a periodic heartbeat frame. + + *cfg* drives the ``capabilities`` list and the ``tx_enabled`` flag. If not + supplied, the heartbeat advertises RX-only with ``tx_enabled=False`` — + matching the pre-TX shape. + """ + c = cfg or AgentConfig() + capabilities = ["rx"] + if c.tx_enabled: + capabilities.append("tx") + payload: dict = { "type": "heartbeat", "hardware": available_devices(), "status": status, + "capabilities": capabilities, + "tx_enabled": bool(c.tx_enabled), } if app_id: payload["app_id"] = app_id + if sessions: + payload["sessions"] = sessions return payload diff --git a/src/ria_toolkit_oss/agent/streamer.py b/src/ria_toolkit_oss/agent/streamer.py index 4d89743..8570a73 100644 --- a/src/ria_toolkit_oss/agent/streamer.py +++ b/src/ria_toolkit_oss/agent/streamer.py @@ -1,20 +1,33 @@ -"""Thin IQ-streaming agent. +"""IQ-streaming agent. Listens for control messages from the RIA Hub over a persistent WebSocket. -When the server sends ``start``, opens the SDR described in ``radio_config``, -loops over ``sdr.rx(buffer_size)``, and sends each buffer as raw -interleaved float32 bytes. ``stop`` closes the SDR; ``configure`` applies -parameter updates at the next capture boundary. +Supports: + +- An **RX session** (hub sends ``start``/``stop``/``configure``; agent opens + the SDR, loops ``sdr.rx()`` and ships raw interleaved float32 IQ). +- A **TX session** (hub sends ``tx_start``/``tx_stop``/``tx_configure`` plus + binary IQ frames; agent feeds them into ``sdr._stream_tx``). Phase 3 wires + up the session plumbing and rejects TX when ``cfg.tx_enabled`` is False; + Phase 4 implements the full TX loop. + +Both sessions can run concurrently on the same physical SDR (FDD) — a +ref-counted SDR registry shares one driver instance when RX and TX name the +same ``(device, identifier)``. """ from __future__ import annotations import asyncio import logging +import queue +import threading +import time +from dataclasses import dataclass, field from typing import Any import numpy as np +from .config import AgentConfig from .hardware import heartbeat_payload from .ws_client import WsClient @@ -23,6 +36,98 @@ logger = logging.getLogger("ria_agent.streamer") _DEFAULT_BUFFER_SIZE = 1024 +# --------------------------------------------------------------------------- +# Session dataclasses + + +@dataclass +class RxSession: + app_id: str + sdr: Any + device_key: tuple[str, str | None] + buffer_size: int + task: asyncio.Task | None = None + pending_config: dict = field(default_factory=dict) + + +@dataclass +class TxSession: + app_id: str + sdr: Any + device_key: tuple[str, str | None] + buffer_size: int + task: Any = None # concurrent.futures.Future from run_in_executor + pending_config: dict = field(default_factory=dict) + underrun_policy: str = "pause" + last_buffer: np.ndarray | None = None + stop_event: threading.Event = field(default_factory=threading.Event) + started_at: float = 0.0 + max_duration_s: float | None = None + state: str = "armed" + # Thread-safe queue of inbound interleaved-float32 IQ frames. Bounded so + # hub-side over-production triggers WS backpressure rather than memory + # growth in the agent. + in_queue: "queue.Queue[bytes]" = field(default_factory=lambda: queue.Queue(maxsize=8)) + # Set by the TX callback when it hits an underrun while policy=="pause"; + # asyncio side flips the session state and emits tx_status. + underrun_flag: threading.Event = field(default_factory=threading.Event) + + +# --------------------------------------------------------------------------- +# SDR registry (ref-counted so one Pluto handle serves RX + TX simultaneously) + + +class _SdrRegistry: + def __init__(self, factory): + self._factory = factory + self._instances: dict[tuple[str, str | None], tuple[Any, int]] = {} + self._lock = threading.Lock() + + def acquire(self, device: str, identifier: str | None) -> tuple[Any, tuple[str, str | None]]: + key = (device, identifier) + with self._lock: + if key in self._instances: + sdr, rc = self._instances[key] + self._instances[key] = (sdr, rc + 1) + return sdr, key + # Build outside the lock: driver init can be slow and we don't want to + # block concurrent releases on unrelated devices. + sdr = self._factory(device, identifier) + with self._lock: + if key in self._instances: + # Raced another acquirer; discard our duplicate and share theirs. + other_sdr, rc = self._instances[key] + try: + sdr.close() + except Exception: + pass + self._instances[key] = (other_sdr, rc + 1) + return other_sdr, key + self._instances[key] = (sdr, 1) + return sdr, key + + def release(self, key: tuple[str, str | None]) -> bool: + """Decrement refcount. Returns True if the caller owns the last reference + and should close the SDR.""" + with self._lock: + sdr, rc = self._instances.get(key, (None, 0)) + if sdr is None: + return False + if rc <= 1: + del self._instances[key] + return True + self._instances[key] = (sdr, rc - 1) + return False + + def refcount(self, key: tuple[str, str | None]) -> int: + with self._lock: + return self._instances.get(key, (None, 0))[1] + + +# --------------------------------------------------------------------------- +# Streamer + + class Streamer: """Main streamer loop. @@ -31,103 +136,186 @@ class Streamer: ws: Connected :class:`WsClient`. sdr_factory: - Callable ``(device, identifier) -> SDR``. Defaults to - :func:`ria_toolkit_oss.sdr.get_sdr_device`. Injectable for tests. + Callable ``(device, identifier) -> SDR``. Defaults to the helper in + :mod:`ria_toolkit_oss.sdr`. Injectable for tests. + cfg: + :class:`AgentConfig` for interlocks (``tx_enabled`` and caps) and + heartbeat capabilities. Defaults to an empty ``AgentConfig()`` which + leaves TX disabled. """ - def __init__(self, ws: WsClient, sdr_factory=None) -> None: + def __init__( + self, + ws, + sdr_factory=None, + cfg: AgentConfig | None = None, + ) -> None: self.ws = ws - self._sdr_factory = sdr_factory - self._app_id: str | None = None - self._sdr: Any = None - self._pending_config: dict = {} - self._capture_task: asyncio.Task | None = None - self._status = "idle" + self._cfg = cfg or AgentConfig() + self._registry = _SdrRegistry(sdr_factory or _default_sdr_factory) + self._rx: RxSession | None = None + self._tx: TxSession | None = None + # Pending radio_config accepted via ``configure`` before ``start``. + self._standalone_pending_config: dict = {} + # Cached asyncio event loop, set the first time a handler runs. Used + # to schedule async callbacks from the TX executor thread. + self._loop: asyncio.AbstractEventLoop | None = None + + # ------------------------------------------------------------------ + # Back-compat read-only shims for callers that check ``._sdr`` etc. + # Writes to these attributes are not supported — use the session objects. + + @property + def _sdr(self): + return self._rx.sdr if self._rx is not None else None + + @property + def _pending_config(self) -> dict: + return self._rx.pending_config if self._rx is not None else self._standalone_pending_config # ------------------------------------------------------------------ # WsClient wiring def build_heartbeat(self) -> dict: - return heartbeat_payload(status=self._status, app_id=self._app_id) + status = "streaming" if (self._rx is not None or self._tx is not None) else "idle" + app_id: str | None = None + if self._rx is not None: + app_id = self._rx.app_id + elif self._tx is not None: + app_id = self._tx.app_id + + sessions: dict[str, dict] = {} + if self._rx is not None: + sessions["rx"] = {"app_id": self._rx.app_id, "state": "streaming"} + if self._tx is not None: + sessions["tx"] = {"app_id": self._tx.app_id, "state": self._tx.state} + + return heartbeat_payload( + status=status, + app_id=app_id, + cfg=self._cfg, + sessions=sessions or None, + ) async def on_message(self, msg: dict) -> None: t = msg.get("type") - if t == "start": - await self._handle_start(msg) - elif t == "stop": - await self._handle_stop(msg) - elif t == "configure": - self._pending_config.update(msg.get("radio_config") or {}) - logger.debug("Queued configure: %s", self._pending_config) - else: + handler = { + "start": self._handle_rx_start, + "stop": self._handle_rx_stop, + "configure": self._handle_rx_configure, + "tx_start": self._handle_tx_start, + "tx_stop": self._handle_tx_stop, + "tx_configure": self._handle_tx_configure, + }.get(t) + if handler is None: logger.warning("Unknown server message type: %r", t) + return + await handler(msg) - # ------------------------------------------------------------------ - async def _handle_start(self, msg: dict) -> None: - if self._capture_task is not None and not self._capture_task.done(): + async def on_binary(self, data: bytes) -> None: + tx = self._tx + if tx is None: + logger.debug("Dropping %d-byte binary frame: no TX session", len(data)) + return + # Backpressure: if the TX queue is full, await briefly so the hub's + # ``await ws.send`` throttles naturally via TCP. We don't block + # indefinitely — a 2s stall means something else is wrong. + loop = asyncio.get_running_loop() + try: + await loop.run_in_executor(None, lambda: tx.in_queue.put(data, timeout=2.0)) + except queue.Full: + logger.warning("TX queue stalled; dropping frame") + + # ================================================================== + # RX + + async def _handle_rx_start(self, msg: dict) -> None: + if self._rx is not None: logger.warning("start received while already streaming — ignoring") return - self._app_id = msg.get("app_id") + app_id = msg.get("app_id") or "" radio_config = dict(msg.get("radio_config") or {}) device = radio_config.pop("device", None) identifier = radio_config.pop("identifier", None) buffer_size = int(radio_config.pop("buffer_size", _DEFAULT_BUFFER_SIZE)) if not device: - await self._send_error("start missing radio_config.device") + await self._send_error(app_id, "start missing radio_config.device") return try: - factory = self._sdr_factory or _default_sdr_factory - self._sdr = factory(device, identifier) - _apply_sdr_config(self._sdr, radio_config) + sdr, device_key = self._registry.acquire(device, identifier) + _apply_sdr_config(sdr, radio_config) except Exception as exc: logger.exception("Failed to open SDR %r", device) - await self._send_error(f"SDR init failed: {exc}") + await self._send_error(app_id, f"SDR init failed: {exc}") return - self._status = "streaming" - await self._send_status("streaming") - self._capture_task = asyncio.create_task( - self._capture_loop(buffer_size), name="ria-streamer-capture" + # Inherit any pending config that was queued before start. + pending = dict(self._standalone_pending_config) + self._standalone_pending_config = {} + + session = RxSession( + app_id=app_id, + sdr=sdr, + device_key=device_key, + buffer_size=buffer_size, + pending_config=pending, + ) + self._rx = session + await self._send_status("streaming", app_id) + session.task = asyncio.create_task( + self._capture_loop(session), name="ria-streamer-capture" ) - async def _handle_stop(self, msg: dict) -> None: - if self._capture_task is not None: - self._capture_task.cancel() + async def _handle_rx_stop(self, msg: dict) -> None: + session = self._rx + if session is None: + return + if session.task is not None: + session.task.cancel() try: - await self._capture_task + await session.task except (asyncio.CancelledError, Exception): pass - self._capture_task = None - self._close_sdr() - self._app_id = None - self._status = "idle" - await self._send_status("idle") + self._close_session_sdr(session) + app_id = session.app_id + self._rx = None + await self._send_status("idle", app_id) - async def _capture_loop(self, buffer_size: int) -> None: + async def _handle_rx_configure(self, msg: dict) -> None: + cfg = dict(msg.get("radio_config") or {}) + if self._rx is not None: + self._rx.pending_config.update(cfg) + else: + self._standalone_pending_config.update(cfg) + logger.debug("Queued configure: %s", cfg) + + async def _capture_loop(self, session: RxSession) -> None: loop = asyncio.get_running_loop() try: while True: - if self._pending_config: - cfg = self._pending_config - self._pending_config = {} + if session.pending_config: + cfg = session.pending_config + session.pending_config = {} try: - _apply_sdr_config(self._sdr, cfg) + _apply_sdr_config(session.sdr, cfg) except Exception as exc: logger.warning("Applying configure failed: %s", exc) try: - samples = await loop.run_in_executor(None, self._sdr.rx, buffer_size) + samples = await loop.run_in_executor( + None, session.sdr.rx, session.buffer_size + ) except Exception as exc: from ria_toolkit_oss.sdr import SdrDisconnectedError if isinstance(exc, SdrDisconnectedError): logger.warning("SDR disconnected: %s", exc) - await self._send_error(f"SDR disconnected: {exc}") + await self._send_error(session.app_id, f"SDR disconnected: {exc}") else: logger.exception("SDR rx error") - await self._send_error(f"SDR capture failed: {exc}") + await self._send_error(session.app_id, f"SDR capture failed: {exc}") break payload = _samples_to_interleaved_float32(samples) @@ -139,29 +327,305 @@ class Streamer: except asyncio.CancelledError: raise finally: - self._close_sdr() + self._close_session_sdr(session) + # If the loop died on its own (e.g. SDR disconnect), clear the + # session handle so future ``start`` messages can proceed. + if self._rx is session: + self._rx = None - def _close_sdr(self) -> None: - if self._sdr is None: + # ================================================================== + # TX + + async def _handle_tx_start(self, msg: dict) -> None: + app_id = msg.get("app_id") or "" + radio_config = dict(msg.get("radio_config") or {}) + + # --- interlocks (agent-enforced; never trust the hub alone) --- + if not self._cfg.tx_enabled: + await self._send_tx_status(app_id, "error", "tx disabled on this agent") return + tx_gain = radio_config.get("tx_gain") + if ( + self._cfg.tx_max_gain_db is not None + and tx_gain is not None + and float(tx_gain) > float(self._cfg.tx_max_gain_db) + ): + await self._send_tx_status( + app_id, + "error", + f"tx_gain {tx_gain} exceeds cap {self._cfg.tx_max_gain_db}", + ) + return + tx_freq = radio_config.get("tx_center_frequency") + if self._cfg.tx_allowed_freq_ranges and tx_freq is not None: + f = float(tx_freq) + if not any(float(lo) <= f <= float(hi) for lo, hi in self._cfg.tx_allowed_freq_ranges): + await self._send_tx_status( + app_id, + "error", + f"tx_center_frequency {tx_freq} outside allowed ranges", + ) + return + + if self._tx is not None: + await self._send_tx_status(app_id, "error", "tx already active on this agent") + return + + # --- device --- + device = radio_config.pop("device", None) + identifier = radio_config.pop("identifier", None) + buffer_size = int(radio_config.pop("buffer_size", _DEFAULT_BUFFER_SIZE)) + underrun_policy = str(radio_config.pop("underrun_policy", "pause")) + if underrun_policy not in ("pause", "zero", "repeat"): + await self._send_tx_status( + app_id, "error", f"invalid underrun_policy {underrun_policy!r}" + ) + return + if not device: + await self._send_tx_status(app_id, "error", "tx_start missing radio_config.device") + return + + device_key: tuple[str, str | None] | None = None + sdr: Any = None try: - self._sdr.close() + sdr, device_key = self._registry.acquire(device, identifier) + _apply_sdr_config(sdr, radio_config) + # Only call init_tx when the hub supplied the three required + # parameters. Drivers that gate _stream_tx on _tx_initialized + # (e.g. Pluto) need this; drivers that don't (e.g. Mock) tolerate + # its absence. + init_args = { + k: radio_config.get(f"tx_{k}") + for k in ("sample_rate", "center_frequency", "gain") + } + if hasattr(sdr, "init_tx") and all(v is not None for v in init_args.values()): + sdr.init_tx( + sample_rate=init_args["sample_rate"], + center_frequency=init_args["center_frequency"], + gain=init_args["gain"], + channel=radio_config.get("tx_channel", 0), + gain_mode=radio_config.get("tx_gain_mode", "manual"), + ) + except Exception as exc: + if device_key is not None: + if self._registry.release(device_key): + try: + sdr.close() + except Exception: + pass + logger.exception("Failed to init TX on %r", device) + await self._send_tx_status(app_id, "error", f"tx init failed: {exc}") + return + + self._loop = asyncio.get_running_loop() + session = TxSession( + app_id=app_id, + sdr=sdr, + device_key=device_key, + buffer_size=buffer_size, + underrun_policy=underrun_policy, + started_at=time.monotonic(), + max_duration_s=self._cfg.tx_max_duration_s, + ) + self._tx = session + await self._send_tx_status(app_id, "armed") + session.task = self._loop.run_in_executor(None, self._tx_executor_body, session) + # Spawn a small watchdog that transitions armed → transmitting when + # the first buffer has been consumed, and surfaces underrun / max- + # duration terminations back to the hub. + asyncio.create_task(self._tx_watchdog(session)) + + async def _handle_tx_stop(self, msg: dict) -> None: + session = self._tx + if session is None: + return + app_id = session.app_id + session.stop_event.set() + try: + session.sdr.pause_tx() + except Exception: + logger.debug("pause_tx raised during stop", exc_info=True) + # Wake the executor thread if it's blocked on ``queue.get``. + self._drain_tx_queue(session) + if session.task is not None: + try: + await asyncio.wait_for(asyncio.wrap_future(session.task), timeout=1.5) + except asyncio.TimeoutError: + logger.warning("TX executor did not exit within 1.5s after stop") + except Exception: + logger.debug("TX executor raised on shutdown", exc_info=True) + self._close_session_sdr(session) + self._tx = None + await self._send_tx_status(app_id, "done") + + async def _handle_tx_configure(self, msg: dict) -> None: + if self._tx is None: + return + self._tx.pending_config.update(msg.get("radio_config") or {}) + + # ------------------------------------------------------------------ + # TX executor & watchdog + + def _tx_executor_body(self, session: TxSession) -> None: + try: + session.sdr._stream_tx(lambda n: self._tx_callback(session, n)) + except Exception: + logger.exception("TX stream crashed") + self._schedule(self._send_tx_status(session.app_id, "error", "tx stream crashed")) + + def _tx_callback(self, session: TxSession, num_samples) -> np.ndarray: + n = int(num_samples) + # Honor stop requests: return silence one last time and let the driver + # exit its loop on the next iteration (pause_tx flips _enable_tx). + if session.stop_event.is_set(): + return _silence(n) + + # Max-duration watchdog. + if ( + session.max_duration_s is not None + and (time.monotonic() - session.started_at) >= float(session.max_duration_s) + ): + session.stop_event.set() + try: + session.sdr.pause_tx() + except Exception: + pass + self._schedule(self._send_tx_status(session.app_id, "done", "max duration reached")) + return _silence(n) + + # Apply queued configure at buffer boundary. + if session.pending_config: + cfg = session.pending_config + session.pending_config = {} + try: + _apply_sdr_config(session.sdr, cfg) + except Exception as exc: + logger.debug("tx_configure apply failed: %s", exc) + + try: + raw = session.in_queue.get(timeout=0.1) + except queue.Empty: + return self._underrun_fill(session, n) + + arr = np.frombuffer(raw, dtype=np.float32) + if arr.size < 2 or arr.size % 2 != 0: + logger.warning("Malformed TX frame: %d floats (must be non-zero even count)", arr.size) + return self._underrun_fill(session, n) + samples = (arr[0::2].astype(np.complex64) + 1j * arr[1::2].astype(np.complex64)) + if samples.size < n: + out = np.zeros(n, dtype=np.complex64) + out[: samples.size] = samples + session.last_buffer = out + return out + if samples.size > n: + samples = samples[:n] + session.last_buffer = samples + if session.state == "armed": + session.state = "transmitting" + self._schedule(self._send_tx_status(session.app_id, "transmitting")) + return samples + + def _underrun_fill(self, session: TxSession, n: int) -> np.ndarray: + policy = session.underrun_policy + if policy == "zero": + return _silence(n) + if policy == "repeat" and session.last_buffer is not None: + buf = session.last_buffer + if buf.size == n: + return buf + if buf.size > n: + return buf[:n].copy() + out = np.zeros(n, dtype=np.complex64) + out[: buf.size] = buf + return out + # "pause" policy (default) or "repeat" before any buffer arrived. + if not session.underrun_flag.is_set(): + session.underrun_flag.set() + session.stop_event.set() + try: + session.sdr.pause_tx() except Exception: pass - self._sdr = None + return _silence(n) - async def _send_status(self, status: str) -> None: + async def _tx_watchdog(self, session: TxSession) -> None: + # Poll the underrun flag so we can emit status + tear down cleanly + # when the callback flips the flag from the executor thread. Check + # underrun_flag before stop_event, since the "pause" path sets both. + while session is self._tx: + if session.underrun_flag.is_set(): + await self._send_tx_status(session.app_id, "underrun") + await self._teardown_tx_after_underrun(session) + return + if session.stop_event.is_set(): + return + await asyncio.sleep(0.05) + + async def _teardown_tx_after_underrun(self, session: TxSession) -> None: + if self._tx is not session: + return + self._drain_tx_queue(session) + if session.task is not None: + try: + await asyncio.wait_for(asyncio.wrap_future(session.task), timeout=1.0) + except asyncio.TimeoutError: + logger.warning("TX executor did not exit within 1s after underrun") + except Exception: + logger.debug("TX executor raised during underrun teardown", exc_info=True) + self._close_session_sdr(session) + if self._tx is session: + self._tx = None + + def _drain_tx_queue(self, session: TxSession) -> None: try: - await self.ws.send_json({"type": "status", "status": status, "app_id": self._app_id}) + while True: + session.in_queue.get_nowait() + except queue.Empty: + pass + + def _schedule(self, coro) -> None: + loop = self._loop + if loop is None: + return + try: + asyncio.run_coroutine_threadsafe(coro, loop) + except Exception: + logger.debug("_schedule failed", exc_info=True) + + # ================================================================== + # Helpers + + def _close_session_sdr(self, session) -> None: + if session.sdr is None: + return + should_close = self._registry.release(session.device_key) + if should_close: + try: + session.sdr.close() + except Exception: + logger.debug("SDR close raised", exc_info=True) + + async def _send_status(self, status: str, app_id: str) -> None: + try: + await self.ws.send_json({"type": "status", "status": status, "app_id": app_id}) except Exception as exc: logger.debug("Status send failed: %s", exc) - async def _send_error(self, message: str) -> None: + async def _send_error(self, app_id: str, message: str) -> None: try: - await self.ws.send_json({"type": "error", "app_id": self._app_id, "message": message}) + await self.ws.send_json({"type": "error", "app_id": app_id, "message": message}) except Exception as exc: logger.debug("Error-frame send failed: %s", exc) + async def _send_tx_status(self, app_id: str, state: str, message: str | None = None) -> None: + payload: dict = {"type": "tx_status", "app_id": app_id, "state": state} + if message is not None: + payload["message"] = message + try: + await self.ws.send_json(payload) + except Exception as exc: + logger.debug("tx_status send failed: %s", exc) + # --------------------------------------------------------------------------- # Helpers @@ -172,6 +636,10 @@ _CONFIG_ATTR_MAP = { "center_freq": ("center_freq", "rx_center_frequency"), "gain": ("gain", "rx_gain"), "bandwidth": ("bandwidth", "rx_bandwidth"), + "tx_sample_rate": ("tx_sample_rate",), + "tx_center_frequency": ("tx_center_frequency", "tx_lo"), + "tx_gain": ("tx_gain",), + "tx_bandwidth": ("tx_bandwidth",), } @@ -194,6 +662,11 @@ def _apply_sdr_config(sdr: Any, cfg: dict) -> None: logger.debug("radio_config key %r ignored (no matching attr)", key) +def _silence(num_samples: int) -> np.ndarray: + """Return a ``num_samples``-length zero-filled complex64 buffer.""" + return np.zeros(int(num_samples), dtype=np.complex64) + + def _samples_to_interleaved_float32(samples: Any) -> bytes: """Convert complex IQ samples (any numeric dtype) to interleaved float32 bytes.""" arr = np.asarray(samples) @@ -214,8 +687,12 @@ def _default_sdr_factory(device: str, identifier: str | None): # --------------------------------------------------------------------------- # Top-level entry -async def run_streamer(ws_url: str, token: str) -> None: +async def run_streamer(ws_url: str, token: str, *, cfg: AgentConfig | None = None) -> None: """Connect to *ws_url* and run the streamer loop until cancelled.""" ws = WsClient(ws_url, token) - streamer = Streamer(ws) - await ws.run(streamer.on_message, streamer.build_heartbeat) + streamer = Streamer(ws, cfg=cfg) + await ws.run( + streamer.on_message, + streamer.build_heartbeat, + on_binary=streamer.on_binary, + ) diff --git a/src/ria_toolkit_oss/agent/ws_client.py b/src/ria_toolkit_oss/agent/ws_client.py index 1bc66f6..a33991d 100644 --- a/src/ria_toolkit_oss/agent/ws_client.py +++ b/src/ria_toolkit_oss/agent/ws_client.py @@ -15,6 +15,7 @@ logger = logging.getLogger("ria_agent.ws") MessageHandler = Callable[[dict], Awaitable[None]] HeartbeatBuilder = Callable[[], dict] +BinaryHandler = Callable[[bytes], Awaitable[None]] class WsClient: @@ -65,7 +66,12 @@ class WsClient: self._stop.set() # ------------------------------------------------------------------ - async def run(self, on_message: MessageHandler, heartbeat: HeartbeatBuilder) -> None: + async def run( + self, + on_message: MessageHandler, + heartbeat: HeartbeatBuilder, + on_binary: BinaryHandler | None = None, + ) -> None: """Main loop: connect, heartbeat, dispatch messages, reconnect on drop.""" while not self._stop.is_set(): try: @@ -75,8 +81,13 @@ class WsClient: try: async for raw in self._ws: if isinstance(raw, bytes): - # Server shouldn't send binary to the agent; log and drop. - logger.debug("Discarding unexpected %d-byte binary frame", len(raw)) + if on_binary is None: + logger.debug("Discarding unexpected %d-byte binary frame", len(raw)) + continue + try: + await on_binary(raw) + except Exception: + logger.exception("on_binary handler raised; dropping frame") continue try: msg = json.loads(raw) diff --git a/src/ria_toolkit_oss/app/cli.py b/src/ria_toolkit_oss/app/cli.py index 6cd0c1c..9bfb479 100644 --- a/src/ria_toolkit_oss/app/cli.py +++ b/src/ria_toolkit_oss/app/cli.py @@ -21,6 +21,7 @@ from __future__ import annotations import argparse import json +import os import shutil import subprocess import sys @@ -77,24 +78,33 @@ def _inspect_labels(engine: list[str], ref: str) -> dict: return {} -def _hardware_flags(labels: dict) -> list[str]: +def _gpu_available() -> bool: + if os.path.exists("/dev/nvidia0"): + return True + return shutil.which("nvidia-smi") is not None + + +def _hardware_flags(labels: dict, no_gpu: bool, no_usb: bool, no_host_net: bool) -> tuple[list[str], list[str]]: flags: list[str] = [] + notes: list[str] = [] profile = (labels.get(_LABEL_PROFILE) or "").lower() hardware = (labels.get(_LABEL_HARDWARE) or "").lower() hw_items = {h.strip() for h in hardware.split(",") if h.strip()} - if "nvidia" in profile or "holoscan" in profile or "cuda" in profile: - flags += ["--gpus", "all"] + wants_gpu = any(k in profile for k in ("nvidia", "holoscan", "cuda")) + if wants_gpu and not no_gpu: + if _gpu_available(): + flags += ["--gpus", "all"] + else: + notes.append("image wants GPU but no NVIDIA runtime detected — skipping --gpus (use --force-gpu to override)") - needs_usb = hw_items & {"pluto", "rtlsdr", "hackrf", "bladerf"} - if needs_usb: + if hw_items & {"pluto", "rtlsdr", "hackrf", "bladerf"} and not no_usb: flags += ["--device", "/dev/bus/usb"] - needs_net = hw_items & {"usrp", "thinkrf", "pluto"} - if needs_net: + if hw_items & {"usrp", "thinkrf", "pluto"} and not no_host_net: flags += ["--net", "host"] - return flags + return flags, notes def _cmd_configure(args: argparse.Namespace) -> int: @@ -132,7 +142,10 @@ def _cmd_run(args: argparse.Namespace) -> int: return rc labels = _inspect_labels(engine, ref) - hw_flags = _hardware_flags(labels) + no_gpu = args.no_gpu and not args.force_gpu + hw_flags, notes = _hardware_flags(labels, no_gpu=no_gpu, no_usb=args.no_usb, no_host_net=args.no_host_net) + if args.force_gpu and "--gpus" not in hw_flags: + hw_flags = ["--gpus", "all", *hw_flags] cmd = [*engine, "run", "--rm"] if not args.foreground: @@ -162,6 +175,8 @@ def _cmd_run(args: argparse.Namespace) -> int: print(f"Running {ref} [{label_str}]") if hw_flags: print(f" auto flags: {' '.join(hw_flags)}") + for note in notes: + print(f" note: {note}") return subprocess.call(cmd) @@ -225,6 +240,10 @@ def main() -> None: p_run.add_argument("-v", "--volume", action="append", help="Extra volume mount") p_run.add_argument("-p", "--publish", action="append", help="Publish port") p_run.add_argument("--foreground", "-F", action="store_true", help="Run in foreground (no -d)") + p_run.add_argument("--no-gpu", action="store_true", help="Skip --gpus flag even if image wants GPU") + p_run.add_argument("--force-gpu", action="store_true", help="Force --gpus all even if no NVIDIA runtime detected") + p_run.add_argument("--no-usb", action="store_true", help="Skip --device /dev/bus/usb") + p_run.add_argument("--no-host-net", action="store_true", help="Skip --net host") p_run.add_argument("--dry-run", action="store_true", help="Print the container command and exit") p_run.add_argument("--docker-args", nargs=argparse.REMAINDER, help="Pass remaining args to docker/podman run") p_run.add_argument("--app-args", nargs=argparse.REMAINDER, help="Pass remaining args to the app entrypoint") diff --git a/tests/agent/test_cli_tx.py b/tests/agent/test_cli_tx.py new file mode 100644 index 0000000..1543d4c --- /dev/null +++ b/tests/agent/test_cli_tx.py @@ -0,0 +1,111 @@ +"""CLI flags for TX opt-in and interlocks.""" + +from __future__ import annotations + +import json +import sys +from unittest.mock import patch + +from ria_toolkit_oss.agent import cli as agent_cli +from ria_toolkit_oss.agent import config as agent_config + + +class _FakeResp: + def __init__(self, payload: dict): + self._payload = payload + + def read(self) -> bytes: + return json.dumps(self._payload).encode() + + def __enter__(self): + return self + + def __exit__(self, *_a): + return False + + +def _run_register(argv: list[str], cfg_path) -> int: + fake_resp = _FakeResp({"agent_id": "agent-1", "token": "tok-abc"}) + with patch.dict("os.environ", {"RIA_AGENT_CONFIG": str(cfg_path)}, clear=False), \ + patch("urllib.request.urlopen", return_value=fake_resp), \ + patch.object(sys, "argv", ["ria-agent", *argv]): + try: + agent_cli.main() + except SystemExit as exc: + return int(exc.code or 0) + return 0 + + +def test_register_without_allow_tx_keeps_tx_disabled(tmp_path): + cfg_path = tmp_path / "agent.json" + _run_register( + ["register", "--hub", "http://hub:3005", "--api-key", "K"], + cfg_path, + ) + cfg = agent_config.load(path=cfg_path) + assert cfg.agent_id == "agent-1" + assert cfg.tx_enabled is False + assert cfg.tx_max_gain_db is None + + +def test_register_with_allow_tx_and_caps(tmp_path): + cfg_path = tmp_path / "agent.json" + _run_register( + [ + "register", + "--hub", + "http://hub:3005", + "--api-key", + "K", + "--allow-tx", + "--tx-max-gain-db", + "-10", + "--tx-max-duration-s", + "60", + "--tx-freq-range", + "2.4e9", + "2.5e9", + "--tx-freq-range", + "5.7e9", + "5.8e9", + ], + cfg_path, + ) + cfg = agent_config.load(path=cfg_path) + assert cfg.tx_enabled is True + assert cfg.tx_max_gain_db == -10.0 + assert cfg.tx_max_duration_s == 60.0 + assert cfg.tx_allowed_freq_ranges == [[2.4e9, 2.5e9], [5.7e9, 5.8e9]] + + +def test_stream_allow_tx_does_not_persist(tmp_path): + # Pre-register with tx_enabled=False, then simulate `stream --allow-tx`. + # The on-disk config must remain unchanged; the runtime flag is process-local. + cfg_path = tmp_path / "agent.json" + base = agent_config.AgentConfig( + hub_url="http://hub:3005", + agent_id="agent-1", + token="tok-abc", + tx_enabled=False, + ) + agent_config.save(base, path=cfg_path) + + captured: dict = {} + + async def _fake_run_streamer(url, token, *, cfg): + captured["cfg"] = cfg + return None + + with patch.dict("os.environ", {"RIA_AGENT_CONFIG": str(cfg_path)}, clear=False), \ + patch("ria_toolkit_oss.agent.streamer.run_streamer", new=_fake_run_streamer), \ + patch.object(sys, "argv", ["ria-agent", "stream", "--allow-tx"]): + try: + agent_cli.main() + except SystemExit: + pass + + # Runtime cfg had TX flipped on + assert captured["cfg"].tx_enabled is True + # But the persisted file is untouched + on_disk = agent_config.load(path=cfg_path) + assert on_disk.tx_enabled is False diff --git a/tests/agent/test_config.py b/tests/agent/test_config.py index 2532abd..7d2a6b4 100644 --- a/tests/agent/test_config.py +++ b/tests/agent/test_config.py @@ -20,6 +20,36 @@ def test_load_missing_returns_empty(tmp_path): assert loaded == agent_config.AgentConfig() +def test_tx_fields_round_trip(tmp_path): + p = tmp_path / "agent.json" + cfg = agent_config.AgentConfig( + hub_url="https://hub.example.com", + agent_id="agent-1", + token="t", + tx_enabled=True, + tx_max_gain_db=-10.0, + tx_max_duration_s=60.0, + tx_allowed_freq_ranges=[[2.4e9, 2.5e9], [5.7e9, 5.8e9]], + ) + agent_config.save(cfg, path=p) + loaded = agent_config.load(path=p) + assert loaded.tx_enabled is True + assert loaded.tx_max_gain_db == -10.0 + assert loaded.tx_max_duration_s == 60.0 + assert loaded.tx_allowed_freq_ranges == [[2.4e9, 2.5e9], [5.7e9, 5.8e9]] + + +def test_tx_fields_default_when_absent(tmp_path): + # Old configs written before TX existed should load cleanly with safe defaults. + p = tmp_path / "agent.json" + p.write_text('{"hub_url": "x", "agent_id": "a", "token": "t"}') + cfg = agent_config.load(path=p) + assert cfg.tx_enabled is False + assert cfg.tx_max_gain_db is None + assert cfg.tx_max_duration_s is None + assert cfg.tx_allowed_freq_ranges is None + + def test_extra_keys_preserved(tmp_path): p = tmp_path / "agent.json" p.write_text('{"hub_url": "x", "custom": 42}') diff --git a/tests/agent/test_disconnect.py b/tests/agent/test_disconnect.py index f063e3a..3063613 100644 --- a/tests/agent/test_disconnect.py +++ b/tests/agent/test_disconnect.py @@ -67,9 +67,9 @@ def test_streamer_reports_disconnected_and_ends_capture(): "radio_config": {"device": "fake", "buffer_size": 8}, } ) - # Wait for the capture task to fail out. - for _ in range(50): - if streamer._capture_task and streamer._capture_task.done(): + # Wait for the capture loop to emit its error frame and tear down the session. + for _ in range(100): + if any(m.get("type") == "error" for m in ws.json_sent) and streamer._rx is None: break await asyncio.sleep(0.01) return ws, sdr, streamer @@ -79,3 +79,5 @@ def test_streamer_reports_disconnected_and_ends_capture(): errors = [m for m in ws.json_sent if m.get("type") == "error"] assert errors, "expected an error frame" assert "disconnected" in errors[-1]["message"].lower() + # Session handle cleared so future starts can proceed. + assert streamer._rx is None diff --git a/tests/agent/test_full_duplex.py b/tests/agent/test_full_duplex.py new file mode 100644 index 0000000..6ad2f62 --- /dev/null +++ b/tests/agent/test_full_duplex.py @@ -0,0 +1,133 @@ +"""Concurrent RX + TX sessions on the same agent — shared SDR via registry.""" + +from __future__ import annotations + +import asyncio +import time + +import numpy as np + +from ria_toolkit_oss.agent.config import AgentConfig +from ria_toolkit_oss.agent.streamer import Streamer +from ria_toolkit_oss.sdr.mock import MockSDR + + +class FullDuplexMockSDR(MockSDR): + """MockSDR with a recording TX path so the test can assert both directions.""" + + def __init__(self, buffer_size: int): + super().__init__(buffer_size=buffer_size) + self.tx_produced: list[np.ndarray] = [] + + def _stream_tx(self, callback): + self._enable_tx = True + self._tx_initialized = True + while self._enable_tx: + result = callback(self.rx_buffer_size) + self.tx_produced.append(np.asarray(result).copy()) + time.sleep(0.005) + + +class FakeWs: + def __init__(self): + self.json_sent = [] + self.bytes_sent = [] + + async def send_json(self, p): + self.json_sent.append(p) + + async def send_bytes(self, b): + self.bytes_sent.append(b) + + +def _iq_frame(samples: np.ndarray) -> bytes: + interleaved = np.empty(samples.size * 2, dtype=np.float32) + interleaved[0::2] = samples.real + interleaved[1::2] = samples.imag + return interleaved.tobytes() + + +def test_rx_and_tx_share_one_sdr_instance(): + built: list[FullDuplexMockSDR] = [] + + def factory(device, identifier): + sdr = FullDuplexMockSDR(buffer_size=16) + built.append(sdr) + return sdr + + async def scenario(): + ws = FakeWs() + s = Streamer(ws=ws, sdr_factory=factory, cfg=AgentConfig(tx_enabled=True)) + + # Start RX first. + await s.on_message( + { + "type": "start", + "app_id": "app-1", + "radio_config": {"device": "mock", "buffer_size": 16}, + } + ) + # Then start TX on the same device — should share the SDR handle. + await s.on_message( + { + "type": "tx_start", + "app_id": "app-1", + "radio_config": { + "device": "mock", + "buffer_size": 16, + "tx_gain": -20, + "tx_center_frequency": 2.45e9, + "underrun_policy": "zero", + }, + } + ) + + # Push a known TX buffer. + marker = np.arange(16, dtype=np.complex64) + 7 + await s.on_binary(_iq_frame(marker)) + + # Let both directions produce output. + for _ in range(80): + rx_ok = len(ws.bytes_sent) >= 2 + tx_ok = any(np.array_equal(b, marker) for b in built[0].tx_produced) if built else False + if rx_ok and tx_ok: + break + await asyncio.sleep(0.01) + + # Heartbeat should show both sessions. + hb = s.build_heartbeat() + + # Stop TX first, RX keeps running. + await s.on_message({"type": "tx_stop", "app_id": "app-1"}) + tx_after_stop = s._tx is None + rx_still_active = s._rx is not None + + # Now stop RX. + await s.on_message({"type": "stop", "app_id": "app-1"}) + + return ws, s, built, hb, tx_after_stop, rx_still_active + + ws, s, built, hb, tx_after_stop, rx_still_active = asyncio.run(scenario()) + + # One SDR was built and shared. + assert len(built) == 1, f"expected exactly one SDR instance, got {len(built)}" + + # Both directions produced output. + assert len(ws.bytes_sent) >= 1, "RX produced no IQ frames" + marker = np.arange(16, dtype=np.complex64) + 7 + assert any( + np.array_equal(b, marker) for b in built[0].tx_produced + ), "TX callback never saw the pushed marker buffer" + + # Heartbeat reflected both sessions while they were active. + assert hb["sessions"]["rx"]["app_id"] == "app-1" + assert hb["sessions"]["tx"]["app_id"] == "app-1" + + # Stopping TX does not tear down RX. + assert tx_after_stop + assert rx_still_active + + # After both stops, registry is empty. + assert s._registry.refcount(("mock", None)) == 0 + assert s._rx is None + assert s._tx is None diff --git a/tests/agent/test_hardware.py b/tests/agent/test_hardware.py index ab9fcdf..51b2e45 100644 --- a/tests/agent/test_hardware.py +++ b/tests/agent/test_hardware.py @@ -23,7 +23,24 @@ def test_heartbeat_payload_shape(): assert p["status"] == "idle" assert "mock" in p["hardware"] assert "app_id" not in p + # New fields, default shape + assert p["capabilities"] == ["rx"] + assert p["tx_enabled"] is False p2 = hardware.heartbeat_payload(status="streaming", app_id="abc") assert p2["status"] == "streaming" assert p2["app_id"] == "abc" + + +def test_heartbeat_payload_tx_capability_from_cfg(): + from ria_toolkit_oss.agent.config import AgentConfig + + p = hardware.heartbeat_payload(cfg=AgentConfig(tx_enabled=True)) + assert p["capabilities"] == ["rx", "tx"] + assert p["tx_enabled"] is True + + +def test_heartbeat_payload_sessions_field(): + sessions = {"rx": {"app_id": "a", "state": "streaming"}} + p = hardware.heartbeat_payload(status="streaming", app_id="a", sessions=sessions) + assert p["sessions"] == sessions diff --git a/tests/agent/test_integration_tx.py b/tests/agent/test_integration_tx.py new file mode 100644 index 0000000..4fc13af --- /dev/null +++ b/tests/agent/test_integration_tx.py @@ -0,0 +1,144 @@ +"""End-to-end: local websockets server drives a Streamer's TX path.""" + +from __future__ import annotations + +import asyncio +import json +import time + +import numpy as np +import websockets + +from ria_toolkit_oss.agent.config import AgentConfig +from ria_toolkit_oss.agent.streamer import Streamer +from ria_toolkit_oss.agent.ws_client import WsClient +from ria_toolkit_oss.sdr.mock import MockSDR + + +class RecordingMockSDR(MockSDR): + def __init__(self, buffer_size: int): + super().__init__(buffer_size=buffer_size) + self.tx_produced: list[np.ndarray] = [] + + def _stream_tx(self, callback): + self._enable_tx = True + self._tx_initialized = True + while self._enable_tx: + result = callback(self.rx_buffer_size) + self.tx_produced.append(np.asarray(result).copy()) + time.sleep(0.005) + + +def _iq_frame(samples: np.ndarray) -> bytes: + interleaved = np.empty(samples.size * 2, dtype=np.float32) + interleaved[0::2] = samples.real + interleaved[1::2] = samples.imag + return interleaved.tobytes() + + +def test_server_tx_start_binary_stop_cycle_over_real_ws(): + BUF = 16 + sdr = RecordingMockSDR(buffer_size=BUF) + marker = np.arange(BUF, dtype=np.complex64) + 1 + + async def scenario(): + control_frames: list[dict] = [] + done = asyncio.Event() + + async def server_handler(ws): + try: + # Drain initial heartbeat. + first = await asyncio.wait_for(ws.recv(), timeout=2.0) + control_frames.append(json.loads(first)) + + await ws.send( + json.dumps( + { + "type": "tx_start", + "app_id": "tx-app", + "radio_config": { + "device": "mock", + "buffer_size": BUF, + "tx_sample_rate": 1_000_000, + "tx_center_frequency": 2.45e9, + "tx_gain": -20, + "underrun_policy": "zero", + }, + } + ) + ) + + # Push a few binary IQ frames. + for _ in range(3): + await ws.send(_iq_frame(marker)) + + # Wait for at least "armed" + "transmitting" statuses. + for _ in range(100): + msg = await asyncio.wait_for(ws.recv(), timeout=2.0) + if isinstance(msg, str): + control_frames.append(json.loads(msg)) + if any( + f.get("type") == "tx_status" and f.get("state") == "transmitting" + for f in control_frames + ): + break + + await ws.send(json.dumps({"type": "tx_stop", "app_id": "tx-app"})) + + # Drain trailing statuses. + try: + while True: + msg = await asyncio.wait_for(ws.recv(), timeout=0.5) + if isinstance(msg, str): + control_frames.append(json.loads(msg)) + except (asyncio.TimeoutError, Exception): + pass + finally: + done.set() + + server = await websockets.serve(server_handler, "127.0.0.1", 0) + port = server.sockets[0].getsockname()[1] + try: + client = WsClient( + f"ws://127.0.0.1:{port}", + token="", + heartbeat_interval=10.0, + reconnect_pause=0.05, + ) + streamer = Streamer( + ws=client, + sdr_factory=lambda d, i: sdr, + cfg=AgentConfig(tx_enabled=True), + ) + task = asyncio.create_task( + client.run( + on_message=streamer.on_message, + heartbeat=streamer.build_heartbeat, + on_binary=streamer.on_binary, + ) + ) + await asyncio.wait_for(done.wait(), timeout=5.0) + client.stop() + task.cancel() + try: + await task + except (asyncio.CancelledError, Exception): + pass + finally: + server.close() + await server.wait_closed() + return control_frames, streamer + + controls, streamer = asyncio.run(scenario()) + + # Heartbeat reached the server. + assert any(f.get("type") == "heartbeat" for f in controls) + # tx_status lifecycle: armed → transmitting → done. + tx_states = [f["state"] for f in controls if f.get("type") == "tx_status"] + assert tx_states[0] == "armed" + assert "transmitting" in tx_states + assert tx_states[-1] == "done" + # TX callback saw our marker buffer at least once. + assert any(np.array_equal(b, marker) for b in sdr.tx_produced) + # Session cleared. + assert streamer._tx is None diff --git a/tests/agent/test_streamer.py b/tests/agent/test_streamer.py index 1bb2081..2aa842e 100644 --- a/tests/agent/test_streamer.py +++ b/tests/agent/test_streamer.py @@ -46,15 +46,29 @@ def test_apply_sdr_config_sets_attributes(): def test_heartbeat_reflects_status_and_app(): - s = Streamer(ws=FakeWs(), sdr_factory=_factory) - hb = s.build_heartbeat() - assert hb["type"] == "heartbeat" - assert hb["status"] == "idle" - s._status = "streaming" - s._app_id = "app-42" - hb2 = s.build_heartbeat() - assert hb2["status"] == "streaming" - assert hb2["app_id"] == "app-42" + async def scenario(): + s = Streamer(ws=FakeWs(), sdr_factory=_factory) + hb = s.build_heartbeat() + assert hb["type"] == "heartbeat" + assert hb["status"] == "idle" + # capabilities default to rx-only + assert hb["capabilities"] == ["rx"] + assert hb["tx_enabled"] is False + + await s.on_message( + { + "type": "start", + "app_id": "app-42", + "radio_config": {"device": "mock", "buffer_size": 32}, + } + ) + hb2 = s.build_heartbeat() + assert hb2["status"] == "streaming" + assert hb2["app_id"] == "app-42" + assert hb2["sessions"]["rx"]["app_id"] == "app-42" + await s.on_message({"type": "stop", "app_id": "app-42"}) + + asyncio.run(scenario()) def test_full_start_stream_stop_cycle(): @@ -89,7 +103,7 @@ def test_full_start_stream_stop_cycle(): statuses = [m for m in ws.json_sent if m.get("type") == "status"] assert statuses[0]["status"] == "streaming" assert statuses[-1]["status"] == "idle" - assert streamer._sdr is None + assert streamer._rx is None def test_start_without_device_emits_error(): @@ -110,6 +124,7 @@ def test_configure_queues_update(): await streamer.on_message( {"type": "configure", "app_id": "x", "radio_config": {"center_frequency": 915e6}} ) + # Before start(), pending config lives on the standalone dict exposed via the _pending_config shim. return streamer._pending_config pending = asyncio.run(scenario()) @@ -122,3 +137,56 @@ def test_unknown_message_type_is_ignored(): await s.on_message({"type": "nope"}) asyncio.run(scenario()) + + +def test_registry_shares_sdr_across_start_stop_cycles(): + # Two sequential start/stop cycles with the same (device, identifier) + # should hit the registry's cache path rather than constructing a new SDR. + built: list[MockSDR] = [] + + def counting_factory(device: str, identifier): + sdr = MockSDR(buffer_size=16, seed=0) + built.append(sdr) + return sdr + + async def scenario(): + s = Streamer(ws=FakeWs(), sdr_factory=counting_factory) + for _ in range(2): + await s.on_message( + { + "type": "start", + "app_id": "a", + "radio_config": {"device": "mock", "buffer_size": 16}, + } + ) + # Let one capture buffer flow before stopping so the loop is engaged. + await asyncio.sleep(0.02) + await s.on_message({"type": "stop", "app_id": "a"}) + + asyncio.run(scenario()) + # A new SDR per cycle (we fully close between starts) — registry refcount + # drops to zero on each stop. This test confirms close-and-rebuild works; + # the ref-counting share-while-open case is covered in the full-duplex tests. + assert len(built) == 2 + + +def test_tx_start_rejected_when_tx_disabled(): + from ria_toolkit_oss.agent.config import AgentConfig + + async def scenario(): + ws = FakeWs() + s = Streamer(ws=ws, sdr_factory=_factory, cfg=AgentConfig(tx_enabled=False)) + await s.on_message( + { + "type": "tx_start", + "app_id": "a", + "radio_config": {"device": "mock", "tx_center_frequency": 2.45e9, "tx_gain": -20}, + } + ) + return ws + + ws = asyncio.run(scenario()) + tx_statuses = [m for m in ws.json_sent if m.get("type") == "tx_status"] + assert tx_statuses, "expected a tx_status frame" + assert tx_statuses[-1]["state"] == "error" + assert "disabled" in tx_statuses[-1]["message"].lower() diff --git a/tests/agent/test_streamer_tx.py b/tests/agent/test_streamer_tx.py new file mode 100644 index 0000000..6cb2bb4 --- /dev/null +++ b/tests/agent/test_streamer_tx.py @@ -0,0 +1,133 @@ +"""TX streaming happy path + shutdown semantics.""" + +from __future__ import annotations + +import asyncio +import time + +import numpy as np + +from ria_toolkit_oss.agent.config import AgentConfig +from ria_toolkit_oss.agent.streamer import Streamer +from ria_toolkit_oss.sdr.mock import MockSDR + + +class RecordingMockSDR(MockSDR): + """MockSDR that records each TX callback's returned buffer.""" + + def __init__(self, buffer_size: int): + super().__init__(buffer_size=buffer_size) + self.tx_produced: list[np.ndarray] = [] + + def _stream_tx(self, callback) -> None: + self._enable_tx = True + self._tx_initialized = True + while self._enable_tx: + result = callback(self.rx_buffer_size) + self.tx_produced.append(np.asarray(result)) + time.sleep(0.005) + + +class FakeWs: + def __init__(self): + self.json_sent: list[dict] = [] + self.bytes_sent: list[bytes] = [] + + async def send_json(self, payload): + self.json_sent.append(payload) + + async def send_bytes(self, data): + self.bytes_sent.append(data) + + +def _iq_frame(samples: np.ndarray) -> bytes: + interleaved = np.empty(samples.size * 2, dtype=np.float32) + interleaved[0::2] = samples.real + interleaved[1::2] = samples.imag + return interleaved.tobytes() + + +def test_tx_start_streams_binary_to_callback(): + BUF = 16 + sdr = RecordingMockSDR(buffer_size=BUF) + + async def scenario(): + ws = FakeWs() + s = Streamer(ws=ws, sdr_factory=lambda d, i: sdr, cfg=AgentConfig(tx_enabled=True)) + + # Frames of distinct content so we can assert ordering. + frame_a = np.arange(BUF, dtype=np.complex64) * (1 + 0j) + frame_b = (np.arange(BUF, dtype=np.complex64) + BUF) * (1 + 0j) + frame_c = (np.arange(BUF, dtype=np.complex64) + 2 * BUF) * (1 + 0j) + + await s.on_message( + { + "type": "tx_start", + "app_id": "app-1", + "radio_config": { + "device": "mock", + "buffer_size": BUF, + "tx_sample_rate": 1_000_000, + "tx_center_frequency": 2.45e9, + "tx_gain": -20, + "underrun_policy": "zero", + }, + } + ) + # Push three IQ frames. + await s.on_binary(_iq_frame(frame_a)) + await s.on_binary(_iq_frame(frame_b)) + await s.on_binary(_iq_frame(frame_c)) + + # Let the executor thread consume them. + for _ in range(100): + # At least the 3 real frames, plus any zero-fill from before they + # arrived. We stop once 3 non-trivial buffers are recorded. + nontrivial = [b for b in sdr.tx_produced if np.any(b != 0)] + if len(nontrivial) >= 3: + break + await asyncio.sleep(0.01) + + await s.on_message({"type": "tx_stop", "app_id": "app-1"}) + return ws, sdr, s + + ws, sdr, streamer = asyncio.run(scenario()) + + nontrivial = [b for b in sdr.tx_produced if np.any(b != 0)] + assert len(nontrivial) >= 3, "expected ≥3 nontrivial TX buffers" + + # First three nontrivial buffers match the order we pushed them. + np.testing.assert_array_equal(nontrivial[0], np.arange(BUF, dtype=np.complex64)) + np.testing.assert_array_equal(nontrivial[1], np.arange(BUF, 2 * BUF, dtype=np.complex64)) + np.testing.assert_array_equal(nontrivial[2], np.arange(2 * BUF, 3 * BUF, dtype=np.complex64)) + + # Lifecycle: armed → transmitting → done. + states = [m["state"] for m in ws.json_sent if m.get("type") == "tx_status"] + assert states[0] == "armed" + assert "transmitting" in states + assert states[-1] == "done" + # Session cleared. + assert streamer._tx is None + + +def test_tx_stop_releases_sdr(): + sdr = RecordingMockSDR(buffer_size=8) + + async def scenario(): + ws = FakeWs() + s = Streamer(ws=ws, sdr_factory=lambda d, i: sdr, cfg=AgentConfig(tx_enabled=True)) + await s.on_message( + { + "type": "tx_start", + "app_id": "a", + "radio_config": {"device": "mock", "buffer_size": 8, "underrun_policy": "zero"}, + } + ) + await asyncio.sleep(0.03) + await s.on_message({"type": "tx_stop", "app_id": "a"}) + return s + + s = asyncio.run(scenario()) + # After stop, the registry has no outstanding references to ("mock", None). + assert s._registry.refcount(("mock", None)) == 0 + assert s._tx is None diff --git a/tests/agent/test_tx_safety.py b/tests/agent/test_tx_safety.py new file mode 100644 index 0000000..5307917 --- /dev/null +++ b/tests/agent/test_tx_safety.py @@ -0,0 +1,167 @@ +"""Agent-side TX interlocks: gain cap, freq ranges, duplicate sessions, disabled.""" + +from __future__ import annotations + +import asyncio + +from ria_toolkit_oss.agent.config import AgentConfig +from ria_toolkit_oss.agent.streamer import Streamer +from ria_toolkit_oss.sdr.mock import MockSDR + + +class FakeWs: + def __init__(self): + self.json_sent = [] + self.bytes_sent = [] + + async def send_json(self, p): + self.json_sent.append(p) + + async def send_bytes(self, b): + self.bytes_sent.append(b) + + +def _last_tx_status(ws): + frames = [m for m in ws.json_sent if m.get("type") == "tx_status"] + return frames[-1] if frames else None + + +def _tx_start(app_id="a", **radio): + rc = {"device": "mock", "buffer_size": 16, "underrun_policy": "zero"} + rc.update(radio) + return {"type": "tx_start", "app_id": app_id, "radio_config": rc} + + +def _make_streamer(cfg): + built: list = [] + + def factory(device, identifier): + sdr = MockSDR(buffer_size=16) + built.append(sdr) + return sdr + + ws = FakeWs() + s = Streamer(ws=ws, sdr_factory=factory, cfg=cfg) + return s, ws, built + + +def test_rejects_when_tx_disabled(): + async def scenario(): + s, ws, built = _make_streamer(AgentConfig(tx_enabled=False)) + await s.on_message(_tx_start(tx_gain=-20, tx_center_frequency=2.45e9)) + return s, ws, built + + s, ws, built = asyncio.run(scenario()) + status = _last_tx_status(ws) + assert status and status["state"] == "error" + assert "disabled" in status["message"].lower() + assert not built, "SDR should never have been constructed" + assert s._tx is None + + +def test_rejects_when_tx_gain_exceeds_cap(): + async def scenario(): + s, ws, built = _make_streamer(AgentConfig(tx_enabled=True, tx_max_gain_db=-15.0)) + await s.on_message(_tx_start(tx_gain=-5, tx_center_frequency=2.45e9)) + return ws, built + + ws, built = asyncio.run(scenario()) + status = _last_tx_status(ws) + assert status and status["state"] == "error" + assert "exceeds cap" in status["message"] + assert not built + + +def test_allows_gain_at_cap_boundary(): + async def scenario(): + s, ws, _ = _make_streamer(AgentConfig(tx_enabled=True, tx_max_gain_db=-10.0)) + await s.on_message(_tx_start(tx_gain=-10, tx_center_frequency=2.45e9)) + # Stop promptly to avoid keeping an executor thread around. + await asyncio.sleep(0.02) + await s.on_message({"type": "tx_stop", "app_id": "a"}) + return ws + + ws = asyncio.run(scenario()) + states = [m["state"] for m in ws.json_sent if m.get("type") == "tx_status"] + assert "armed" in states + assert states[-1] == "done" + + +def test_rejects_when_freq_outside_ranges(): + async def scenario(): + s, ws, built = _make_streamer( + AgentConfig( + tx_enabled=True, + tx_allowed_freq_ranges=[[2.4e9, 2.5e9]], + ) + ) + await s.on_message(_tx_start(tx_center_frequency=5.8e9, tx_gain=-20)) + return ws, built + + ws, built = asyncio.run(scenario()) + status = _last_tx_status(ws) + assert status and status["state"] == "error" + assert "outside allowed ranges" in status["message"] + assert not built + + +def test_allows_freq_inside_a_range(): + async def scenario(): + s, ws, _ = _make_streamer( + AgentConfig( + tx_enabled=True, + tx_allowed_freq_ranges=[[2.4e9, 2.5e9], [5.7e9, 5.8e9]], + ) + ) + await s.on_message(_tx_start(tx_center_frequency=5.75e9, tx_gain=-20)) + await asyncio.sleep(0.02) + await s.on_message({"type": "tx_stop", "app_id": "a"}) + return ws + + ws = asyncio.run(scenario()) + states = [m["state"] for m in ws.json_sent if m.get("type") == "tx_status"] + assert "armed" in states + assert states[-1] == "done" + + +def test_rejects_duplicate_tx_session(): + async def scenario(): + s, ws, _ = _make_streamer(AgentConfig(tx_enabled=True)) + await s.on_message(_tx_start(app_id="a", tx_gain=-20, tx_center_frequency=2.45e9)) + await asyncio.sleep(0.01) + await s.on_message(_tx_start(app_id="b", tx_gain=-20, tx_center_frequency=2.45e9)) + # Let the second request process, then stop cleanly. + await asyncio.sleep(0.01) + await s.on_message({"type": "tx_stop", "app_id": "a"}) + return ws + + ws = asyncio.run(scenario()) + errors = [ + m for m in ws.json_sent + if m.get("type") == "tx_status" and m.get("state") == "error" + ] + assert any("already active" in e.get("message", "") for e in errors) + + +def test_rejects_invalid_underrun_policy(): + async def scenario(): + s, ws, _ = _make_streamer(AgentConfig(tx_enabled=True)) + await s.on_message( + { + "type": "tx_start", + "app_id": "a", + "radio_config": { + "device": "mock", + "buffer_size": 8, + "tx_gain": -20, + "tx_center_frequency": 2.45e9, + "underrun_policy": "teleport", + }, + } + ) + return ws + + ws = asyncio.run(scenario()) + status = _last_tx_status(ws) + assert status and status["state"] == "error" + assert "underrun_policy" in status["message"] diff --git a/tests/agent/test_tx_underrun.py b/tests/agent/test_tx_underrun.py new file mode 100644 index 0000000..e95feec --- /dev/null +++ b/tests/agent/test_tx_underrun.py @@ -0,0 +1,136 @@ +"""Underrun policies: pause, zero, repeat.""" + +from __future__ import annotations + +import asyncio +import time + +import numpy as np + +from ria_toolkit_oss.agent.config import AgentConfig +from ria_toolkit_oss.agent.streamer import Streamer +from ria_toolkit_oss.sdr.mock import MockSDR + + +class RecordingMockSDR(MockSDR): + def __init__(self, buffer_size: int): + super().__init__(buffer_size=buffer_size) + self.tx_produced: list[np.ndarray] = [] + + def _stream_tx(self, callback): + self._enable_tx = True + self._tx_initialized = True + while self._enable_tx: + result = callback(self.rx_buffer_size) + self.tx_produced.append(np.asarray(result).copy()) + time.sleep(0.005) + + +class FakeWs: + def __init__(self): + self.json_sent = [] + self.bytes_sent = [] + + async def send_json(self, p): + self.json_sent.append(p) + + async def send_bytes(self, b): + self.bytes_sent.append(b) + + +def _iq_frame(samples: np.ndarray) -> bytes: + interleaved = np.empty(samples.size * 2, dtype=np.float32) + interleaved[0::2] = samples.real + interleaved[1::2] = samples.imag + return interleaved.tobytes() + + +def _start_cfg(policy: str, buf: int = 8) -> dict: + return { + "type": "tx_start", + "app_id": "a", + "radio_config": { + "device": "mock", + "buffer_size": buf, + "tx_gain": -20, + "tx_center_frequency": 2.45e9, + "underrun_policy": policy, + }, + } + + +def test_underrun_pause_stops_session_and_emits_status(): + sdr = RecordingMockSDR(buffer_size=8) + + async def scenario(): + ws = FakeWs() + s = Streamer(ws=ws, sdr_factory=lambda d, i: sdr, cfg=AgentConfig(tx_enabled=True)) + await s.on_message(_start_cfg("pause")) + # Do not push any buffers. The callback underruns on first tick and + # the watchdog should emit "underrun" and tear down. + for _ in range(100): + if any( + m.get("type") == "tx_status" and m.get("state") == "underrun" + for m in ws.json_sent + ): + break + await asyncio.sleep(0.01) + for _ in range(50): + if s._tx is None: + break + await asyncio.sleep(0.01) + return ws, s + + ws, s = asyncio.run(scenario()) + states = [m["state"] for m in ws.json_sent if m.get("type") == "tx_status"] + assert "underrun" in states + assert s._tx is None + + +def test_underrun_zero_keeps_session_alive(): + sdr = RecordingMockSDR(buffer_size=8) + + async def scenario(): + ws = FakeWs() + s = Streamer(ws=ws, sdr_factory=lambda d, i: sdr, cfg=AgentConfig(tx_enabled=True)) + await s.on_message(_start_cfg("zero")) + # Let it produce several underrun-filled buffers. + await asyncio.sleep(0.08) + still_alive = s._tx is not None + await s.on_message({"type": "tx_stop", "app_id": "a"}) + return ws, still_alive + + ws, still_alive = asyncio.run(scenario()) + # No underrun status emitted (policy absorbs it silently). + assert not any( + m.get("type") == "tx_status" and m.get("state") == "underrun" for m in ws.json_sent + ) + assert still_alive + # All produced buffers are zero (no real data was pushed). + assert sdr.tx_produced, "expected at least one TX callback invocation" + assert all(not np.any(b != 0) for b in sdr.tx_produced) + + +def test_underrun_repeat_replays_last_buffer(): + BUF = 8 + sdr = RecordingMockSDR(buffer_size=BUF) + marker = np.arange(BUF, dtype=np.complex64) + 1 # distinct non-zero buffer + + async def scenario(): + ws = FakeWs() + s = Streamer(ws=ws, sdr_factory=lambda d, i: sdr, cfg=AgentConfig(tx_enabled=True)) + await s.on_message(_start_cfg("repeat", buf=BUF)) + await s.on_binary(_iq_frame(marker)) + # Give the executor time to consume the real frame + several repeats. + await asyncio.sleep(0.08) + await s.on_message({"type": "tx_stop", "app_id": "a"}) + return ws, sdr + + ws, sdr = asyncio.run(scenario()) + # No underrun status emitted. + assert not any( + m.get("type") == "tx_status" and m.get("state") == "underrun" for m in ws.json_sent + ) + # At least two buffers equal to the marker — the real one and ≥1 repeat. + matching = [b for b in sdr.tx_produced if np.array_equal(b, marker)] + assert len(matching) >= 2, f"expected ≥2 buffers matching marker, got {len(matching)}" diff --git a/tests/agent/test_ws_client.py b/tests/agent/test_ws_client.py index 0994a5b..4061f32 100644 --- a/tests/agent/test_ws_client.py +++ b/tests/agent/test_ws_client.py @@ -113,6 +113,109 @@ def test_reconnects_after_server_drop(): assert n >= 2 +def test_binary_frame_forwarded_to_handler(): + payload = bytes(range(128)) + + async def scenario(): + received: list[bytes] = [] + done = asyncio.Event() + + async def handler(ws): + await ws.send(payload) + done.set() + try: + await ws.wait_closed() + except Exception: + pass + + server, port = await _open_server(handler) + try: + client = WsClient( + f"ws://127.0.0.1:{port}", + token="", + heartbeat_interval=10.0, + reconnect_pause=0.05, + ) + + async def on_bin(data): + received.append(data) + + task = asyncio.create_task( + client.run( + on_message=lambda _m: asyncio.sleep(0), + heartbeat=lambda: {"type": "heartbeat"}, + on_binary=on_bin, + ) + ) + for _ in range(50): + if received: + break + await asyncio.sleep(0.02) + client.stop() + task.cancel() + try: + await task + except (asyncio.CancelledError, Exception): + pass + finally: + server.close() + await server.wait_closed() + return received + + received = asyncio.run(scenario()) + assert received == [payload] + + +def test_binary_frame_dropped_when_no_handler(): + # Regression guard: existing behavior (drop server-sent binary) preserved when + # on_binary is not supplied. + async def scenario(): + crashes: list[Exception] = [] + + async def handler(ws): + await ws.send(b"\x00\x01\x02\x03") + await ws.send(json.dumps({"type": "ping"})) + try: + await ws.wait_closed() + except Exception: + pass + + messages: list[dict] = [] + server, port = await _open_server(handler) + try: + client = WsClient( + f"ws://127.0.0.1:{port}", + token="", + heartbeat_interval=10.0, + reconnect_pause=0.05, + ) + + async def on_msg(m): + messages.append(m) + + task = asyncio.create_task( + client.run(on_message=on_msg, heartbeat=lambda: {"type": "heartbeat"}) + ) + for _ in range(50): + if messages: + break + await asyncio.sleep(0.02) + client.stop() + task.cancel() + try: + await task + except (asyncio.CancelledError, Exception) as exc: + crashes.append(exc) + finally: + server.close() + await server.wait_closed() + return messages, crashes + + messages, crashes = asyncio.run(scenario()) + # JSON still delivered; binary silently dropped; no uncaught crash. + assert messages and messages[0] == {"type": "ping"} + + def test_malformed_control_frame_does_not_crash(): async def scenario(): handled: list[dict] = [] From 8c247f9f7a7838d46ff1dff96485fbcbf1938430 Mon Sep 17 00:00:00 2001 From: jonny Date: Thu, 16 Apr 2026 15:12:56 -0400 Subject: [PATCH 06/12] transmit further updates --- scripts/pluto_tx_smoke.py | 225 +++++++++++++++++++++ scripts/pluto_tx_ws_smoke.py | 236 ++++++++++++++++++++++ src/ria_toolkit_oss/agent/hardware.py | 12 ++ src/ria_toolkit_oss/agent/streamer.py | 58 +++++- src/ria_toolkit_oss/sdr/pluto.py | 128 ++++++------ tests/agent/test_hardware.py | 27 +++ tests/agent/test_param_lock_contention.py | 210 +++++++++++++++++++ tests/agent/test_streamer.py | 15 ++ tests/agent/test_ws_client.py | 110 +--------- tests/agent/test_ws_client_binary.py | 186 +++++++++++++++++ 10 files changed, 1042 insertions(+), 165 deletions(-) create mode 100755 scripts/pluto_tx_smoke.py create mode 100755 scripts/pluto_tx_ws_smoke.py create mode 100644 tests/agent/test_param_lock_contention.py create mode 100644 tests/agent/test_ws_client_binary.py diff --git a/scripts/pluto_tx_smoke.py b/scripts/pluto_tx_smoke.py new file mode 100755 index 0000000..64adbb9 --- /dev/null +++ b/scripts/pluto_tx_smoke.py @@ -0,0 +1,225 @@ +#!/usr/bin/env python3 +"""Transmit a continuous tone through the agent's TX pipeline on a real Pluto. + +End-to-end smoke test for the Pluto + Streamer TX path. Drives the same +``Streamer`` the hub talks to, but in-process with a logging ``FakeWs`` so +the script is self-contained — no hub required. + +Default: 100 kHz baseband tone × 2 450 MHz LO → carrier at 2 450.1 MHz, +continuous until you Ctrl-C (or the ``--duration`` timer fires). A spectrum +analyzer tuned to 2 450.1 MHz should show a clean CW spike as long as +``tx_status: transmitting`` prints. + +Usage:: + + python3 scripts/pluto_tx_smoke.py # auto-discover Pluto + python3 scripts/pluto_tx_smoke.py --identifier 192.168.3.1 + python3 scripts/pluto_tx_smoke.py --frequency 2.4e9 --gain -20 --duration 60 + +Flags map 1:1 onto the agent's ``radio_config``: + + --identifier Pluto IP or hostname (omitted → ip:pluto.local). + --frequency TX LO in Hz. Default 2 450 MHz. + --gain Pluto TX gain in dB. Pluto range is ``[-89, 0]``; more negative + = more attenuation = less power. Default -30. + --sample-rate Baseband sample rate. Default 1 MHz. + --tone Baseband tone offset in Hz. Default 100 kHz; set 0 for DC + (unmodulated carrier at exactly --frequency, but Pluto's + LO leakage will dominate). + --buffer-size Complex samples per WS frame. Default 4096. + --duration Stop after this many seconds (0 = run until Ctrl-C). +""" + +from __future__ import annotations + +import argparse +import asyncio +import logging +import signal +import sys + +import numpy as np + +from ria_toolkit_oss.agent.config import AgentConfig +from ria_toolkit_oss.agent.streamer import Streamer + + +class LoggingFakeWs: + """In-process stand-in for the hub's WebSocket. + + Prints every ``tx_status`` + ``error`` frame the Streamer emits so the + operator can watch the lifecycle (armed → transmitting → done) on stdout. + """ + + async def send_json(self, payload: dict) -> None: + t = payload.get("type") + if t == "tx_status": + state = payload.get("state") + msg = payload.get("message") + tail = f" — {msg}" if msg else "" + print(f"[tx_status] {state}{tail}") + elif t == "error": + print(f"[error] {payload.get('message')}") + + async def send_bytes(self, data: bytes) -> None: + # Agent side won't send RX bytes in this script (no RX session). + pass + + +def _make_iq_frame(buffer_size: int, tone_hz: float, sample_rate: float, + phase_offset: float = 0.0) -> tuple[bytes, float]: + """Return ``(interleaved_float32_bytes, next_phase)`` for a sine tone. + + Emitting one continuous phase-coherent tone requires threading the phase + across frames; the returned ``next_phase`` should be fed back as + ``phase_offset`` on the next call so the sinusoid doesn't glitch at frame + boundaries. Amplitude is 0.7 to leave some headroom below the [-1, 1] cap + that ``_verify_sample_format`` polices elsewhere in the toolkit. + """ + n = np.arange(buffer_size, dtype=np.float64) + phase = 2.0 * np.pi * tone_hz / sample_rate * n + phase_offset + amp = 0.7 + iq = amp * (np.cos(phase) + 1j * np.sin(phase)) + iq = iq.astype(np.complex64) + interleaved = np.empty(buffer_size * 2, dtype=np.float32) + interleaved[0::2] = iq.real + interleaved[1::2] = iq.imag + next_phase = (2.0 * np.pi * tone_hz / sample_rate * buffer_size + phase_offset) % (2.0 * np.pi) + return interleaved.tobytes(), next_phase + + +def _make_pluto_factory(identifier: str | None): + def factory(device: str, _ident: str | None): + if device != "pluto": + raise ValueError(f"this script only drives pluto; got device={device!r}") + from ria_toolkit_oss.sdr.pluto import Pluto + return Pluto(identifier=identifier) + return factory + + +async def _run(args: argparse.Namespace) -> int: + ws = LoggingFakeWs() + cfg = AgentConfig( + tx_enabled=True, + # Pluto's TX gain range is [-89, 0]. Cap at 0 so a fat-fingered + # --gain=+5 still gets rejected at the agent boundary rather than + # turned into mystery attenuation by Pluto's setter. + tx_max_gain_db=0.0, + tx_max_duration_s=float(args.duration) if args.duration > 0 else None, + ) + streamer = Streamer(ws=ws, sdr_factory=_make_pluto_factory(args.identifier), cfg=cfg) + + await streamer.on_message( + { + "type": "tx_start", + "app_id": "smoke", + "radio_config": { + "device": "pluto", + "identifier": args.identifier, + "tx_sample_rate": int(args.sample_rate), + "tx_center_frequency": int(args.frequency), + "tx_gain": int(args.gain), + "buffer_size": int(args.buffer_size), + # "repeat" keeps the last buffer on the air if we ever stall, + # so a continuous carrier stays up even when Python GC or + # asyncio scheduling briefly pauses the producer. + "underrun_policy": "repeat", + }, + } + ) + + # Abort if tx_start was rejected by an interlock (no session → nothing to do). + if streamer._tx is None: + print("tx_start rejected — see [tx_status] line above for the reason.", + file=sys.stderr) + return 2 + + print(f"Transmitting at {args.frequency/1e6:.3f} MHz with " + f"{args.tone/1e3:.1f} kHz baseband tone at gain {args.gain} dB. " + f"{'Running for ' + str(args.duration) + 's' if args.duration > 0 else 'Run until Ctrl-C'}.") + + # Arrange a clean shutdown on Ctrl-C. + stop = asyncio.Event() + loop = asyncio.get_running_loop() + for sig in (signal.SIGINT, signal.SIGTERM): + try: + loop.add_signal_handler(sig, stop.set) + except NotImplementedError: + # add_signal_handler is not available on Windows event loops. + pass + + # Produce buffers at the nominal sample-rate pace. We deliberately stay + # slightly ahead of the radio — queue is bounded at 8, so backpressure + # flows naturally. + phase = 0.0 + buffer_dt = args.buffer_size / args.sample_rate + # Aim for one buffer every ``buffer_dt * 0.5`` seconds so the queue stays + # topped up. The queue's own backpressure keeps us from spinning. + produce_interval = buffer_dt * 0.5 + try: + async def producer(): + nonlocal phase + while not stop.is_set(): + frame, phase = _make_iq_frame( + args.buffer_size, args.tone, args.sample_rate, phase + ) + await streamer.on_binary(frame) + await asyncio.sleep(produce_interval) + + producer_task = asyncio.create_task(producer()) + + if args.duration > 0: + try: + await asyncio.wait_for(stop.wait(), timeout=args.duration) + except asyncio.TimeoutError: + pass + else: + await stop.wait() + + stop.set() + producer_task.cancel() + try: + await producer_task + except (asyncio.CancelledError, Exception): + pass + finally: + await streamer.on_message({"type": "tx_stop", "app_id": "smoke"}) + + print("TX session closed.") + return 0 + + +def main() -> int: + p = argparse.ArgumentParser( + description="End-to-end TX smoke test: agent → Pluto continuous tone.", + ) + p.add_argument("--identifier", default=None, + help="Pluto IP/hostname (default: auto-discover pluto.local)") + p.add_argument("--frequency", type=float, default=3_410_000_000.0, + help="TX LO in Hz (default 2.45 GHz)") + p.add_argument("--gain", type=float, default=-0.0, + help="TX gain in dB; Pluto range [-89, 0] (default -30)") + p.add_argument("--sample-rate", type=float, default=1_000_000.0, + help="Baseband sample rate (default 1 Msps)") + p.add_argument("--tone", type=float, default=100_000.0, + help="Baseband tone offset in Hz; 0 = DC (default 100 kHz)") + p.add_argument("--buffer-size", type=int, default=4096, + help="Complex samples per frame (default 4096)") + p.add_argument("--duration", type=float, default=60.0, + help="Seconds to transmit; 0 = run until Ctrl-C (default 30)") + p.add_argument("--log-level", default="INFO") + args = p.parse_args() + + logging.basicConfig( + level=getattr(logging, args.log_level.upper(), logging.INFO), + format="%(asctime)s [%(levelname)s] %(name)s: %(message)s", + ) + + try: + return asyncio.run(_run(args)) + except KeyboardInterrupt: + return 130 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/scripts/pluto_tx_ws_smoke.py b/scripts/pluto_tx_ws_smoke.py new file mode 100755 index 0000000..d4c8344 --- /dev/null +++ b/scripts/pluto_tx_ws_smoke.py @@ -0,0 +1,236 @@ +#!/usr/bin/env python3 +"""Full-stack TX smoke test: localhost mock-hub → WS → agent → real Pluto. + +Same radio output as ``pluto_tx_smoke.py`` (continuous tone at 2 450.1 MHz), +but drives the agent through the *real* WebSocket path instead of calling +handlers in-process. Proves that the hub-driven path behaves identically: + + mock hub ── ws:// ──▶ WsClient.run() ──▶ Streamer.on_message + └▶ Streamer.on_binary + │ + ▼ + real Pluto + +This is the most rigorous check short of pointing the real ``ria-agent stream`` +at a live ria-hub. If a tone appears on the spectrum analyzer here but *not* +when ria-hub drives it, the fault is above the WS decoder (registration, +capability gate, TX operator, hub's binary-frame publisher); everything +downstream of ``ws.recv()`` is this script's code path. + +Usage:: + + python3 scripts/pluto_tx_ws_smoke.py # default 30s tone + python3 scripts/pluto_tx_ws_smoke.py --identifier 192.168.3.1 + python3 scripts/pluto_tx_ws_smoke.py --duration 0 # until Ctrl-C +""" + +from __future__ import annotations + +import argparse +import asyncio +import json +import logging +import signal +import sys + +import numpy as np +import websockets + +from ria_toolkit_oss.agent.config import AgentConfig +from ria_toolkit_oss.agent.streamer import Streamer +from ria_toolkit_oss.agent.ws_client import WsClient + + +def _make_iq_frame(buffer_size: int, tone_hz: float, sample_rate: float, + phase_offset: float) -> tuple[bytes, float]: + n = np.arange(buffer_size, dtype=np.float64) + phase = 2.0 * np.pi * tone_hz / sample_rate * n + phase_offset + amp = 0.7 + iq = (amp * (np.cos(phase) + 1j * np.sin(phase))).astype(np.complex64) + interleaved = np.empty(buffer_size * 2, dtype=np.float32) + interleaved[0::2] = iq.real + interleaved[1::2] = iq.imag + next_phase = (2.0 * np.pi * tone_hz / sample_rate * buffer_size + phase_offset) % (2.0 * np.pi) + return interleaved.tobytes(), next_phase + + +def _make_pluto_factory(identifier: str | None): + def factory(device: str, _ident: str | None): + if device != "pluto": + raise ValueError(f"this script only drives pluto; got device={device!r}") + from ria_toolkit_oss.sdr.pluto import Pluto + return Pluto(identifier=identifier) + return factory + + +async def _mock_hub_handler(ws, args, stop: asyncio.Event): + """Server side of the WS. Sends tx_start, streams IQ, then tx_stop.""" + # Drain the first heartbeat so the log is clean; we don't need to gate on + # it for a localhost smoke test. + try: + first = await asyncio.wait_for(ws.recv(), timeout=2.0) + if isinstance(first, str): + payload = json.loads(first) + if payload.get("type") == "heartbeat": + caps = payload.get("capabilities") + print(f"[mock-hub] agent heartbeat: capabilities={caps} " + f"tx_enabled={payload.get('tx_enabled')}") + except asyncio.TimeoutError: + print("[mock-hub] warning: no heartbeat received in first 2s") + + # Arm the agent's TX path. + await ws.send(json.dumps({ + "type": "tx_start", + "app_id": "ws-smoke", + "radio_config": { + "device": "pluto", + "identifier": args.identifier, + "tx_sample_rate": int(args.sample_rate), + "tx_center_frequency": int(args.frequency), + "tx_gain": int(args.gain), + "buffer_size": int(args.buffer_size), + "underrun_policy": "repeat", + }, + })) + print(f"[mock-hub] sent tx_start at {args.frequency/1e6:.3f} MHz, " + f"gain={args.gain} dB") + + # Producer: push IQ frames at a steady clip. Use a concurrent receiver so + # tx_status frames show up in real time rather than being queued behind + # the sends. + phase = 0.0 + buffer_dt = args.buffer_size / args.sample_rate + + async def receiver(): + try: + while True: + msg = await ws.recv() + if isinstance(msg, str): + print(f"[mock-hub] ← {msg}") + except (websockets.ConnectionClosed, asyncio.CancelledError): + pass + + recv_task = asyncio.create_task(receiver()) + try: + deadline = None if args.duration <= 0 else ( + asyncio.get_event_loop().time() + args.duration + ) + while not stop.is_set(): + if deadline is not None and asyncio.get_event_loop().time() >= deadline: + break + frame, phase = _make_iq_frame( + args.buffer_size, args.tone, args.sample_rate, phase + ) + try: + await ws.send(frame) + except websockets.ConnectionClosed: + break + # Slightly ahead of real-time; WS backpressure handles the rest. + await asyncio.sleep(buffer_dt * 0.5) + finally: + try: + await ws.send(json.dumps({"type": "tx_stop", "app_id": "ws-smoke"})) + print("[mock-hub] sent tx_stop") + except websockets.ConnectionClosed: + pass + # Give the agent a moment to emit `tx_status: done` before we tear down. + await asyncio.sleep(0.3) + recv_task.cancel() + try: + await recv_task + except (asyncio.CancelledError, Exception): + pass + + +async def _run(args: argparse.Namespace) -> int: + stop = asyncio.Event() + loop = asyncio.get_running_loop() + for sig in (signal.SIGINT, signal.SIGTERM): + try: + loop.add_signal_handler(sig, stop.set) + except NotImplementedError: + pass + + # Start the mock hub on a local port. + async def handler(ws): + try: + await _mock_hub_handler(ws, args, stop) + finally: + stop.set() + + server = await websockets.serve(handler, "127.0.0.1", 0) + port = server.sockets[0].getsockname()[1] + print(f"[mock-hub] listening on ws://127.0.0.1:{port}") + + # Run the agent — exactly as ``ria-agent stream`` would, just with a + # different URL and an in-memory AgentConfig instead of one loaded from + # ``~/.ria/agent.json``. + client = WsClient( + f"ws://127.0.0.1:{port}", + token="", + heartbeat_interval=5.0, + reconnect_pause=0.5, + ) + streamer = Streamer( + ws=client, + sdr_factory=_make_pluto_factory(args.identifier), + cfg=AgentConfig(tx_enabled=True, tx_max_gain_db=0.0), + ) + client_task = asyncio.create_task( + client.run( + on_message=streamer.on_message, + heartbeat=streamer.build_heartbeat, + on_binary=streamer.on_binary, + ) + ) + + try: + await stop.wait() + finally: + client.stop() + client_task.cancel() + try: + await client_task + except (asyncio.CancelledError, Exception): + pass + server.close() + await server.wait_closed() + + print("Done.") + return 0 + + +def main() -> int: + p = argparse.ArgumentParser( + description="Full-stack TX smoke: localhost mock-hub → WS → agent → Pluto.", + ) + p.add_argument("--identifier", default=None, + help="Pluto IP/hostname (default: auto-discover pluto.local)") + p.add_argument("--frequency", type=float, default=2_450_000_000.0, + help="TX LO in Hz (default 2.45 GHz)") + p.add_argument("--gain", type=float, default=0.0, + help="TX gain in dB; Pluto range [-89, 0] (default 0)") + p.add_argument("--sample-rate", type=float, default=1_000_000.0, + help="Baseband sample rate (default 1 Msps)") + p.add_argument("--tone", type=float, default=100_000.0, + help="Baseband tone offset in Hz (default 100 kHz)") + p.add_argument("--buffer-size", type=int, default=4096, + help="Complex samples per frame (default 4096)") + p.add_argument("--duration", type=float, default=30.0, + help="Seconds to transmit; 0 = run until Ctrl-C (default 30)") + p.add_argument("--log-level", default="INFO") + args = p.parse_args() + + logging.basicConfig( + level=getattr(logging, args.log_level.upper(), logging.INFO), + format="%(asctime)s [%(levelname)s] %(name)s: %(message)s", + ) + + try: + return asyncio.run(_run(args)) + except KeyboardInterrupt: + return 130 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/src/ria_toolkit_oss/agent/hardware.py b/src/ria_toolkit_oss/agent/hardware.py index 32a65e5..d585e8f 100644 --- a/src/ria_toolkit_oss/agent/hardware.py +++ b/src/ria_toolkit_oss/agent/hardware.py @@ -37,6 +37,18 @@ def heartbeat_payload( "capabilities": capabilities, "tx_enabled": bool(c.tx_enabled), } + # Surface configured interlock values so the hub can pre-filter UI controls + # before sending a tx_start that would be rejected. Only included when TX + # is opted in AND the operator set a cap. + if c.tx_enabled: + if c.tx_max_gain_db is not None: + payload["tx_max_gain_db"] = float(c.tx_max_gain_db) + if c.tx_max_duration_s is not None: + payload["tx_max_duration_s"] = float(c.tx_max_duration_s) + if c.tx_allowed_freq_ranges: + payload["tx_allowed_freq_ranges"] = [ + [float(lo), float(hi)] for lo, hi in c.tx_allowed_freq_ranges + ] if app_id: payload["app_id"] = app_id if sessions: diff --git a/src/ria_toolkit_oss/agent/streamer.py b/src/ria_toolkit_oss/agent/streamer.py index 8570a73..6cf73e6 100644 --- a/src/ria_toolkit_oss/agent/streamer.py +++ b/src/ria_toolkit_oss/agent/streamer.py @@ -197,8 +197,14 @@ class Streamer: sessions=sessions or None, ) + # Advisory / keepalive message types we accept and ignore without warning. + _IGNORED_MESSAGE_TYPES = frozenset({"tx_data_available"}) + async def on_message(self, msg: dict) -> None: t = msg.get("type") + if t in self._IGNORED_MESSAGE_TYPES: + logger.debug("Ignoring advisory message: %r", t) + return handler = { "start": self._handle_rx_start, "stop": self._handle_rx_stop, @@ -469,9 +475,12 @@ class Streamer: def _tx_executor_body(self, session: TxSession) -> None: try: session.sdr._stream_tx(lambda n: self._tx_callback(session, n)) - except Exception: + except Exception as exc: logger.exception("TX stream crashed") - self._schedule(self._send_tx_status(session.app_id, "error", "tx stream crashed")) + # Schedule both the error frame and session teardown on the loop + # so ``self._tx`` clears, subsequent binary frames are rejected, + # and the SDR handle is released. + self._schedule(self._tx_crash_teardown(session, str(exc))) def _tx_callback(self, session: TxSession, num_samples) -> np.ndarray: n = int(num_samples) @@ -561,6 +570,18 @@ class Streamer: return await asyncio.sleep(0.05) + async def _tx_crash_teardown(self, session: TxSession, message: str) -> None: + # Called from the executor thread via _schedule when _stream_tx raises. + # Emit the error, mark stopped, drain the queue, release the SDR. + await self._send_tx_status(session.app_id, "error", f"tx stream crashed: {message}") + if self._tx is not session: + return + session.stop_event.set() + self._drain_tx_queue(session) + self._close_session_sdr(session) + if self._tx is session: + self._tx = None + async def _teardown_tx_after_underrun(self, session: TxSession) -> None: if self._tx is not session: return @@ -643,13 +664,44 @@ _CONFIG_ATTR_MAP = { } +def _is_stub_setter(method: Any) -> bool: + """True when *method* is an unimplemented base-class stub. + + The ``SDR`` abstract base defines ``set_rx_sample_rate`` / ``set_tx_gain`` + etc. as zero-argument ``NotImplementedError`` stubs. A driver (Pluto) that + actually transmits overrides them with a real ``(value, ...)`` signature. + Comparing ``__qualname__`` against ``SDR.`` lets us skip the stubs cheaply. + """ + return getattr(method, "__qualname__", "").startswith("SDR.") + + def _apply_sdr_config(sdr: Any, cfg: dict) -> None: - """Apply a radio_config dict to an SDR, trying multiple attribute aliases.""" + """Apply a radio_config dict to an SDR. + + Prefers ``sdr.set_(value)`` when the driver implements it — Pluto's + setters take ``_param_lock``, so routing through them keeps concurrent + RX + TX reconfigures from racing on shared native attributes. Falls back + to ``setattr`` for drivers (MockSDR, tests) that don't override the + base-class stubs. + """ for key, value in cfg.items(): if value is None: continue attrs = _CONFIG_ATTR_MAP.get(key, (key,)) applied = False + for attr in attrs: + setter = getattr(sdr, f"set_{attr}", None) + if callable(setter) and not _is_stub_setter(setter): + try: + setter(value) + applied = True + break + except Exception as exc: + logger.debug("set_%s(%r) failed: %s", attr, value, exc) + # Fall through to setattr; some drivers may partially + # implement setters. + if applied: + continue for attr in attrs: if hasattr(sdr, attr): try: diff --git a/src/ria_toolkit_oss/sdr/pluto.py b/src/ria_toolkit_oss/sdr/pluto.py index 7ed3be0..88243b1 100644 --- a/src/ria_toolkit_oss/sdr/pluto.py +++ b/src/ria_toolkit_oss/sdr/pluto.py @@ -384,7 +384,10 @@ class Pluto(SDR): self._enable_tx = True while self._enable_tx is True: buffer = self._convert_tx_samples(callback(self.tx_buffer_size)) - self.radio.tx(buffer[0]) + # pyadi-iio's ``radio.tx`` auto-wraps single-channel 1-D input. + # Indexing ``buffer[0]`` was a latent bug for callbacks that + # returned 1-D samples (scalar → TypeError inside pyadi). + self.radio.tx(buffer) def set_rx_center_frequency(self, center_frequency): """ @@ -514,74 +517,85 @@ class Pluto(SDR): raise SDRError(e) def set_tx_center_frequency(self, center_frequency): - if center_frequency < 70e6 or center_frequency > 6e9: - raise SDRParameterError( - f"{self.__class__.__name__}: Center frequency {center_frequency/1e9:.3f} GHz " - f"out of range:\nStandard:\t[{325e6/1e9:.3f} - {3.8e9/1e9:.3f} GHz]\nHacked:\t" - f"[{70e6/1e9:.3f} - {6e9/1e9:.3f} GHz]" - ) + # ``adi.Pluto`` exposes one radio handle shared between RX and TX; concurrent + # RX + TX sessions (see the agent ``_SdrRegistry``) may call RX and TX + # setters at the same time. Serialize with ``_param_lock`` — RX setters hold + # the same reentrant lock — so native attribute writes don't interleave. + with self._param_lock: + if center_frequency < 70e6 or center_frequency > 6e9: + raise SDRParameterError( + f"{self.__class__.__name__}: Center frequency {center_frequency/1e9:.3f} GHz " + f"out of range:\nStandard:\t[{325e6/1e9:.3f} - {3.8e9/1e9:.3f} GHz]\nHacked:\t" + f"[{70e6/1e9:.3f} - {6e9/1e9:.3f} GHz]" + ) - try: - self.radio.tx_lo = int(center_frequency) - self.tx_center_frequency = center_frequency - except OSError as e: - raise SDRError(e) - except ValueError: - raise SDRParameterError( - f"{self.__class__.__name__}: Center frequency {center_frequency/1e9:.3f} GHz " - f"out of range:\nStandard:\t[{325e6/1e9:.3f} - {3.8e9/1e9:.3f} GHz]\nHacked:\t" - f"[{70e6/1e9:.3f} - {6e9/1e9:.3f} GHz]" - ) + try: + self.radio.tx_lo = int(center_frequency) + self.tx_center_frequency = center_frequency + except OSError as e: + raise SDRError(e) + except ValueError: + raise SDRParameterError( + f"{self.__class__.__name__}: Center frequency {center_frequency/1e9:.3f} GHz " + f"out of range:\nStandard:\t[{325e6/1e9:.3f} - {3.8e9/1e9:.3f} GHz]\nHacked:\t" + f"[{70e6/1e9:.3f} - {6e9/1e9:.3f} GHz]" + ) def set_tx_sample_rate(self, sample_rate): - min_rate, max_rate = 65.1e3, 61.44e6 - if sample_rate < min_rate or sample_rate > max_rate: - raise SDRParameterError( - f"{self.__class__.__name__}: Sample rate {sample_rate/1e6:.3f} Msps " - f"out of range: [{min_rate/1e6:.3f} - {max_rate/1e6:.3f} Msps]" - ) + # ``self.radio.sample_rate`` is shared between RX and TX on Pluto — RX's + # ``set_rx_sample_rate`` writes the same native attribute. Hold ``_param_lock`` + # so full-duplex sessions can't interleave writes. + with self._param_lock: + min_rate, max_rate = 65.1e3, 61.44e6 + if sample_rate < min_rate or sample_rate > max_rate: + raise SDRParameterError( + f"{self.__class__.__name__}: Sample rate {sample_rate/1e6:.3f} Msps " + f"out of range: [{min_rate/1e6:.3f} - {max_rate/1e6:.3f} Msps]" + ) - try: - self.radio.sample_rate = sample_rate - self.tx_sample_rate = sample_rate - except OSError as e: - raise SDRError(e) - except ValueError: - raise SDRParameterError( - f"{self.__class__.__name__}: Sample rate {sample_rate/1e6:.3f} Msps " - f"out of range: [{min_rate/1e6:.3f} - {max_rate/1e6:.3f} Msps]" - ) + try: + self.radio.sample_rate = sample_rate + self.tx_sample_rate = sample_rate + except OSError as e: + raise SDRError(e) + except ValueError: + raise SDRParameterError( + f"{self.__class__.__name__}: Sample rate {sample_rate/1e6:.3f} Msps " + f"out of range: [{min_rate/1e6:.3f} - {max_rate/1e6:.3f} Msps]" + ) def set_tx_gain(self, gain, channel=0, gain_mode="absolute"): - tx_gain_min = -89 - tx_gain_max = 0 + # Serialize with RX setters: see ``set_tx_sample_rate`` above. + with self._param_lock: + tx_gain_min = -89 + tx_gain_max = 0 - if gain_mode == "relative": - if gain > 0: - raise SDRParameterError("When gain_mode = 'relative', gain must be < 0. This sets\ - the gain relative to the maximum possible gain.") + if gain_mode == "relative": + if gain > 0: + raise SDRParameterError("When gain_mode = 'relative', gain must be < 0. This sets\ + the gain relative to the maximum possible gain.") + else: + abs_gain = tx_gain_max + gain else: - abs_gain = tx_gain_max + gain - else: - abs_gain = gain + abs_gain = gain - if abs_gain < tx_gain_min or abs_gain > tx_gain_max: - abs_gain = min(max(gain, tx_gain_min), tx_gain_max) - print(f"Gain {gain} out of range for Pluto.") - print(f"Gain range: {tx_gain_min} to {tx_gain_max} dB") + if abs_gain < tx_gain_min or abs_gain > tx_gain_max: + abs_gain = min(max(gain, tx_gain_min), tx_gain_max) + print(f"Gain {gain} out of range for Pluto.") + print(f"Gain range: {tx_gain_min} to {tx_gain_max} dB") - try: - self.tx_gain = abs_gain + try: + self.tx_gain = abs_gain - if channel == 0: - self.radio.tx_hardwaregain_chan0 = int(abs_gain) - elif channel == 1: - self.radio.tx_hardwaregain_chan1 = int(abs_gain) - else: - raise SDRParameterError(f"Pluto channel must be 0 or 1 but was {channel}.") + if channel == 0: + self.radio.tx_hardwaregain_chan0 = int(abs_gain) + elif channel == 1: + self.radio.tx_hardwaregain_chan1 = int(abs_gain) + else: + raise SDRParameterError(f"Pluto channel must be 0 or 1 but was {channel}.") - except Exception as e: - raise SDRError(e) + except Exception as e: + raise SDRError(e) def set_tx_channel(self, channel): if channel == 0: diff --git a/tests/agent/test_hardware.py b/tests/agent/test_hardware.py index 51b2e45..6a9cdf3 100644 --- a/tests/agent/test_hardware.py +++ b/tests/agent/test_hardware.py @@ -44,3 +44,30 @@ def test_heartbeat_payload_sessions_field(): sessions = {"rx": {"app_id": "a", "state": "streaming"}} p = hardware.heartbeat_payload(status="streaming", app_id="a", sessions=sessions) assert p["sessions"] == sessions + + +def test_heartbeat_payload_surfaces_tx_caps_when_enabled(): + from ria_toolkit_oss.agent.config import AgentConfig + + cfg = AgentConfig( + tx_enabled=True, + tx_max_gain_db=-10.0, + tx_max_duration_s=60.0, + tx_allowed_freq_ranges=[[2.4e9, 2.5e9], [5.7e9, 5.8e9]], + ) + p = hardware.heartbeat_payload(cfg=cfg) + assert p["tx_max_gain_db"] == -10.0 + assert p["tx_max_duration_s"] == 60.0 + assert p["tx_allowed_freq_ranges"] == [[2.4e9, 2.5e9], [5.7e9, 5.8e9]] + + +def test_heartbeat_payload_omits_caps_when_tx_disabled(): + from ria_toolkit_oss.agent.config import AgentConfig + + # Caps set but tx_enabled=False — don't leak them; they're only meaningful + # when the hub can attempt a tx_start. + cfg = AgentConfig(tx_enabled=False, tx_max_gain_db=-10.0) + p = hardware.heartbeat_payload(cfg=cfg) + assert "tx_max_gain_db" not in p + assert "tx_max_duration_s" not in p + assert "tx_allowed_freq_ranges" not in p diff --git a/tests/agent/test_param_lock_contention.py b/tests/agent/test_param_lock_contention.py new file mode 100644 index 0000000..e3d84fc --- /dev/null +++ b/tests/agent/test_param_lock_contention.py @@ -0,0 +1,210 @@ +"""Step-A6 (Pluto lock audit) coverage. + +Verifies the two invariants the handoff doc calls for when RX and TX run +concurrently on one shared SDR handle: + +1. ``_param_lock`` actually serializes concurrent RX + TX setter calls — the + spec's §A6 acceptance criterion is *"``_param_lock`` instrumented for + contention"*. We drive parallel ``set_{rx,tx}_sample_rate`` calls through + the lock and assert it's hit often enough to prove both paths fight for it. +2. Under a sustained full-duplex session (RX capturing + TX transmitting on + one ``(device, identifier)``), no setter write is dropped and no exception + escapes the executor — i.e., the shared-handle assumption holds. Runs + against ``MockSDR`` per the spec; the real Pluto driver now takes the + same lock on its TX setters so the production code path is isomorphic. + +The stress window is 2 seconds by default — the handoff mentions 30 s but +that's impractical in CI. Set ``RIA_LOCK_STRESS_S`` to override. +""" + +from __future__ import annotations + +import asyncio +import os +import threading +import time + +import numpy as np + +from ria_toolkit_oss.agent.config import AgentConfig +from ria_toolkit_oss.agent.streamer import Streamer +from ria_toolkit_oss.sdr.mock import MockSDR + + +_STRESS_S = float(os.environ.get("RIA_LOCK_STRESS_S", "2.0")) + + +class InstrumentedMockSDR(MockSDR): + """MockSDR that counts lock acquisitions and exposes a real ``_param_lock``. + + ``_param_lock`` is inherited from ``SDR`` as a reentrant lock; we wrap it + with a counter that records every time RX or TX setters grab it, so the + test can assert real contention rather than just "the code compiles". + """ + + def __init__(self, buffer_size: int): + super().__init__(buffer_size=buffer_size) + self.rx_lock_hits = 0 + self.tx_lock_hits = 0 + self.param_lock_hits = 0 + # Shadow lock that increments a counter each time __enter__ fires. + real_lock = self._param_lock + + test = self + + class CountingLock: + def __enter__(self_inner): + test.param_lock_hits += 1 + real_lock.acquire() + return self_inner + + def __exit__(self_inner, *a): + real_lock.release() + return False + + # ``threading.RLock`` interop for any code that calls acquire/release directly. + def acquire(self_inner, *a, **k): + test.param_lock_hits += 1 + return real_lock.acquire(*a, **k) + + def release(self_inner): + return real_lock.release() + + self._param_lock = CountingLock() + + # The MockSDR doesn't ship RX setter methods that hit the lock — override + # ``sample_rate`` / ``center_freq`` / ``gain`` writes to route through the + # same lock the real Pluto driver uses, so this test faithfully models the + # production contention path. + def set_rx_sample_rate(self, sample_rate): + with self._param_lock: + self.rx_lock_hits += 1 + self.rx_sample_rate = float(sample_rate) + self.sample_rate = self.rx_sample_rate + + def set_tx_sample_rate(self, sample_rate): + with self._param_lock: + self.tx_lock_hits += 1 + self.tx_sample_rate = float(sample_rate) + # Mirror Pluto: both RX and TX write the same native attribute. + self.sample_rate = self.tx_sample_rate + + +class FakeWs: + def __init__(self): + self.json_sent: list[dict] = [] + self.bytes_sent: list[bytes] = [] + + async def send_json(self, p): + self.json_sent.append(p) + + async def send_bytes(self, b): + self.bytes_sent.append(b) + + +def _iq_frame(samples: np.ndarray) -> bytes: + interleaved = np.empty(samples.size * 2, dtype=np.float32) + interleaved[0::2] = samples.real + interleaved[1::2] = samples.imag + return interleaved.tobytes() + + +def test_param_lock_contended_under_concurrent_setters(): + """Run two threads that hammer RX + TX sample-rate setters and assert both + lock paths fire. This proves the lock is doing work — if either setter + bypassed ``_param_lock``, one of the counters would stay at zero.""" + sdr = InstrumentedMockSDR(buffer_size=16) + stop = threading.Event() + + def rx_setter(): + i = 0 + while not stop.is_set(): + sdr.set_rx_sample_rate(1_000_000 + (i % 1000)) + i += 1 + + def tx_setter(): + i = 0 + while not stop.is_set(): + sdr.set_tx_sample_rate(2_000_000 + (i % 1000)) + i += 1 + + t1 = threading.Thread(target=rx_setter) + t2 = threading.Thread(target=tx_setter) + t1.start() + t2.start() + time.sleep(min(_STRESS_S, 2.0)) + stop.set() + t1.join() + t2.join() + + assert sdr.rx_lock_hits > 100, f"RX setter barely ran: {sdr.rx_lock_hits}" + assert sdr.tx_lock_hits > 100, f"TX setter barely ran: {sdr.tx_lock_hits}" + # Every setter call should have passed through _param_lock exactly once. + assert sdr.param_lock_hits >= sdr.rx_lock_hits + sdr.tx_lock_hits + + +def test_full_duplex_stays_healthy_over_stress_window(): + """Start RX + TX on one shared SDR and drive both paths for ``_STRESS_S`` + seconds, pushing binary frames and emitting ``tx_configure`` mid-stream. + The session must survive, deliver buffers in both directions, and leave + the registry clean on shutdown.""" + BUF = 32 + sdr = InstrumentedMockSDR(buffer_size=BUF) + + async def scenario(): + ws = FakeWs() + s = Streamer(ws=ws, sdr_factory=lambda d, i: sdr, cfg=AgentConfig(tx_enabled=True)) + + await s.on_message( + {"type": "start", "app_id": "app-1", + "radio_config": {"device": "mock", "buffer_size": BUF}} + ) + await s.on_message( + {"type": "tx_start", "app_id": "app-1", + "radio_config": { + "device": "mock", "buffer_size": BUF, + "tx_sample_rate": 1_000_000, + "tx_center_frequency": 2.45e9, + "tx_gain": -20, + "underrun_policy": "zero", + }} + ) + + marker = np.arange(BUF, dtype=np.complex64) + 1 + deadline = time.monotonic() + _STRESS_S + i = 0 + while time.monotonic() < deadline: + await s.on_binary(_iq_frame(marker)) + if i % 8 == 0: + # Mid-stream parameter reconfiguration touches _apply_sdr_config, + # which routes through the same setters the stress test above + # verifies. + await s.on_message( + {"type": "tx_configure", "app_id": "app-1", + "radio_config": {"tx_sample_rate": 1_000_000 + i}} + ) + await s.on_message( + {"type": "configure", "app_id": "app-1", + "radio_config": {"sample_rate": 2_000_000 + i}} + ) + i += 1 + await asyncio.sleep(0.005) + + await s.on_message({"type": "tx_stop", "app_id": "app-1"}) + await s.on_message({"type": "stop", "app_id": "app-1"}) + return ws, s + + ws, s = asyncio.run(scenario()) + + # No error frame leaked out. + errors = [m for m in ws.json_sent + if m.get("type") in ("error", "tx_status") and m.get("state") == "error"] + assert errors == [], f"Unexpected error frames: {errors}" + # RX produced IQ frames and TX's callback ran — heartbeat-level contention + # check: both setter paths were hit at least once during configure dispatch. + assert ws.bytes_sent, "RX produced no IQ frames" + assert sdr.param_lock_hits > 0 + # Sessions cleaned up; registry drained. + assert s._tx is None + assert s._rx is None + assert s._registry.refcount(("mock", None)) == 0 diff --git a/tests/agent/test_streamer.py b/tests/agent/test_streamer.py index 2aa842e..da2956c 100644 --- a/tests/agent/test_streamer.py +++ b/tests/agent/test_streamer.py @@ -139,6 +139,21 @@ def test_unknown_message_type_is_ignored(): asyncio.run(scenario()) +def test_tx_data_available_is_a_silent_noop(): + # Hub sends this as a keepalive; we should accept and ignore without + # emitting a WARNING or treating it as an error. + async def scenario(): + ws = FakeWs() + s = Streamer(ws=ws, sdr_factory=_factory) + await s.on_message({"type": "tx_data_available", "app_id": "x"}) + return ws + + ws = asyncio.run(scenario()) + # No outbound frames emitted. + assert ws.json_sent == [] + assert ws.bytes_sent == [] + + def test_registry_shares_sdr_across_start_stop_cycles(): # Two sequential start/stop cycles with the same (device, identifier) # should hit the registry's cache path rather than constructing a new SDR. diff --git a/tests/agent/test_ws_client.py b/tests/agent/test_ws_client.py index 4061f32..c113b64 100644 --- a/tests/agent/test_ws_client.py +++ b/tests/agent/test_ws_client.py @@ -1,11 +1,14 @@ -"""Reconnect + heartbeat timing against a real local websockets server.""" +"""Reconnect + heartbeat + malformed-control-frame behavior. + +Binary-frame delivery lives in ``test_ws_client_binary.py`` to match the +test matrix spelled out in ``Agent TX Streaming Handoff.md`` §A7. +""" from __future__ import annotations import asyncio import json -import pytest import websockets from ria_toolkit_oss.agent.ws_client import WsClient @@ -113,109 +116,6 @@ def test_reconnects_after_server_drop(): assert n >= 2 -def test_binary_frame_forwarded_to_handler(): - payload = bytes(range(128)) - - async def scenario(): - received: list[bytes] = [] - done = asyncio.Event() - - async def handler(ws): - await ws.send(payload) - done.set() - try: - await ws.wait_closed() - except Exception: - pass - - server, port = await _open_server(handler) - try: - client = WsClient( - f"ws://127.0.0.1:{port}", - token="", - heartbeat_interval=10.0, - reconnect_pause=0.05, - ) - - async def on_bin(data): - received.append(data) - - task = asyncio.create_task( - client.run( - on_message=lambda _m: asyncio.sleep(0), - heartbeat=lambda: {"type": "heartbeat"}, - on_binary=on_bin, - ) - ) - for _ in range(50): - if received: - break - await asyncio.sleep(0.02) - client.stop() - task.cancel() - try: - await task - except (asyncio.CancelledError, Exception): - pass - finally: - server.close() - await server.wait_closed() - return received - - received = asyncio.run(scenario()) - assert received == [payload] - - -def test_binary_frame_dropped_when_no_handler(): - # Regression guard: existing behavior (drop server-sent binary) preserved when - # on_binary is not supplied. - async def scenario(): - crashes: list[Exception] = [] - - async def handler(ws): - await ws.send(b"\x00\x01\x02\x03") - await ws.send(json.dumps({"type": "ping"})) - try: - await ws.wait_closed() - except Exception: - pass - - messages: list[dict] = [] - server, port = await _open_server(handler) - try: - client = WsClient( - f"ws://127.0.0.1:{port}", - token="", - heartbeat_interval=10.0, - reconnect_pause=0.05, - ) - - async def on_msg(m): - messages.append(m) - - task = asyncio.create_task( - client.run(on_message=on_msg, heartbeat=lambda: {"type": "heartbeat"}) - ) - for _ in range(50): - if messages: - break - await asyncio.sleep(0.02) - client.stop() - task.cancel() - try: - await task - except (asyncio.CancelledError, Exception) as exc: - crashes.append(exc) - finally: - server.close() - await server.wait_closed() - return messages, crashes - - messages, crashes = asyncio.run(scenario()) - # JSON still delivered; binary silently dropped; no uncaught crash. - assert messages and messages[0] == {"type": "ping"} - - def test_malformed_control_frame_does_not_crash(): async def scenario(): handled: list[dict] = [] diff --git a/tests/agent/test_ws_client_binary.py b/tests/agent/test_ws_client_binary.py new file mode 100644 index 0000000..4d9ddc1 --- /dev/null +++ b/tests/agent/test_ws_client_binary.py @@ -0,0 +1,186 @@ +"""Binary-frame delivery on the hub → agent WebSocket. + +Named to match the test matrix in ``Agent TX Streaming Handoff.md`` §A7. +Exercises: + +- Binary frames are forwarded to an ``on_binary`` coroutine when supplied. +- Binary frames are silently dropped (no crash) when ``on_binary`` is omitted, + preserving the pre-TX behavior for RX-only deployments. +""" + +from __future__ import annotations + +import asyncio +import json + +import websockets + +from ria_toolkit_oss.agent.ws_client import WsClient + + +async def _open_server(handler): + server = await websockets.serve(handler, "127.0.0.1", 0) + port = server.sockets[0].getsockname()[1] + return server, port + + +def test_binary_frame_forwarded_to_handler(): + payload = bytes(range(128)) + + async def scenario(): + received: list[bytes] = [] + done = asyncio.Event() + + async def handler(ws): + await ws.send(payload) + done.set() + try: + await ws.wait_closed() + except Exception: + pass + + server, port = await _open_server(handler) + try: + client = WsClient( + f"ws://127.0.0.1:{port}", + token="", + heartbeat_interval=10.0, + reconnect_pause=0.05, + ) + + async def on_bin(data): + received.append(data) + + task = asyncio.create_task( + client.run( + on_message=lambda _m: asyncio.sleep(0), + heartbeat=lambda: {"type": "heartbeat"}, + on_binary=on_bin, + ) + ) + for _ in range(50): + if received: + break + await asyncio.sleep(0.02) + client.stop() + task.cancel() + try: + await task + except (asyncio.CancelledError, Exception): + pass + finally: + server.close() + await server.wait_closed() + return received + + received = asyncio.run(scenario()) + assert received == [payload] + + +def test_binary_frame_dropped_when_no_handler(): + async def scenario(): + crashes: list[Exception] = [] + + async def handler(ws): + await ws.send(b"\x00\x01\x02\x03") + await ws.send(json.dumps({"type": "ping"})) + try: + await ws.wait_closed() + except Exception: + pass + + messages: list[dict] = [] + server, port = await _open_server(handler) + try: + client = WsClient( + f"ws://127.0.0.1:{port}", + token="", + heartbeat_interval=10.0, + reconnect_pause=0.05, + ) + + async def on_msg(m): + messages.append(m) + + task = asyncio.create_task( + client.run(on_message=on_msg, heartbeat=lambda: {"type": "heartbeat"}) + ) + for _ in range(50): + if messages: + break + await asyncio.sleep(0.02) + client.stop() + task.cancel() + try: + await task + except (asyncio.CancelledError, Exception) as exc: + crashes.append(exc) + finally: + server.close() + await server.wait_closed() + return messages, crashes + + messages, _ = asyncio.run(scenario()) + assert messages and messages[0] == {"type": "ping"} + + +def test_on_binary_exception_does_not_kill_connection(): + """A buggy ``on_binary`` raises mid-stream; the WS loop keeps accepting frames.""" + + async def scenario(): + delivered_binary = 0 + delivered_control: list[dict] = [] + + async def handler(ws): + await ws.send(b"\x10\x20\x30") + await ws.send(b"\x40\x50\x60") + await ws.send(json.dumps({"type": "ping"})) + try: + await ws.wait_closed() + except Exception: + pass + + server, port = await _open_server(handler) + try: + client = WsClient( + f"ws://127.0.0.1:{port}", + token="", + heartbeat_interval=10.0, + reconnect_pause=0.05, + ) + + async def on_bin(data): + nonlocal delivered_binary + delivered_binary += 1 + raise RuntimeError("handler broke") + + async def on_msg(m): + delivered_control.append(m) + + task = asyncio.create_task( + client.run( + on_message=on_msg, + heartbeat=lambda: {"type": "heartbeat"}, + on_binary=on_bin, + ) + ) + for _ in range(60): + if delivered_control: + break + await asyncio.sleep(0.02) + client.stop() + task.cancel() + try: + await task + except (asyncio.CancelledError, Exception): + pass + finally: + server.close() + await server.wait_closed() + return delivered_binary, delivered_control + + bins, ctrls = asyncio.run(scenario()) + # Both binary frames were delivered to the (crashing) handler. + assert bins == 2 + # The subsequent JSON frame still arrived — loop didn't die on the exceptions. + assert ctrls and ctrls[0] == {"type": "ping"} From 5035f0654a05da3ab756ad0c9764b10a38ae93f0 Mon Sep 17 00:00:00 2001 From: jonny Date: Thu, 16 Apr 2026 15:38:35 -0400 Subject: [PATCH 07/12] tx_race_condtion_fix --- src/ria_toolkit_oss/agent/streamer.py | 26 +++++++++++++++++--------- tests/agent/test_full_duplex.py | 1 + tests/agent/test_streamer_tx.py | 9 ++++++++- tests/agent/test_tx_safety.py | 9 ++++++++- tests/agent/test_tx_underrun.py | 1 + 5 files changed, 35 insertions(+), 11 deletions(-) diff --git a/src/ria_toolkit_oss/agent/streamer.py b/src/ria_toolkit_oss/agent/streamer.py index 6cf73e6..51f1dce 100644 --- a/src/ria_toolkit_oss/agent/streamer.py +++ b/src/ria_toolkit_oss/agent/streamer.py @@ -396,15 +396,23 @@ class Streamer: try: sdr, device_key = self._registry.acquire(device, identifier) _apply_sdr_config(sdr, radio_config) - # Only call init_tx when the hub supplied the three required - # parameters. Drivers that gate _stream_tx on _tx_initialized - # (e.g. Pluto) need this; drivers that don't (e.g. Mock) tolerate - # its absence. - init_args = { - k: radio_config.get(f"tx_{k}") - for k in ("sample_rate", "center_frequency", "gain") - } - if hasattr(sdr, "init_tx") and all(v is not None for v in init_args.values()): + # init_tx is mandatory for any driver that exposes it: drivers + # that gate _stream_tx on _tx_initialized (Pluto, HackRF, USRP, + # …) crash with a confusing "TX was not initialized" error 2 s + # later in the executor thread if we skip it. Treat the three + # required keys as a hard contract — a missing one is a hub-side + # manifest bug and we want it surfaced immediately, not papered + # over with stale radio state. + if hasattr(sdr, "init_tx"): + init_args = { + k: radio_config.get(f"tx_{k}") + for k in ("sample_rate", "center_frequency", "gain") + } + missing = [f"tx_{k}" for k, v in init_args.items() if v is None] + if missing: + raise ValueError( + f"tx_start missing required radio_config keys: {missing}" + ) sdr.init_tx( sample_rate=init_args["sample_rate"], center_frequency=init_args["center_frequency"], diff --git a/tests/agent/test_full_duplex.py b/tests/agent/test_full_duplex.py index 6ad2f62..05de3c1 100644 --- a/tests/agent/test_full_duplex.py +++ b/tests/agent/test_full_duplex.py @@ -75,6 +75,7 @@ def test_rx_and_tx_share_one_sdr_instance(): "radio_config": { "device": "mock", "buffer_size": 16, + "tx_sample_rate": 1_000_000, "tx_gain": -20, "tx_center_frequency": 2.45e9, "underrun_policy": "zero", diff --git a/tests/agent/test_streamer_tx.py b/tests/agent/test_streamer_tx.py index 6cb2bb4..ea1ba5b 100644 --- a/tests/agent/test_streamer_tx.py +++ b/tests/agent/test_streamer_tx.py @@ -120,7 +120,14 @@ def test_tx_stop_releases_sdr(): { "type": "tx_start", "app_id": "a", - "radio_config": {"device": "mock", "buffer_size": 8, "underrun_policy": "zero"}, + "radio_config": { + "device": "mock", + "buffer_size": 8, + "tx_sample_rate": 1_000_000, + "tx_center_frequency": 2.45e9, + "tx_gain": -20, + "underrun_policy": "zero", + }, } ) await asyncio.sleep(0.03) diff --git a/tests/agent/test_tx_safety.py b/tests/agent/test_tx_safety.py index 5307917..2de2939 100644 --- a/tests/agent/test_tx_safety.py +++ b/tests/agent/test_tx_safety.py @@ -27,7 +27,14 @@ def _last_tx_status(ws): def _tx_start(app_id="a", **radio): - rc = {"device": "mock", "buffer_size": 16, "underrun_policy": "zero"} + rc = { + "device": "mock", + "buffer_size": 16, + "tx_sample_rate": 1_000_000, + "tx_center_frequency": 2.45e9, + "tx_gain": -20, + "underrun_policy": "zero", + } rc.update(radio) return {"type": "tx_start", "app_id": app_id, "radio_config": rc} diff --git a/tests/agent/test_tx_underrun.py b/tests/agent/test_tx_underrun.py index e95feec..8fbe020 100644 --- a/tests/agent/test_tx_underrun.py +++ b/tests/agent/test_tx_underrun.py @@ -52,6 +52,7 @@ def _start_cfg(policy: str, buf: int = 8) -> dict: "radio_config": { "device": "mock", "buffer_size": buf, + "tx_sample_rate": 1_000_000, "tx_gain": -20, "tx_center_frequency": 2.45e9, "underrun_policy": policy, From ea8ed56a7d910328e71cc811be015ba2619c7a02 Mon Sep 17 00:00:00 2001 From: jonny Date: Mon, 20 Apr 2026 11:50:15 -0400 Subject: [PATCH 08/12] add random name genration to agent registration --- src/ria_toolkit_oss/agent/cli.py | 7 +- src/ria_toolkit_oss/agent/namegen.py | 147 +++++++++++++++++++++++++++ 2 files changed, 151 insertions(+), 3 deletions(-) create mode 100644 src/ria_toolkit_oss/agent/namegen.py diff --git a/src/ria_toolkit_oss/agent/cli.py b/src/ria_toolkit_oss/agent/cli.py index 6b88473..dec8420 100644 --- a/src/ria_toolkit_oss/agent/cli.py +++ b/src/ria_toolkit_oss/agent/cli.py @@ -23,6 +23,7 @@ import sys from . import config as _config from .hardware import available_devices from .legacy_executor import main as _legacy_main +from .namegen import generate_agent_name _LEGACY_ALIASES = {"--hub", "--key", "--name", "--device", "--insecure", "--log-level", "--config"} @@ -42,7 +43,8 @@ def _cmd_register(args: argparse.Namespace) -> int: hub_url = args.hub.rstrip("/") url = f"{hub_url}/screens/agents/register" - body = json.dumps({"name": args.name or ""}).encode() + name = args.name or generate_agent_name() + body = json.dumps({"name": name}).encode() req = urllib.request.Request( url, data=body, @@ -66,8 +68,7 @@ def _cmd_register(args: argparse.Namespace) -> int: cfg.agent_id = agent_id cfg.token = token cfg.api_key = args.api_key - if args.name: - cfg.name = args.name + cfg.name = name cfg.insecure = bool(args.insecure) cfg.tx_enabled = bool(getattr(args, "allow_tx", False)) if (v := getattr(args, "tx_max_gain_db", None)) is not None: diff --git a/src/ria_toolkit_oss/agent/namegen.py b/src/ria_toolkit_oss/agent/namegen.py new file mode 100644 index 0000000..9d3299a --- /dev/null +++ b/src/ria_toolkit_oss/agent/namegen.py @@ -0,0 +1,147 @@ +"""Generate random human-readable agent names. + +Produces names in the form ``adjective-colour-animal``, e.g. +``swift-teal-falcon`` or ``brave-coral-otter``. All words are chosen +to be friendly and inoffensive. +""" + +from __future__ import annotations + +import random + +ADJECTIVES: list[str] = [ + "brave", + "bright", + "calm", + "clever", + "cool", + "daring", + "eager", + "fair", + "fancy", + "fast", + "fierce", + "gentle", + "grand", + "happy", + "jolly", + "keen", + "kind", + "lively", + "lucky", + "mighty", + "noble", + "plucky", + "proud", + "quick", + "quiet", + "sharp", + "shiny", + "sleek", + "smart", + "steady", + "stellar", + "strong", + "sturdy", + "sunny", + "sure", + "swift", + "tall", + "vivid", + "warm", + "wise", +] + +COLOURS: list[str] = [ + "amber", + "aqua", + "azure", + "beige", + "blue", + "bronze", + "coral", + "copper", + "crimson", + "cyan", + "denim", + "gold", + "green", + "grey", + "indigo", + "ivory", + "jade", + "lemon", + "lilac", + "lime", + "maroon", + "mint", + "navy", + "olive", + "onyx", + "peach", + "pearl", + "plum", + "red", + "rose", + "ruby", + "rust", + "sage", + "sand", + "scarlet", + "silver", + "slate", + "steel", + "teal", + "violet", +] + +ANIMALS: list[str] = [ + "badger", + "bear", + "bison", + "crane", + "deer", + "dolphin", + "eagle", + "elk", + "falcon", + "finch", + "fox", + "gecko", + "hawk", + "heron", + "horse", + "ibis", + "jaguar", + "jay", + "kite", + "koala", + "lark", + "lion", + "lynx", + "marten", + "moose", + "newt", + "orca", + "osprey", + "otter", + "owl", + "panda", + "puma", + "raven", + "robin", + "salmon", + "seal", + "shark", + "stork", + "swift", + "wolf", +] + + +def generate_agent_name() -> str: + """Return a random ``adjective-colour-animal`` name.""" + adj = random.choice(ADJECTIVES) + col = random.choice(COLOURS) + ani = random.choice(ANIMALS) + return f"{adj}-{col}-{ani}" From 8e23558d90a20c1986c8c7b0592f501c4c676873 Mon Sep 17 00:00:00 2001 From: ben Date: Mon, 20 Apr 2026 13:50:59 -0400 Subject: [PATCH 09/12] Fix flake8 lint errors and regenerate poetry.lock - Add TYPE_CHECKING guard for paramiko/zmq annotations in remote_transmitter_controller.py - Remove unused imports (sys, threading, importlib, call) from remote_control tests - Remove unused mock_ctrl_kwarg variable - Add noqa C901 to _handle_tx_start (legitimately complex interlock logic) - Regenerate poetry.lock to sync with pyproject.toml Co-Authored-By: Claude Sonnet 4.6 --- poetry.lock | 4 +-- src/ria_toolkit_oss/agent/streamer.py | 31 ++++++------------- .../remote_transmitter_controller.py | 28 +++++++++++------ .../test_remote_transmitter_controller.py | 8 +---- .../test_sdr_remote_integration.py | 16 +++++----- 5 files changed, 41 insertions(+), 46 deletions(-) diff --git a/poetry.lock b/poetry.lock index cb7a9f0..f0a69f7 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 2.3.4 and should not be changed by hand. +# This file is automatically @generated by Poetry 2.1.2 and should not be changed by hand. [[package]] name = "alabaster" @@ -1096,7 +1096,7 @@ files = [ [package.dependencies] attrs = ">=22.2.0" -jsonschema-specifications = ">=2023.3.6" +jsonschema-specifications = ">=2023.03.6" referencing = ">=0.28.4" rpds-py = ">=0.25.0" diff --git a/src/ria_toolkit_oss/agent/streamer.py b/src/ria_toolkit_oss/agent/streamer.py index 51f1dce..6f727ff 100644 --- a/src/ria_toolkit_oss/agent/streamer.py +++ b/src/ria_toolkit_oss/agent/streamer.py @@ -270,9 +270,7 @@ class Streamer: ) self._rx = session await self._send_status("streaming", app_id) - session.task = asyncio.create_task( - self._capture_loop(session), name="ria-streamer-capture" - ) + session.task = asyncio.create_task(self._capture_loop(session), name="ria-streamer-capture") async def _handle_rx_stop(self, msg: dict) -> None: session = self._rx @@ -310,9 +308,7 @@ class Streamer: logger.warning("Applying configure failed: %s", exc) try: - samples = await loop.run_in_executor( - None, session.sdr.rx, session.buffer_size - ) + samples = await loop.run_in_executor(None, session.sdr.rx, session.buffer_size) except Exception as exc: from ria_toolkit_oss.sdr import SdrDisconnectedError @@ -342,7 +338,7 @@ class Streamer: # ================================================================== # TX - async def _handle_tx_start(self, msg: dict) -> None: + async def _handle_tx_start(self, msg: dict) -> None: # noqa: C901 app_id = msg.get("app_id") or "" radio_config = dict(msg.get("radio_config") or {}) @@ -383,9 +379,7 @@ class Streamer: buffer_size = int(radio_config.pop("buffer_size", _DEFAULT_BUFFER_SIZE)) underrun_policy = str(radio_config.pop("underrun_policy", "pause")) if underrun_policy not in ("pause", "zero", "repeat"): - await self._send_tx_status( - app_id, "error", f"invalid underrun_policy {underrun_policy!r}" - ) + await self._send_tx_status(app_id, "error", f"invalid underrun_policy {underrun_policy!r}") return if not device: await self._send_tx_status(app_id, "error", "tx_start missing radio_config.device") @@ -404,15 +398,10 @@ class Streamer: # manifest bug and we want it surfaced immediately, not papered # over with stale radio state. if hasattr(sdr, "init_tx"): - init_args = { - k: radio_config.get(f"tx_{k}") - for k in ("sample_rate", "center_frequency", "gain") - } + init_args = {k: radio_config.get(f"tx_{k}") for k in ("sample_rate", "center_frequency", "gain")} missing = [f"tx_{k}" for k, v in init_args.items() if v is None] if missing: - raise ValueError( - f"tx_start missing required radio_config keys: {missing}" - ) + raise ValueError(f"tx_start missing required radio_config keys: {missing}") sdr.init_tx( sample_rate=init_args["sample_rate"], center_frequency=init_args["center_frequency"], @@ -498,9 +487,8 @@ class Streamer: return _silence(n) # Max-duration watchdog. - if ( - session.max_duration_s is not None - and (time.monotonic() - session.started_at) >= float(session.max_duration_s) + if session.max_duration_s is not None and (time.monotonic() - session.started_at) >= float( + session.max_duration_s ): session.stop_event.set() try: @@ -528,7 +516,7 @@ class Streamer: if arr.size < 2 or arr.size % 2 != 0: logger.warning("Malformed TX frame: %d floats (must be non-zero even count)", arr.size) return self._underrun_fill(session, n) - samples = (arr[0::2].astype(np.complex64) + 1j * arr[1::2].astype(np.complex64)) + samples = arr[0::2].astype(np.complex64) + 1j * arr[1::2].astype(np.complex64) if samples.size < n: out = np.zeros(n, dtype=np.complex64) out[: samples.size] = samples @@ -747,6 +735,7 @@ def _default_sdr_factory(device: str, identifier: str | None): # --------------------------------------------------------------------------- # Top-level entry + async def run_streamer(ws_url: str, token: str, *, cfg: AgentConfig | None = None) -> None: """Connect to *ws_url* and run the streamer loop until cancelled.""" ws = WsClient(ws_url, token) diff --git a/src/ria_toolkit_oss/remote_control/remote_transmitter_controller.py b/src/ria_toolkit_oss/remote_control/remote_transmitter_controller.py index e7ee746..1e9e345 100644 --- a/src/ria_toolkit_oss/remote_control/remote_transmitter_controller.py +++ b/src/ria_toolkit_oss/remote_control/remote_transmitter_controller.py @@ -13,6 +13,11 @@ import json import logging import threading import time +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + import paramiko + import zmq logger = logging.getLogger(__name__) @@ -158,16 +163,21 @@ class RemoteTransmitterController: """ logger.info( "init_tx: fc=%.3f MHz, fs=%.3f MHz, gain=%.1f dB, ch=%d", - center_frequency / 1e6, sample_rate / 1e6, gain, channel, + center_frequency / 1e6, + sample_rate / 1e6, + gain, + channel, + ) + self._send( + { + "function_name": "init_tx", + "center_frequency": center_frequency, + "sample_rate": sample_rate, + "gain": gain, + "channel": channel, + "gain_mode": gain_mode, + } ) - self._send({ - "function_name": "init_tx", - "center_frequency": center_frequency, - "sample_rate": sample_rate, - "gain": gain, - "channel": channel, - "gain_mode": gain_mode, - }) def transmit_async(self, duration_s: float) -> None: """Start a timed CW transmission in a background thread. diff --git a/tests/remote_control/test_remote_transmitter_controller.py b/tests/remote_control/test_remote_transmitter_controller.py index f2b6de7..8e132ef 100644 --- a/tests/remote_control/test_remote_transmitter_controller.py +++ b/tests/remote_control/test_remote_transmitter_controller.py @@ -7,8 +7,6 @@ sys.modules so they run regardless of whether the packages are installed. from __future__ import annotations import json -import sys -import threading import time from types import ModuleType from unittest.mock import MagicMock, patch @@ -199,15 +197,11 @@ class TestErrorHandling: def test_missing_paramiko_raises_runtime_error(self): """If paramiko is absent, connecting gives a clear RuntimeError.""" - import importlib - import ria_toolkit_oss.remote_control.remote_transmitter_controller as mod with patch.dict("sys.modules", {"paramiko": None}): with pytest.raises((RuntimeError, ImportError)): - mod.RemoteTransmitterController( - host="h", ssh_user="u", ssh_key_path="/k" - ) + mod.RemoteTransmitterController(host="h", ssh_user="u", ssh_key_path="/k") # --------------------------------------------------------------------------- diff --git a/tests/remote_control/test_sdr_remote_integration.py b/tests/remote_control/test_sdr_remote_integration.py index 2f13bed..123cbf1 100644 --- a/tests/remote_control/test_sdr_remote_integration.py +++ b/tests/remote_control/test_sdr_remote_integration.py @@ -2,7 +2,7 @@ from __future__ import annotations -from unittest.mock import MagicMock, call, patch +from unittest.mock import MagicMock, patch import pytest @@ -12,7 +12,6 @@ from ria_toolkit_oss.orchestration.campaign import ( TransmitterConfig, ) - # --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- @@ -179,9 +178,7 @@ class TestInitRemoteTxControllers: } ] executor = _make_executor(d) - with patch( - "ria_toolkit_oss.remote_control.RemoteTransmitterController" - ) as mock_cls: + with patch("ria_toolkit_oss.remote_control.RemoteTransmitterController") as mock_cls: executor._init_remote_tx_controllers() mock_cls.assert_not_called() assert executor._remote_tx_controllers == {} @@ -264,7 +261,7 @@ class TestStartTransmitterSdrRemote: tx = executor.config.transmitters[0] step = CaptureStep(duration=5.0, label="nochan") executor._start_transmitter(tx, step) - _, kwargs = mock_ctrl_kwarg = ctrl.init_tx.call_args + _, kwargs = ctrl.init_tx.call_args assert kwargs["channel"] == 0 def test_missing_controller_raises(self): @@ -381,7 +378,11 @@ class TestRunWithSdrRemote: ), patch.object(executor, "_close_sdr"), patch.object(executor, "_close_remote_tx_controllers"), - patch.object(executor, "_execute_step", return_value=MagicMock(error=None, qa=MagicMock(flagged=False, snr_db=20.0, duration_s=10.0))), + patch.object( + executor, + "_execute_step", + return_value=MagicMock(error=None, qa=MagicMock(flagged=False, snr_db=20.0, duration_s=10.0)), + ), ): executor.run() @@ -401,6 +402,7 @@ class TestTransmitBufferAndTimeout: def _executor_with_ctrl(self): from ria_toolkit_oss.orchestration.executor import CampaignExecutor + cfg = CampaignConfig.from_dict(_FULL_CAMPAIGN_DICT) executor = CampaignExecutor(cfg) ctrl = MagicMock() From 22b035dbeea70003da6d6afc747ccf9d34f3f8aa Mon Sep 17 00:00:00 2001 From: ben Date: Mon, 20 Apr 2026 13:51:15 -0400 Subject: [PATCH 10/12] format fixes --- scripts/pluto_tx_smoke.py | 48 ++++++------- scripts/pluto_tx_ws_smoke.py | 72 +++++++++---------- src/ria_toolkit_oss/agent/cli.py | 4 +- src/ria_toolkit_oss/agent/config.py | 1 + src/ria_toolkit_oss/agent/hardware.py | 4 +- src/ria_toolkit_oss/app/cli.py | 6 +- .../remote_control/remote_transmitter.py | 5 ++ src/ria_toolkit_oss/sdr/__init__.py | 9 ++- src/ria_toolkit_oss/sdr/pluto.py | 7 +- src/ria_toolkit_oss/sdr/sdr.py | 2 +- tests/agent/test_cli_tx.py | 16 +++-- tests/agent/test_integration.py | 4 +- tests/agent/test_integration_tx.py | 5 +- tests/agent/test_param_lock_contention.py | 33 +++++---- tests/agent/test_streamer.py | 4 +- tests/agent/test_tx_safety.py | 5 +- tests/agent/test_tx_underrun.py | 13 +--- tests/agent/test_ws_client.py | 4 +- tests/agent/test_ws_client_binary.py | 4 +- .../remote_control/test_remote_transmitter.py | 43 ++++++----- 20 files changed, 143 insertions(+), 146 deletions(-) diff --git a/scripts/pluto_tx_smoke.py b/scripts/pluto_tx_smoke.py index 64adbb9..97913ec 100755 --- a/scripts/pluto_tx_smoke.py +++ b/scripts/pluto_tx_smoke.py @@ -66,8 +66,9 @@ class LoggingFakeWs: pass -def _make_iq_frame(buffer_size: int, tone_hz: float, sample_rate: float, - phase_offset: float = 0.0) -> tuple[bytes, float]: +def _make_iq_frame( + buffer_size: int, tone_hz: float, sample_rate: float, phase_offset: float = 0.0 +) -> tuple[bytes, float]: """Return ``(interleaved_float32_bytes, next_phase)`` for a sine tone. Emitting one continuous phase-coherent tone requires threading the phase @@ -93,7 +94,9 @@ def _make_pluto_factory(identifier: str | None): if device != "pluto": raise ValueError(f"this script only drives pluto; got device={device!r}") from ria_toolkit_oss.sdr.pluto import Pluto + return Pluto(identifier=identifier) + return factory @@ -130,13 +133,14 @@ async def _run(args: argparse.Namespace) -> int: # Abort if tx_start was rejected by an interlock (no session → nothing to do). if streamer._tx is None: - print("tx_start rejected — see [tx_status] line above for the reason.", - file=sys.stderr) + print("tx_start rejected — see [tx_status] line above for the reason.", file=sys.stderr) return 2 - print(f"Transmitting at {args.frequency/1e6:.3f} MHz with " - f"{args.tone/1e3:.1f} kHz baseband tone at gain {args.gain} dB. " - f"{'Running for ' + str(args.duration) + 's' if args.duration > 0 else 'Run until Ctrl-C'}.") + print( + f"Transmitting at {args.frequency/1e6:.3f} MHz with " + f"{args.tone/1e3:.1f} kHz baseband tone at gain {args.gain} dB. " + f"{'Running for ' + str(args.duration) + 's' if args.duration > 0 else 'Run until Ctrl-C'}." + ) # Arrange a clean shutdown on Ctrl-C. stop = asyncio.Event() @@ -157,12 +161,11 @@ async def _run(args: argparse.Namespace) -> int: # topped up. The queue's own backpressure keeps us from spinning. produce_interval = buffer_dt * 0.5 try: + async def producer(): nonlocal phase while not stop.is_set(): - frame, phase = _make_iq_frame( - args.buffer_size, args.tone, args.sample_rate, phase - ) + frame, phase = _make_iq_frame(args.buffer_size, args.tone, args.sample_rate, phase) await streamer.on_binary(frame) await asyncio.sleep(produce_interval) @@ -193,20 +196,17 @@ def main() -> int: p = argparse.ArgumentParser( description="End-to-end TX smoke test: agent → Pluto continuous tone.", ) - p.add_argument("--identifier", default=None, - help="Pluto IP/hostname (default: auto-discover pluto.local)") - p.add_argument("--frequency", type=float, default=3_410_000_000.0, - help="TX LO in Hz (default 2.45 GHz)") - p.add_argument("--gain", type=float, default=-0.0, - help="TX gain in dB; Pluto range [-89, 0] (default -30)") - p.add_argument("--sample-rate", type=float, default=1_000_000.0, - help="Baseband sample rate (default 1 Msps)") - p.add_argument("--tone", type=float, default=100_000.0, - help="Baseband tone offset in Hz; 0 = DC (default 100 kHz)") - p.add_argument("--buffer-size", type=int, default=4096, - help="Complex samples per frame (default 4096)") - p.add_argument("--duration", type=float, default=60.0, - help="Seconds to transmit; 0 = run until Ctrl-C (default 30)") + p.add_argument("--identifier", default=None, help="Pluto IP/hostname (default: auto-discover pluto.local)") + p.add_argument("--frequency", type=float, default=3_410_000_000.0, help="TX LO in Hz (default 2.45 GHz)") + p.add_argument("--gain", type=float, default=-0.0, help="TX gain in dB; Pluto range [-89, 0] (default -30)") + p.add_argument("--sample-rate", type=float, default=1_000_000.0, help="Baseband sample rate (default 1 Msps)") + p.add_argument( + "--tone", type=float, default=100_000.0, help="Baseband tone offset in Hz; 0 = DC (default 100 kHz)" + ) + p.add_argument("--buffer-size", type=int, default=4096, help="Complex samples per frame (default 4096)") + p.add_argument( + "--duration", type=float, default=60.0, help="Seconds to transmit; 0 = run until Ctrl-C (default 30)" + ) p.add_argument("--log-level", default="INFO") args = p.parse_args() diff --git a/scripts/pluto_tx_ws_smoke.py b/scripts/pluto_tx_ws_smoke.py index d4c8344..f828e0c 100755 --- a/scripts/pluto_tx_ws_smoke.py +++ b/scripts/pluto_tx_ws_smoke.py @@ -41,8 +41,7 @@ from ria_toolkit_oss.agent.streamer import Streamer from ria_toolkit_oss.agent.ws_client import WsClient -def _make_iq_frame(buffer_size: int, tone_hz: float, sample_rate: float, - phase_offset: float) -> tuple[bytes, float]: +def _make_iq_frame(buffer_size: int, tone_hz: float, sample_rate: float, phase_offset: float) -> tuple[bytes, float]: n = np.arange(buffer_size, dtype=np.float64) phase = 2.0 * np.pi * tone_hz / sample_rate * n + phase_offset amp = 0.7 @@ -59,7 +58,9 @@ def _make_pluto_factory(identifier: str | None): if device != "pluto": raise ValueError(f"this script only drives pluto; got device={device!r}") from ria_toolkit_oss.sdr.pluto import Pluto + return Pluto(identifier=identifier) + return factory @@ -73,27 +74,29 @@ async def _mock_hub_handler(ws, args, stop: asyncio.Event): payload = json.loads(first) if payload.get("type") == "heartbeat": caps = payload.get("capabilities") - print(f"[mock-hub] agent heartbeat: capabilities={caps} " - f"tx_enabled={payload.get('tx_enabled')}") + print(f"[mock-hub] agent heartbeat: capabilities={caps} " f"tx_enabled={payload.get('tx_enabled')}") except asyncio.TimeoutError: print("[mock-hub] warning: no heartbeat received in first 2s") # Arm the agent's TX path. - await ws.send(json.dumps({ - "type": "tx_start", - "app_id": "ws-smoke", - "radio_config": { - "device": "pluto", - "identifier": args.identifier, - "tx_sample_rate": int(args.sample_rate), - "tx_center_frequency": int(args.frequency), - "tx_gain": int(args.gain), - "buffer_size": int(args.buffer_size), - "underrun_policy": "repeat", - }, - })) - print(f"[mock-hub] sent tx_start at {args.frequency/1e6:.3f} MHz, " - f"gain={args.gain} dB") + await ws.send( + json.dumps( + { + "type": "tx_start", + "app_id": "ws-smoke", + "radio_config": { + "device": "pluto", + "identifier": args.identifier, + "tx_sample_rate": int(args.sample_rate), + "tx_center_frequency": int(args.frequency), + "tx_gain": int(args.gain), + "buffer_size": int(args.buffer_size), + "underrun_policy": "repeat", + }, + } + ) + ) + print(f"[mock-hub] sent tx_start at {args.frequency/1e6:.3f} MHz, " f"gain={args.gain} dB") # Producer: push IQ frames at a steady clip. Use a concurrent receiver so # tx_status frames show up in real time rather than being queued behind @@ -112,15 +115,11 @@ async def _mock_hub_handler(ws, args, stop: asyncio.Event): recv_task = asyncio.create_task(receiver()) try: - deadline = None if args.duration <= 0 else ( - asyncio.get_event_loop().time() + args.duration - ) + deadline = None if args.duration <= 0 else (asyncio.get_event_loop().time() + args.duration) while not stop.is_set(): if deadline is not None and asyncio.get_event_loop().time() >= deadline: break - frame, phase = _make_iq_frame( - args.buffer_size, args.tone, args.sample_rate, phase - ) + frame, phase = _make_iq_frame(args.buffer_size, args.tone, args.sample_rate, phase) try: await ws.send(frame) except websockets.ConnectionClosed: @@ -204,20 +203,15 @@ def main() -> int: p = argparse.ArgumentParser( description="Full-stack TX smoke: localhost mock-hub → WS → agent → Pluto.", ) - p.add_argument("--identifier", default=None, - help="Pluto IP/hostname (default: auto-discover pluto.local)") - p.add_argument("--frequency", type=float, default=2_450_000_000.0, - help="TX LO in Hz (default 2.45 GHz)") - p.add_argument("--gain", type=float, default=0.0, - help="TX gain in dB; Pluto range [-89, 0] (default 0)") - p.add_argument("--sample-rate", type=float, default=1_000_000.0, - help="Baseband sample rate (default 1 Msps)") - p.add_argument("--tone", type=float, default=100_000.0, - help="Baseband tone offset in Hz (default 100 kHz)") - p.add_argument("--buffer-size", type=int, default=4096, - help="Complex samples per frame (default 4096)") - p.add_argument("--duration", type=float, default=30.0, - help="Seconds to transmit; 0 = run until Ctrl-C (default 30)") + p.add_argument("--identifier", default=None, help="Pluto IP/hostname (default: auto-discover pluto.local)") + p.add_argument("--frequency", type=float, default=2_450_000_000.0, help="TX LO in Hz (default 2.45 GHz)") + p.add_argument("--gain", type=float, default=0.0, help="TX gain in dB; Pluto range [-89, 0] (default 0)") + p.add_argument("--sample-rate", type=float, default=1_000_000.0, help="Baseband sample rate (default 1 Msps)") + p.add_argument("--tone", type=float, default=100_000.0, help="Baseband tone offset in Hz (default 100 kHz)") + p.add_argument("--buffer-size", type=int, default=4096, help="Complex samples per frame (default 4096)") + p.add_argument( + "--duration", type=float, default=30.0, help="Seconds to transmit; 0 = run until Ctrl-C (default 30)" + ) p.add_argument("--log-level", default="INFO") args = p.parse_args() diff --git a/src/ria_toolkit_oss/agent/cli.py b/src/ria_toolkit_oss/agent/cli.py index dec8420..83a7769 100644 --- a/src/ria_toolkit_oss/agent/cli.py +++ b/src/ria_toolkit_oss/agent/cli.py @@ -118,9 +118,9 @@ def _derive_ws_url(hub_url: str, agent_id: str) -> str: return "" base = hub_url.rstrip("/") if base.startswith("https://"): - base = "wss://" + base[len("https://"):] + base = "wss://" + base[len("https://") :] elif base.startswith("http://"): - base = "ws://" + base[len("http://"):] + base = "ws://" + base[len("http://") :] suffix = f"/screens/agent/ws?agent_id={agent_id}" if agent_id else "/screens/agent/ws" return base + suffix diff --git a/src/ria_toolkit_oss/agent/config.py b/src/ria_toolkit_oss/agent/config.py index 431094a..37d20c8 100644 --- a/src/ria_toolkit_oss/agent/config.py +++ b/src/ria_toolkit_oss/agent/config.py @@ -22,6 +22,7 @@ import os from dataclasses import asdict, dataclass, field from pathlib import Path + def _resolve_default_path() -> Path: return Path(os.environ.get("RIA_AGENT_CONFIG", str(Path.home() / ".ria" / "agent.json"))) diff --git a/src/ria_toolkit_oss/agent/hardware.py b/src/ria_toolkit_oss/agent/hardware.py index d585e8f..98b4683 100644 --- a/src/ria_toolkit_oss/agent/hardware.py +++ b/src/ria_toolkit_oss/agent/hardware.py @@ -46,9 +46,7 @@ def heartbeat_payload( if c.tx_max_duration_s is not None: payload["tx_max_duration_s"] = float(c.tx_max_duration_s) if c.tx_allowed_freq_ranges: - payload["tx_allowed_freq_ranges"] = [ - [float(lo), float(hi)] for lo, hi in c.tx_allowed_freq_ranges - ] + payload["tx_allowed_freq_ranges"] = [[float(lo), float(hi)] for lo, hi in c.tx_allowed_freq_ranges] if app_id: payload["app_id"] = app_id if sessions: diff --git a/src/ria_toolkit_oss/app/cli.py b/src/ria_toolkit_oss/app/cli.py index 9bfb479..7a2b7c7 100644 --- a/src/ria_toolkit_oss/app/cli.py +++ b/src/ria_toolkit_oss/app/cli.py @@ -37,7 +37,7 @@ def _engine(cfg: _config.AppConfig, sudo_override: bool = False) -> list[str]: for exe in ("docker", "podman"): if shutil.which(exe): use_sudo = sudo_override or cfg.sudo - return (["sudo", exe] if use_sudo else [exe]) + return ["sudo", exe] if use_sudo else [exe] print("error: neither 'docker' nor 'podman' found on PATH", file=sys.stderr) sys.exit(2) @@ -96,7 +96,9 @@ def _hardware_flags(labels: dict, no_gpu: bool, no_usb: bool, no_host_net: bool) if _gpu_available(): flags += ["--gpus", "all"] else: - notes.append("image wants GPU but no NVIDIA runtime detected — skipping --gpus (use --force-gpu to override)") + notes.append( + "image wants GPU but no NVIDIA runtime detected — skipping --gpus (use --force-gpu to override)" + ) if hw_items & {"pluto", "rtlsdr", "hackrf", "bladerf"} and not no_usb: flags += ["--device", "/dev/bus/usb"] diff --git a/src/ria_toolkit_oss/remote_control/remote_transmitter.py b/src/ria_toolkit_oss/remote_control/remote_transmitter.py index c6ea3ab..d5b24a0 100644 --- a/src/ria_toolkit_oss/remote_control/remote_transmitter.py +++ b/src/ria_toolkit_oss/remote_control/remote_transmitter.py @@ -40,15 +40,19 @@ class RemoteTransmitter: try: if radio_str in ("pluto", "plutosdr"): from ria_toolkit_oss.sdr.pluto import Pluto + self._sdr = Pluto(identifier) elif radio_str in ("usrp",): from ria_toolkit_oss.sdr.usrp import USRP + self._sdr = USRP(identifier) elif radio_str in ("hackrf", "hackrf_one"): from ria_toolkit_oss.sdr.hackrf import HackRF + self._sdr = HackRF(identifier) elif radio_str in ("bladerf", "blade"): from ria_toolkit_oss.sdr.blade import Blade + self._sdr = Blade(identifier) else: raise ValueError(f"Unknown SDR type: {radio_str!r}") @@ -77,6 +81,7 @@ class RemoteTransmitter: if self._sdr is None: raise RuntimeError("Call set_radio() and init_tx() before transmit()") import time + # Transmit in a loop until duration has elapsed end = time.monotonic() + duration_s while time.monotonic() < end: diff --git a/src/ria_toolkit_oss/sdr/__init__.py b/src/ria_toolkit_oss/sdr/__init__.py index 4b327a2..a712be6 100644 --- a/src/ria_toolkit_oss/sdr/__init__.py +++ b/src/ria_toolkit_oss/sdr/__init__.py @@ -15,8 +15,13 @@ __all__ = [ ] from .mock import MockSDR -from .sdr import SDR, SDRError, SdrDisconnectedError, SDRParameterError, translate_disconnect # noqa: F401 - +from .sdr import ( # noqa: F401 + SDR, + SdrDisconnectedError, + SDRError, + SDRParameterError, + translate_disconnect, +) _DRIVER_CANDIDATES: tuple[tuple[str, str, str], ...] = ( ("mock", "ria_toolkit_oss.sdr.mock", "MockSDR"), diff --git a/src/ria_toolkit_oss/sdr/pluto.py b/src/ria_toolkit_oss/sdr/pluto.py index 88243b1..c78d36f 100644 --- a/src/ria_toolkit_oss/sdr/pluto.py +++ b/src/ria_toolkit_oss/sdr/pluto.py @@ -8,7 +8,12 @@ import adi import numpy as np from ria_toolkit_oss.datatypes.recording import Recording -from ria_toolkit_oss.sdr.sdr import SDR, SDRError, SDRParameterError, translate_disconnect +from ria_toolkit_oss.sdr.sdr import ( + SDR, + SDRError, + SDRParameterError, + translate_disconnect, +) class Pluto(SDR): diff --git a/src/ria_toolkit_oss/sdr/sdr.py b/src/ria_toolkit_oss/sdr/sdr.py index aba68d5..443f8fa 100644 --- a/src/ria_toolkit_oss/sdr/sdr.py +++ b/src/ria_toolkit_oss/sdr/sdr.py @@ -583,7 +583,7 @@ _DISCONNECT_MARKERS = ( "i/o error", "input/output error", "errno 19", # ENODEV - "errno 5", # EIO + "errno 5", # EIO ) diff --git a/tests/agent/test_cli_tx.py b/tests/agent/test_cli_tx.py index 1543d4c..da66e91 100644 --- a/tests/agent/test_cli_tx.py +++ b/tests/agent/test_cli_tx.py @@ -26,9 +26,11 @@ class _FakeResp: def _run_register(argv: list[str], cfg_path) -> int: fake_resp = _FakeResp({"agent_id": "agent-1", "token": "tok-abc"}) - with patch.dict("os.environ", {"RIA_AGENT_CONFIG": str(cfg_path)}, clear=False), \ - patch("urllib.request.urlopen", return_value=fake_resp), \ - patch.object(sys, "argv", ["ria-agent", *argv]): + with ( + patch.dict("os.environ", {"RIA_AGENT_CONFIG": str(cfg_path)}, clear=False), + patch("urllib.request.urlopen", return_value=fake_resp), + patch.object(sys, "argv", ["ria-agent", *argv]), + ): try: agent_cli.main() except SystemExit as exc: @@ -96,9 +98,11 @@ def test_stream_allow_tx_does_not_persist(tmp_path): captured["cfg"] = cfg return None - with patch.dict("os.environ", {"RIA_AGENT_CONFIG": str(cfg_path)}, clear=False), \ - patch("ria_toolkit_oss.agent.streamer.run_streamer", new=_fake_run_streamer), \ - patch.object(sys, "argv", ["ria-agent", "stream", "--allow-tx"]): + with ( + patch.dict("os.environ", {"RIA_AGENT_CONFIG": str(cfg_path)}, clear=False), + patch("ria_toolkit_oss.agent.streamer.run_streamer", new=_fake_run_streamer), + patch.object(sys, "argv", ["ria-agent", "stream", "--allow-tx"]), + ): try: agent_cli.main() except SystemExit: diff --git a/tests/agent/test_integration.py b/tests/agent/test_integration.py index 168e7a6..01eb9ec 100644 --- a/tests/agent/test_integration.py +++ b/tests/agent/test_integration.py @@ -70,9 +70,7 @@ def test_server_start_stream_stop_cycle_over_real_ws(): reconnect_pause=0.05, ) streamer = Streamer(ws=client, sdr_factory=lambda d, i: MockSDR(buffer_size=32, seed=0)) - task = asyncio.create_task( - client.run(on_message=streamer.on_message, heartbeat=streamer.build_heartbeat) - ) + task = asyncio.create_task(client.run(on_message=streamer.on_message, heartbeat=streamer.build_heartbeat)) await asyncio.wait_for(ready.wait(), timeout=3.0) await asyncio.wait_for(stopped.wait(), timeout=3.0) client.stop() diff --git a/tests/agent/test_integration_tx.py b/tests/agent/test_integration_tx.py index 4fc13af..e5239d7 100644 --- a/tests/agent/test_integration_tx.py +++ b/tests/agent/test_integration_tx.py @@ -77,10 +77,7 @@ def test_server_tx_start_binary_stop_cycle_over_real_ws(): msg = await asyncio.wait_for(ws.recv(), timeout=2.0) if isinstance(msg, str): control_frames.append(json.loads(msg)) - if any( - f.get("type") == "tx_status" and f.get("state") == "transmitting" - for f in control_frames - ): + if any(f.get("type") == "tx_status" and f.get("state") == "transmitting" for f in control_frames): break await ws.send(json.dumps({"type": "tx_stop", "app_id": "tx-app"})) diff --git a/tests/agent/test_param_lock_contention.py b/tests/agent/test_param_lock_contention.py index e3d84fc..e70229e 100644 --- a/tests/agent/test_param_lock_contention.py +++ b/tests/agent/test_param_lock_contention.py @@ -30,7 +30,6 @@ from ria_toolkit_oss.agent.config import AgentConfig from ria_toolkit_oss.agent.streamer import Streamer from ria_toolkit_oss.sdr.mock import MockSDR - _STRESS_S = float(os.environ.get("RIA_LOCK_STRESS_S", "2.0")) @@ -156,18 +155,21 @@ def test_full_duplex_stays_healthy_over_stress_window(): s = Streamer(ws=ws, sdr_factory=lambda d, i: sdr, cfg=AgentConfig(tx_enabled=True)) await s.on_message( - {"type": "start", "app_id": "app-1", - "radio_config": {"device": "mock", "buffer_size": BUF}} + {"type": "start", "app_id": "app-1", "radio_config": {"device": "mock", "buffer_size": BUF}} ) await s.on_message( - {"type": "tx_start", "app_id": "app-1", - "radio_config": { - "device": "mock", "buffer_size": BUF, - "tx_sample_rate": 1_000_000, - "tx_center_frequency": 2.45e9, - "tx_gain": -20, - "underrun_policy": "zero", - }} + { + "type": "tx_start", + "app_id": "app-1", + "radio_config": { + "device": "mock", + "buffer_size": BUF, + "tx_sample_rate": 1_000_000, + "tx_center_frequency": 2.45e9, + "tx_gain": -20, + "underrun_policy": "zero", + }, + } ) marker = np.arange(BUF, dtype=np.complex64) + 1 @@ -180,12 +182,10 @@ def test_full_duplex_stays_healthy_over_stress_window(): # which routes through the same setters the stress test above # verifies. await s.on_message( - {"type": "tx_configure", "app_id": "app-1", - "radio_config": {"tx_sample_rate": 1_000_000 + i}} + {"type": "tx_configure", "app_id": "app-1", "radio_config": {"tx_sample_rate": 1_000_000 + i}} ) await s.on_message( - {"type": "configure", "app_id": "app-1", - "radio_config": {"sample_rate": 2_000_000 + i}} + {"type": "configure", "app_id": "app-1", "radio_config": {"sample_rate": 2_000_000 + i}} ) i += 1 await asyncio.sleep(0.005) @@ -197,8 +197,7 @@ def test_full_duplex_stays_healthy_over_stress_window(): ws, s = asyncio.run(scenario()) # No error frame leaked out. - errors = [m for m in ws.json_sent - if m.get("type") in ("error", "tx_status") and m.get("state") == "error"] + errors = [m for m in ws.json_sent if m.get("type") in ("error", "tx_status") and m.get("state") == "error"] assert errors == [], f"Unexpected error frames: {errors}" # RX produced IQ frames and TX's callback ran — heartbeat-level contention # check: both setter paths were hit at least once during configure dispatch. diff --git a/tests/agent/test_streamer.py b/tests/agent/test_streamer.py index da2956c..44f98e0 100644 --- a/tests/agent/test_streamer.py +++ b/tests/agent/test_streamer.py @@ -121,9 +121,7 @@ def test_start_without_device_emits_error(): def test_configure_queues_update(): async def scenario(): streamer = Streamer(ws=FakeWs(), sdr_factory=_factory) - await streamer.on_message( - {"type": "configure", "app_id": "x", "radio_config": {"center_frequency": 915e6}} - ) + await streamer.on_message({"type": "configure", "app_id": "x", "radio_config": {"center_frequency": 915e6}}) # Before start(), pending config lives on the standalone dict exposed via the _pending_config shim. return streamer._pending_config diff --git a/tests/agent/test_tx_safety.py b/tests/agent/test_tx_safety.py index 2de2939..385835f 100644 --- a/tests/agent/test_tx_safety.py +++ b/tests/agent/test_tx_safety.py @@ -143,10 +143,7 @@ def test_rejects_duplicate_tx_session(): return ws ws = asyncio.run(scenario()) - errors = [ - m for m in ws.json_sent - if m.get("type") == "tx_status" and m.get("state") == "error" - ] + errors = [m for m in ws.json_sent if m.get("type") == "tx_status" and m.get("state") == "error"] assert any("already active" in e.get("message", "") for e in errors) diff --git a/tests/agent/test_tx_underrun.py b/tests/agent/test_tx_underrun.py index 8fbe020..95e4277 100644 --- a/tests/agent/test_tx_underrun.py +++ b/tests/agent/test_tx_underrun.py @@ -70,10 +70,7 @@ def test_underrun_pause_stops_session_and_emits_status(): # Do not push any buffers. The callback underruns on first tick and # the watchdog should emit "underrun" and tear down. for _ in range(100): - if any( - m.get("type") == "tx_status" and m.get("state") == "underrun" - for m in ws.json_sent - ): + if any(m.get("type") == "tx_status" and m.get("state") == "underrun" for m in ws.json_sent): break await asyncio.sleep(0.01) for _ in range(50): @@ -103,9 +100,7 @@ def test_underrun_zero_keeps_session_alive(): ws, still_alive = asyncio.run(scenario()) # No underrun status emitted (policy absorbs it silently). - assert not any( - m.get("type") == "tx_status" and m.get("state") == "underrun" for m in ws.json_sent - ) + assert not any(m.get("type") == "tx_status" and m.get("state") == "underrun" for m in ws.json_sent) assert still_alive # All produced buffers are zero (no real data was pushed). assert sdr.tx_produced, "expected at least one TX callback invocation" @@ -129,9 +124,7 @@ def test_underrun_repeat_replays_last_buffer(): ws, sdr = asyncio.run(scenario()) # No underrun status emitted. - assert not any( - m.get("type") == "tx_status" and m.get("state") == "underrun" for m in ws.json_sent - ) + assert not any(m.get("type") == "tx_status" and m.get("state") == "underrun" for m in ws.json_sent) # At least two buffers equal to the marker — the real one and ≥1 repeat. matching = [b for b in sdr.tx_produced if np.array_equal(b, marker)] assert len(matching) >= 2, f"expected ≥2 buffers matching marker, got {len(matching)}" diff --git a/tests/agent/test_ws_client.py b/tests/agent/test_ws_client.py index c113b64..7717d5f 100644 --- a/tests/agent/test_ws_client.py +++ b/tests/agent/test_ws_client.py @@ -142,9 +142,7 @@ def test_malformed_control_frame_does_not_crash(): async def on_msg(m): handled.append(m) - task = asyncio.create_task( - client.run(on_message=on_msg, heartbeat=lambda: {"type": "heartbeat"}) - ) + task = asyncio.create_task(client.run(on_message=on_msg, heartbeat=lambda: {"type": "heartbeat"})) for _ in range(50): if handled: break diff --git a/tests/agent/test_ws_client_binary.py b/tests/agent/test_ws_client_binary.py index 4d9ddc1..70bd97c 100644 --- a/tests/agent/test_ws_client_binary.py +++ b/tests/agent/test_ws_client_binary.py @@ -102,9 +102,7 @@ def test_binary_frame_dropped_when_no_handler(): async def on_msg(m): messages.append(m) - task = asyncio.create_task( - client.run(on_message=on_msg, heartbeat=lambda: {"type": "heartbeat"}) - ) + task = asyncio.create_task(client.run(on_message=on_msg, heartbeat=lambda: {"type": "heartbeat"})) for _ in range(50): if messages: break diff --git a/tests/remote_control/test_remote_transmitter.py b/tests/remote_control/test_remote_transmitter.py index 9c50152..e6a0940 100644 --- a/tests/remote_control/test_remote_transmitter.py +++ b/tests/remote_control/test_remote_transmitter.py @@ -12,7 +12,6 @@ import pytest from ria_toolkit_oss.remote_control.remote_transmitter import RemoteTransmitter - # --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- @@ -241,34 +240,40 @@ class TestRunFunction: def test_init_tx_without_radio_returns_failure(self): tx = RemoteTransmitter() - resp = tx.run_function({ - "function_name": "init_tx", - "center_frequency": 2.4e9, - "sample_rate": 20e6, - "gain": 0, - }) + resp = tx.run_function( + { + "function_name": "init_tx", + "center_frequency": 2.4e9, + "sample_rate": 20e6, + "gain": 0, + } + ) assert resp["status"] is False assert resp["error_message"] def test_init_tx_with_radio_success(self): tx = self._tx_with_mock_sdr() - resp = tx.run_function({ - "function_name": "init_tx", - "center_frequency": 2.4e9, - "sample_rate": 20e6, - "gain": 30, - }) + resp = tx.run_function( + { + "function_name": "init_tx", + "center_frequency": 2.4e9, + "sample_rate": 20e6, + "gain": 30, + } + ) assert resp["status"] is True def test_transmit_runs_for_short_duration(self): tx = self._tx_with_mock_sdr() tx._sdr.init_tx = MagicMock() - resp = tx.run_function({ - "function_name": "init_tx", - "center_frequency": 2.4e9, - "sample_rate": 20e6, - "gain": 0, - }) + resp = tx.run_function( + { + "function_name": "init_tx", + "center_frequency": 2.4e9, + "sample_rate": 20e6, + "gain": 0, + } + ) resp = tx.run_function({"function_name": "transmit", "duration_s": 0.02}) assert resp["status"] is True From 98f63b622b022c04bba1d117cc2ccea1063d8e49 Mon Sep 17 00:00:00 2001 From: ben Date: Mon, 20 Apr 2026 15:27:54 -0400 Subject: [PATCH 11/12] remote transmitter fix --- .../remote_control/remote_transmitter_controller.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/src/ria_toolkit_oss/remote_control/remote_transmitter_controller.py b/src/ria_toolkit_oss/remote_control/remote_transmitter_controller.py index 212295c..dab760d 100644 --- a/src/ria_toolkit_oss/remote_control/remote_transmitter_controller.py +++ b/src/ria_toolkit_oss/remote_control/remote_transmitter_controller.py @@ -13,12 +13,6 @@ import json import logging import threading import time -from typing import TYPE_CHECKING - -if TYPE_CHECKING: - import paramiko - import zmq - import paramiko import zmq From 8d2f9eebaf79680d428f7951b2b3eb5c4952c756 Mon Sep 17 00:00:00 2001 From: ben Date: Mon, 20 Apr 2026 15:59:03 -0400 Subject: [PATCH 12/12] black fix --- .../remote_control/remote_transmitter_controller.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/ria_toolkit_oss/remote_control/remote_transmitter_controller.py b/src/ria_toolkit_oss/remote_control/remote_transmitter_controller.py index dab760d..b073b5d 100644 --- a/src/ria_toolkit_oss/remote_control/remote_transmitter_controller.py +++ b/src/ria_toolkit_oss/remote_control/remote_transmitter_controller.py @@ -13,6 +13,7 @@ import json import logging import threading import time + import paramiko import zmq