Compare commits
No commits in common. "912fc54f253e70a72728ac4813283f5373f6a13f" and "b884397f1f2dedd5c1b9c38da9b7c88285801c03" have entirely different histories.
912fc54f25
...
b884397f1f
159
poetry.lock
generated
159
poetry.lock
generated
|
|
@ -1,4 +1,4 @@
|
|||
# This file is automatically @generated by Poetry 2.3.4 and should not be changed by hand.
|
||||
# This file is automatically @generated by Poetry 2.1.2 and should not be changed by hand.
|
||||
|
||||
[[package]]
|
||||
name = "alabaster"
|
||||
|
|
@ -1108,7 +1108,7 @@ files = [
|
|||
|
||||
[package.dependencies]
|
||||
attrs = ">=22.2.0"
|
||||
jsonschema-specifications = ">=2023.3.6"
|
||||
jsonschema-specifications = ">=2023.03.6"
|
||||
referencing = ">=0.28.4"
|
||||
rpds-py = ">=0.25.0"
|
||||
|
||||
|
|
@ -3463,101 +3463,76 @@ anyio = ">=3.0.0"
|
|||
|
||||
[[package]]
|
||||
name = "websockets"
|
||||
version = "13.1"
|
||||
version = "16.0"
|
||||
description = "An implementation of the WebSocket Protocol (RFC 6455 & 7692)"
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
groups = ["agent", "docs", "server", "test"]
|
||||
python-versions = ">=3.10"
|
||||
groups = ["docs", "server", "test"]
|
||||
files = [
|
||||
{file = "websockets-13.1-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:f48c749857f8fb598fb890a75f540e3221d0976ed0bf879cf3c7eef34151acee"},
|
||||
{file = "websockets-13.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:c7e72ce6bda6fb9409cc1e8164dd41d7c91466fb599eb047cfda72fe758a34a7"},
|
||||
{file = "websockets-13.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:f779498eeec470295a2b1a5d97aa1bc9814ecd25e1eb637bd9d1c73a327387f6"},
|
||||
{file = "websockets-13.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4676df3fe46956fbb0437d8800cd5f2b6d41143b6e7e842e60554398432cf29b"},
|
||||
{file = "websockets-13.1-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a7affedeb43a70351bb811dadf49493c9cfd1ed94c9c70095fd177e9cc1541fa"},
|
||||
{file = "websockets-13.1-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1971e62d2caa443e57588e1d82d15f663b29ff9dfe7446d9964a4b6f12c1e700"},
|
||||
{file = "websockets-13.1-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:5f2e75431f8dc4a47f31565a6e1355fb4f2ecaa99d6b89737527ea917066e26c"},
|
||||
{file = "websockets-13.1-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:58cf7e75dbf7e566088b07e36ea2e3e2bd5676e22216e4cad108d4df4a7402a0"},
|
||||
{file = "websockets-13.1-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:c90d6dec6be2c7d03378a574de87af9b1efea77d0c52a8301dd831ece938452f"},
|
||||
{file = "websockets-13.1-cp310-cp310-win32.whl", hash = "sha256:730f42125ccb14602f455155084f978bd9e8e57e89b569b4d7f0f0c17a448ffe"},
|
||||
{file = "websockets-13.1-cp310-cp310-win_amd64.whl", hash = "sha256:5993260f483d05a9737073be197371940c01b257cc45ae3f1d5d7adb371b266a"},
|
||||
{file = "websockets-13.1-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:61fc0dfcda609cda0fc9fe7977694c0c59cf9d749fbb17f4e9483929e3c48a19"},
|
||||
{file = "websockets-13.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:ceec59f59d092c5007e815def4ebb80c2de330e9588e101cf8bd94c143ec78a5"},
|
||||
{file = "websockets-13.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:c1dca61c6db1166c48b95198c0b7d9c990b30c756fc2923cc66f68d17dc558fd"},
|
||||
{file = "websockets-13.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:308e20f22c2c77f3f39caca508e765f8725020b84aa963474e18c59accbf4c02"},
|
||||
{file = "websockets-13.1-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:62d516c325e6540e8a57b94abefc3459d7dab8ce52ac75c96cad5549e187e3a7"},
|
||||
{file = "websockets-13.1-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:87c6e35319b46b99e168eb98472d6c7d8634ee37750d7693656dc766395df096"},
|
||||
{file = "websockets-13.1-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:5f9fee94ebafbc3117c30be1844ed01a3b177bb6e39088bc6b2fa1dc15572084"},
|
||||
{file = "websockets-13.1-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:7c1e90228c2f5cdde263253fa5db63e6653f1c00e7ec64108065a0b9713fa1b3"},
|
||||
{file = "websockets-13.1-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:6548f29b0e401eea2b967b2fdc1c7c7b5ebb3eeb470ed23a54cd45ef078a0db9"},
|
||||
{file = "websockets-13.1-cp311-cp311-win32.whl", hash = "sha256:c11d4d16e133f6df8916cc5b7e3e96ee4c44c936717d684a94f48f82edb7c92f"},
|
||||
{file = "websockets-13.1-cp311-cp311-win_amd64.whl", hash = "sha256:d04f13a1d75cb2b8382bdc16ae6fa58c97337253826dfe136195b7f89f661557"},
|
||||
{file = "websockets-13.1-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:9d75baf00138f80b48f1eac72ad1535aac0b6461265a0bcad391fc5aba875cfc"},
|
||||
{file = "websockets-13.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:9b6f347deb3dcfbfde1c20baa21c2ac0751afaa73e64e5b693bb2b848efeaa49"},
|
||||
{file = "websockets-13.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:de58647e3f9c42f13f90ac7e5f58900c80a39019848c5547bc691693098ae1bd"},
|
||||
{file = "websockets-13.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a1b54689e38d1279a51d11e3467dd2f3a50f5f2e879012ce8f2d6943f00e83f0"},
|
||||
{file = "websockets-13.1-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:cf1781ef73c073e6b0f90af841aaf98501f975d306bbf6221683dd594ccc52b6"},
|
||||
{file = "websockets-13.1-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8d23b88b9388ed85c6faf0e74d8dec4f4d3baf3ecf20a65a47b836d56260d4b9"},
|
||||
{file = "websockets-13.1-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:3c78383585f47ccb0fcf186dcb8a43f5438bd7d8f47d69e0b56f71bf431a0a68"},
|
||||
{file = "websockets-13.1-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:d6d300f8ec35c24025ceb9b9019ae9040c1ab2f01cddc2bcc0b518af31c75c14"},
|
||||
{file = "websockets-13.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:a9dcaf8b0cc72a392760bb8755922c03e17a5a54e08cca58e8b74f6902b433cf"},
|
||||
{file = "websockets-13.1-cp312-cp312-win32.whl", hash = "sha256:2f85cf4f2a1ba8f602298a853cec8526c2ca42a9a4b947ec236eaedb8f2dc80c"},
|
||||
{file = "websockets-13.1-cp312-cp312-win_amd64.whl", hash = "sha256:38377f8b0cdeee97c552d20cf1865695fcd56aba155ad1b4ca8779a5b6ef4ac3"},
|
||||
{file = "websockets-13.1-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:a9ab1e71d3d2e54a0aa646ab6d4eebfaa5f416fe78dfe4da2839525dc5d765c6"},
|
||||
{file = "websockets-13.1-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:b9d7439d7fab4dce00570bb906875734df13d9faa4b48e261c440a5fec6d9708"},
|
||||
{file = "websockets-13.1-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:327b74e915cf13c5931334c61e1a41040e365d380f812513a255aa804b183418"},
|
||||
{file = "websockets-13.1-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:325b1ccdbf5e5725fdcb1b0e9ad4d2545056479d0eee392c291c1bf76206435a"},
|
||||
{file = "websockets-13.1-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:346bee67a65f189e0e33f520f253d5147ab76ae42493804319b5716e46dddf0f"},
|
||||
{file = "websockets-13.1-cp313-cp313-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:91a0fa841646320ec0d3accdff5b757b06e2e5c86ba32af2e0815c96c7a603c5"},
|
||||
{file = "websockets-13.1-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:18503d2c5f3943e93819238bf20df71982d193f73dcecd26c94514f417f6b135"},
|
||||
{file = "websockets-13.1-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:a9cd1af7e18e5221d2878378fbc287a14cd527fdd5939ed56a18df8a31136bb2"},
|
||||
{file = "websockets-13.1-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:70c5be9f416aa72aab7a2a76c90ae0a4fe2755c1816c153c1a2bcc3333ce4ce6"},
|
||||
{file = "websockets-13.1-cp313-cp313-win32.whl", hash = "sha256:624459daabeb310d3815b276c1adef475b3e6804abaf2d9d2c061c319f7f187d"},
|
||||
{file = "websockets-13.1-cp313-cp313-win_amd64.whl", hash = "sha256:c518e84bb59c2baae725accd355c8dc517b4a3ed8db88b4bc93c78dae2974bf2"},
|
||||
{file = "websockets-13.1-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:c7934fd0e920e70468e676fe7f1b7261c1efa0d6c037c6722278ca0228ad9d0d"},
|
||||
{file = "websockets-13.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:149e622dc48c10ccc3d2760e5f36753db9cacf3ad7bc7bbbfd7d9c819e286f23"},
|
||||
{file = "websockets-13.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:a569eb1b05d72f9bce2ebd28a1ce2054311b66677fcd46cf36204ad23acead8c"},
|
||||
{file = "websockets-13.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:95df24ca1e1bd93bbca51d94dd049a984609687cb2fb08a7f2c56ac84e9816ea"},
|
||||
{file = "websockets-13.1-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d8dbb1bf0c0a4ae8b40bdc9be7f644e2f3fb4e8a9aca7145bfa510d4a374eeb7"},
|
||||
{file = "websockets-13.1-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:035233b7531fb92a76beefcbf479504db8c72eb3bff41da55aecce3a0f729e54"},
|
||||
{file = "websockets-13.1-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:e4450fc83a3df53dec45922b576e91e94f5578d06436871dce3a6be38e40f5db"},
|
||||
{file = "websockets-13.1-cp38-cp38-musllinux_1_2_i686.whl", hash = "sha256:463e1c6ec853202dd3657f156123d6b4dad0c546ea2e2e38be2b3f7c5b8e7295"},
|
||||
{file = "websockets-13.1-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:6d6855bbe70119872c05107e38fbc7f96b1d8cb047d95c2c50869a46c65a8e96"},
|
||||
{file = "websockets-13.1-cp38-cp38-win32.whl", hash = "sha256:204e5107f43095012b00f1451374693267adbb832d29966a01ecc4ce1db26faf"},
|
||||
{file = "websockets-13.1-cp38-cp38-win_amd64.whl", hash = "sha256:485307243237328c022bc908b90e4457d0daa8b5cf4b3723fd3c4a8012fce4c6"},
|
||||
{file = "websockets-13.1-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:9b37c184f8b976f0c0a231a5f3d6efe10807d41ccbe4488df8c74174805eea7d"},
|
||||
{file = "websockets-13.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:163e7277e1a0bd9fb3c8842a71661ad19c6aa7bb3d6678dc7f89b17fbcc4aeb7"},
|
||||
{file = "websockets-13.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:4b889dbd1342820cc210ba44307cf75ae5f2f96226c0038094455a96e64fb07a"},
|
||||
{file = "websockets-13.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:586a356928692c1fed0eca68b4d1c2cbbd1ca2acf2ac7e7ebd3b9052582deefa"},
|
||||
{file = "websockets-13.1-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:7bd6abf1e070a6b72bfeb71049d6ad286852e285f146682bf30d0296f5fbadfa"},
|
||||
{file = "websockets-13.1-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6d2aad13a200e5934f5a6767492fb07151e1de1d6079c003ab31e1823733ae79"},
|
||||
{file = "websockets-13.1-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:df01aea34b6e9e33572c35cd16bae5a47785e7d5c8cb2b54b2acdb9678315a17"},
|
||||
{file = "websockets-13.1-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:e54affdeb21026329fb0744ad187cf812f7d3c2aa702a5edb562b325191fcab6"},
|
||||
{file = "websockets-13.1-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:9ef8aa8bdbac47f4968a5d66462a2a0935d044bf35c0e5a8af152d58516dbeb5"},
|
||||
{file = "websockets-13.1-cp39-cp39-win32.whl", hash = "sha256:deeb929efe52bed518f6eb2ddc00cc496366a14c726005726ad62c2dd9017a3c"},
|
||||
{file = "websockets-13.1-cp39-cp39-win_amd64.whl", hash = "sha256:7c65ffa900e7cc958cd088b9a9157a8141c991f8c53d11087e6fb7277a03f81d"},
|
||||
{file = "websockets-13.1-pp310-pypy310_pp73-macosx_10_15_x86_64.whl", hash = "sha256:5dd6da9bec02735931fccec99d97c29f47cc61f644264eb995ad6c0c27667238"},
|
||||
{file = "websockets-13.1-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:2510c09d8e8df777177ee3d40cd35450dc169a81e747455cc4197e63f7e7bfe5"},
|
||||
{file = "websockets-13.1-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f1c3cf67185543730888b20682fb186fc8d0fa6f07ccc3ef4390831ab4b388d9"},
|
||||
{file = "websockets-13.1-pp310-pypy310_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:bcc03c8b72267e97b49149e4863d57c2d77f13fae12066622dc78fe322490fe6"},
|
||||
{file = "websockets-13.1-pp310-pypy310_pp73-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:004280a140f220c812e65f36944a9ca92d766b6cc4560be652a0a3883a79ed8a"},
|
||||
{file = "websockets-13.1-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:e2620453c075abeb0daa949a292e19f56de518988e079c36478bacf9546ced23"},
|
||||
{file = "websockets-13.1-pp38-pypy38_pp73-macosx_10_9_x86_64.whl", hash = "sha256:9156c45750b37337f7b0b00e6248991a047be4aa44554c9886fe6bdd605aab3b"},
|
||||
{file = "websockets-13.1-pp38-pypy38_pp73-macosx_11_0_arm64.whl", hash = "sha256:80c421e07973a89fbdd93e6f2003c17d20b69010458d3a8e37fb47874bd67d51"},
|
||||
{file = "websockets-13.1-pp38-pypy38_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:82d0ba76371769d6a4e56f7e83bb8e81846d17a6190971e38b5de108bde9b0d7"},
|
||||
{file = "websockets-13.1-pp38-pypy38_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e9875a0143f07d74dc5e1ded1c4581f0d9f7ab86c78994e2ed9e95050073c94d"},
|
||||
{file = "websockets-13.1-pp38-pypy38_pp73-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a11e38ad8922c7961447f35c7b17bffa15de4d17c70abd07bfbe12d6faa3e027"},
|
||||
{file = "websockets-13.1-pp38-pypy38_pp73-win_amd64.whl", hash = "sha256:4059f790b6ae8768471cddb65d3c4fe4792b0ab48e154c9f0a04cefaabcd5978"},
|
||||
{file = "websockets-13.1-pp39-pypy39_pp73-macosx_10_15_x86_64.whl", hash = "sha256:25c35bf84bf7c7369d247f0b8cfa157f989862c49104c5cf85cb5436a641d93e"},
|
||||
{file = "websockets-13.1-pp39-pypy39_pp73-macosx_11_0_arm64.whl", hash = "sha256:83f91d8a9bb404b8c2c41a707ac7f7f75b9442a0a876df295de27251a856ad09"},
|
||||
{file = "websockets-13.1-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7a43cfdcddd07f4ca2b1afb459824dd3c6d53a51410636a2c7fc97b9a8cf4842"},
|
||||
{file = "websockets-13.1-pp39-pypy39_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:48a2ef1381632a2f0cb4efeff34efa97901c9fbc118e01951ad7cfc10601a9bb"},
|
||||
{file = "websockets-13.1-pp39-pypy39_pp73-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:459bf774c754c35dbb487360b12c5727adab887f1622b8aed5755880a21c4a20"},
|
||||
{file = "websockets-13.1-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:95858ca14a9f6fa8413d29e0a585b31b278388aa775b8a81fa24830123874678"},
|
||||
{file = "websockets-13.1-py3-none-any.whl", hash = "sha256:a9a396a6ad26130cdae92ae10c36af09d9bfe6cafe69670fd3b6da9b07b4044f"},
|
||||
{file = "websockets-13.1.tar.gz", hash = "sha256:a3b3366087c1bc0a2795111edcadddb8b3b59509d5db5d7ea3fdd69f954a8878"},
|
||||
{file = "websockets-16.0-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:04cdd5d2d1dacbad0a7bf36ccbcd3ccd5a30ee188f2560b7a62a30d14107b31a"},
|
||||
{file = "websockets-16.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:8ff32bb86522a9e5e31439a58addbb0166f0204d64066fb955265c4e214160f0"},
|
||||
{file = "websockets-16.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:583b7c42688636f930688d712885cf1531326ee05effd982028212ccc13e5957"},
|
||||
{file = "websockets-16.0-cp310-cp310-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:7d837379b647c0c4c2355c2499723f82f1635fd2c26510e1f587d89bc2199e72"},
|
||||
{file = "websockets-16.0-cp310-cp310-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:df57afc692e517a85e65b72e165356ed1df12386ecb879ad5693be08fac65dde"},
|
||||
{file = "websockets-16.0-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:2b9f1e0d69bc60a4a87349d50c09a037a2607918746f07de04df9e43252c77a3"},
|
||||
{file = "websockets-16.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:335c23addf3d5e6a8633f9f8eda77efad001671e80b95c491dd0924587ece0b3"},
|
||||
{file = "websockets-16.0-cp310-cp310-win32.whl", hash = "sha256:37b31c1623c6605e4c00d466c9d633f9b812ea430c11c8a278774a1fde1acfa9"},
|
||||
{file = "websockets-16.0-cp310-cp310-win_amd64.whl", hash = "sha256:8e1dab317b6e77424356e11e99a432b7cb2f3ec8c5ab4dabbcee6add48f72b35"},
|
||||
{file = "websockets-16.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:31a52addea25187bde0797a97d6fc3d2f92b6f72a9370792d65a6e84615ac8a8"},
|
||||
{file = "websockets-16.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:417b28978cdccab24f46400586d128366313e8a96312e4b9362a4af504f3bbad"},
|
||||
{file = "websockets-16.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:af80d74d4edfa3cb9ed973a0a5ba2b2a549371f8a741e0800cb07becdd20f23d"},
|
||||
{file = "websockets-16.0-cp311-cp311-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:08d7af67b64d29823fed316505a89b86705f2b7981c07848fb5e3ea3020c1abe"},
|
||||
{file = "websockets-16.0-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:7be95cfb0a4dae143eaed2bcba8ac23f4892d8971311f1b06f3c6b78952ee70b"},
|
||||
{file = "websockets-16.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:d6297ce39ce5c2e6feb13c1a996a2ded3b6832155fcfc920265c76f24c7cceb5"},
|
||||
{file = "websockets-16.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:1c1b30e4f497b0b354057f3467f56244c603a79c0d1dafce1d16c283c25f6e64"},
|
||||
{file = "websockets-16.0-cp311-cp311-win32.whl", hash = "sha256:5f451484aeb5cafee1ccf789b1b66f535409d038c56966d6101740c1614b86c6"},
|
||||
{file = "websockets-16.0-cp311-cp311-win_amd64.whl", hash = "sha256:8d7f0659570eefb578dacde98e24fb60af35350193e4f56e11190787bee77dac"},
|
||||
{file = "websockets-16.0-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:71c989cbf3254fbd5e84d3bff31e4da39c43f884e64f2551d14bb3c186230f00"},
|
||||
{file = "websockets-16.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:8b6e209ffee39ff1b6d0fa7bfef6de950c60dfb91b8fcead17da4ee539121a79"},
|
||||
{file = "websockets-16.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:86890e837d61574c92a97496d590968b23c2ef0aeb8a9bc9421d174cd378ae39"},
|
||||
{file = "websockets-16.0-cp312-cp312-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:9b5aca38b67492ef518a8ab76851862488a478602229112c4b0d58d63a7a4d5c"},
|
||||
{file = "websockets-16.0-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:e0334872c0a37b606418ac52f6ab9cfd17317ac26365f7f65e203e2d0d0d359f"},
|
||||
{file = "websockets-16.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:a0b31e0b424cc6b5a04b8838bbaec1688834b2383256688cf47eb97412531da1"},
|
||||
{file = "websockets-16.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:485c49116d0af10ac698623c513c1cc01c9446c058a4e61e3bf6c19dff7335a2"},
|
||||
{file = "websockets-16.0-cp312-cp312-win32.whl", hash = "sha256:eaded469f5e5b7294e2bdca0ab06becb6756ea86894a47806456089298813c89"},
|
||||
{file = "websockets-16.0-cp312-cp312-win_amd64.whl", hash = "sha256:5569417dc80977fc8c2d43a86f78e0a5a22fee17565d78621b6bb264a115d4ea"},
|
||||
{file = "websockets-16.0-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:878b336ac47938b474c8f982ac2f7266a540adc3fa4ad74ae96fea9823a02cc9"},
|
||||
{file = "websockets-16.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:52a0fec0e6c8d9a784c2c78276a48a2bdf099e4ccc2a4cad53b27718dbfd0230"},
|
||||
{file = "websockets-16.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:e6578ed5b6981005df1860a56e3617f14a6c307e6a71b4fff8c48fdc50f3ed2c"},
|
||||
{file = "websockets-16.0-cp313-cp313-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:95724e638f0f9c350bb1c2b0a7ad0e83d9cc0c9259f3ea94e40d7b02a2179ae5"},
|
||||
{file = "websockets-16.0-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:c0204dc62a89dc9d50d682412c10b3542d748260d743500a85c13cd1ee4bde82"},
|
||||
{file = "websockets-16.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:52ac480f44d32970d66763115edea932f1c5b1312de36df06d6b219f6741eed8"},
|
||||
{file = "websockets-16.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:6e5a82b677f8f6f59e8dfc34ec06ca6b5b48bc4fcda346acd093694cc2c24d8f"},
|
||||
{file = "websockets-16.0-cp313-cp313-win32.whl", hash = "sha256:abf050a199613f64c886ea10f38b47770a65154dc37181bfaff70c160f45315a"},
|
||||
{file = "websockets-16.0-cp313-cp313-win_amd64.whl", hash = "sha256:3425ac5cf448801335d6fdc7ae1eb22072055417a96cc6b31b3861f455fbc156"},
|
||||
{file = "websockets-16.0-cp314-cp314-macosx_10_15_universal2.whl", hash = "sha256:8cc451a50f2aee53042ac52d2d053d08bf89bcb31ae799cb4487587661c038a0"},
|
||||
{file = "websockets-16.0-cp314-cp314-macosx_10_15_x86_64.whl", hash = "sha256:daa3b6ff70a9241cf6c7fc9e949d41232d9d7d26fd3522b1ad2b4d62487e9904"},
|
||||
{file = "websockets-16.0-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:fd3cb4adb94a2a6e2b7c0d8d05cb94e6f1c81a0cf9dc2694fb65c7e8d94c42e4"},
|
||||
{file = "websockets-16.0-cp314-cp314-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:781caf5e8eee67f663126490c2f96f40906594cb86b408a703630f95550a8c3e"},
|
||||
{file = "websockets-16.0-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:caab51a72c51973ca21fa8a18bd8165e1a0183f1ac7066a182ff27107b71e1a4"},
|
||||
{file = "websockets-16.0-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:19c4dc84098e523fd63711e563077d39e90ec6702aff4b5d9e344a60cb3c0cb1"},
|
||||
{file = "websockets-16.0-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:a5e18a238a2b2249c9a9235466b90e96ae4795672598a58772dd806edc7ac6d3"},
|
||||
{file = "websockets-16.0-cp314-cp314-win32.whl", hash = "sha256:a069d734c4a043182729edd3e9f247c3b2a4035415a9172fd0f1b71658a320a8"},
|
||||
{file = "websockets-16.0-cp314-cp314-win_amd64.whl", hash = "sha256:c0ee0e63f23914732c6d7e0cce24915c48f3f1512ec1d079ed01fc629dab269d"},
|
||||
{file = "websockets-16.0-cp314-cp314t-macosx_10_15_universal2.whl", hash = "sha256:a35539cacc3febb22b8f4d4a99cc79b104226a756aa7400adc722e83b0d03244"},
|
||||
{file = "websockets-16.0-cp314-cp314t-macosx_10_15_x86_64.whl", hash = "sha256:b784ca5de850f4ce93ec85d3269d24d4c82f22b7212023c974c401d4980ebc5e"},
|
||||
{file = "websockets-16.0-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:569d01a4e7fba956c5ae4fc988f0d4e187900f5497ce46339c996dbf24f17641"},
|
||||
{file = "websockets-16.0-cp314-cp314t-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:50f23cdd8343b984957e4077839841146f67a3d31ab0d00e6b824e74c5b2f6e8"},
|
||||
{file = "websockets-16.0-cp314-cp314t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:152284a83a00c59b759697b7f9e9cddf4e3c7861dd0d964b472b70f78f89e80e"},
|
||||
{file = "websockets-16.0-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:bc59589ab64b0022385f429b94697348a6a234e8ce22544e3681b2e9331b5944"},
|
||||
{file = "websockets-16.0-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:32da954ffa2814258030e5a57bc73a3635463238e797c7375dc8091327434206"},
|
||||
{file = "websockets-16.0-cp314-cp314t-win32.whl", hash = "sha256:5a4b4cc550cb665dd8a47f868c8d04c8230f857363ad3c9caf7a0c3bf8c61ca6"},
|
||||
{file = "websockets-16.0-cp314-cp314t-win_amd64.whl", hash = "sha256:b14dc141ed6d2dde437cddb216004bcac6a1df0935d79656387bd41632ba0bbd"},
|
||||
{file = "websockets-16.0-pp311-pypy311_pp73-macosx_10_15_x86_64.whl", hash = "sha256:349f83cd6c9a415428ee1005cadb5c2c56f4389bc06a9af16103c3bc3dcc8b7d"},
|
||||
{file = "websockets-16.0-pp311-pypy311_pp73-macosx_11_0_arm64.whl", hash = "sha256:4a1aba3340a8dca8db6eb5a7986157f52eb9e436b74813764241981ca4888f03"},
|
||||
{file = "websockets-16.0-pp311-pypy311_pp73-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:f4a32d1bd841d4bcbffdcb3d2ce50c09c3909fbead375ab28d0181af89fd04da"},
|
||||
{file = "websockets-16.0-pp311-pypy311_pp73-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:0298d07ee155e2e9fda5be8a9042200dd2e3bb0b8a38482156576f863a9d457c"},
|
||||
{file = "websockets-16.0-pp311-pypy311_pp73-win_amd64.whl", hash = "sha256:a653aea902e0324b52f1613332ddf50b00c06fdaf7e92624fbf8c77c78fa5767"},
|
||||
{file = "websockets-16.0-py3-none-any.whl", hash = "sha256:1637db62fad1dc833276dded54215f2c7fa46912301a24bd94d45d46a011ceec"},
|
||||
{file = "websockets-16.0.tar.gz", hash = "sha256:5f6261a5e56e8d5c42a4497b364ea24d94d9563e8fbd44e78ac40879c60179b5"},
|
||||
]
|
||||
|
||||
[metadata]
|
||||
lock-version = "2.1"
|
||||
python-versions = ">=3.10"
|
||||
content-hash = "7ddbf7d85e9ae7bd3a1b99ae481df20aaf6fd185d5f628b0fdf9b7bd278730ed"
|
||||
content-hash = "b1e5ddd7284aecf49624e51740b7a4c31bc8d0e703c255126ba5d9b2a4a0e519"
|
||||
|
|
|
|||
|
|
@ -100,7 +100,6 @@ optional = true
|
|||
|
||||
[tool.poetry.group.agent.dependencies]
|
||||
requests = ">=2.28,<3.0"
|
||||
websockets = ">=12.0,<14.0"
|
||||
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
flake8 = "^7.1.0"
|
||||
|
|
@ -117,8 +116,7 @@ pylint = "^3.2.6" # For pyreverse, to automate the creation of UML diagrams
|
|||
ria = "ria_toolkit_oss_cli.cli:cli"
|
||||
ria-tools = "ria_toolkit_oss_cli.cli:cli"
|
||||
ria-server = "ria_toolkit_oss.server.cli:serve"
|
||||
ria-agent = "ria_toolkit_oss.agent.cli:main"
|
||||
ria-app = "ria_toolkit_oss.app.cli:main"
|
||||
ria-agent = "ria_toolkit_oss.agent:main"
|
||||
|
||||
[tool.poetry.group.server.dependencies]
|
||||
fastapi = ">=0.111,<1.0"
|
||||
|
|
|
|||
|
|
@ -1,225 +0,0 @@
|
|||
#!/usr/bin/env python3
|
||||
"""Transmit a continuous tone through the agent's TX pipeline on a real Pluto.
|
||||
|
||||
End-to-end smoke test for the Pluto + Streamer TX path. Drives the same
|
||||
``Streamer`` the hub talks to, but in-process with a logging ``FakeWs`` so
|
||||
the script is self-contained — no hub required.
|
||||
|
||||
Default: 100 kHz baseband tone × 2 450 MHz LO → carrier at 2 450.1 MHz,
|
||||
continuous until you Ctrl-C (or the ``--duration`` timer fires). A spectrum
|
||||
analyzer tuned to 2 450.1 MHz should show a clean CW spike as long as
|
||||
``tx_status: transmitting`` prints.
|
||||
|
||||
Usage::
|
||||
|
||||
python3 scripts/pluto_tx_smoke.py # auto-discover Pluto
|
||||
python3 scripts/pluto_tx_smoke.py --identifier 192.168.3.1
|
||||
python3 scripts/pluto_tx_smoke.py --frequency 2.4e9 --gain -20 --duration 60
|
||||
|
||||
Flags map 1:1 onto the agent's ``radio_config``:
|
||||
|
||||
--identifier Pluto IP or hostname (omitted → ip:pluto.local).
|
||||
--frequency TX LO in Hz. Default 2 450 MHz.
|
||||
--gain Pluto TX gain in dB. Pluto range is ``[-89, 0]``; more negative
|
||||
= more attenuation = less power. Default -30.
|
||||
--sample-rate Baseband sample rate. Default 1 MHz.
|
||||
--tone Baseband tone offset in Hz. Default 100 kHz; set 0 for DC
|
||||
(unmodulated carrier at exactly --frequency, but Pluto's
|
||||
LO leakage will dominate).
|
||||
--buffer-size Complex samples per WS frame. Default 4096.
|
||||
--duration Stop after this many seconds (0 = run until Ctrl-C).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import logging
|
||||
import signal
|
||||
import sys
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ria_toolkit_oss.agent.config import AgentConfig
|
||||
from ria_toolkit_oss.agent.streamer import Streamer
|
||||
|
||||
|
||||
class LoggingFakeWs:
|
||||
"""In-process stand-in for the hub's WebSocket.
|
||||
|
||||
Prints every ``tx_status`` + ``error`` frame the Streamer emits so the
|
||||
operator can watch the lifecycle (armed → transmitting → done) on stdout.
|
||||
"""
|
||||
|
||||
async def send_json(self, payload: dict) -> None:
|
||||
t = payload.get("type")
|
||||
if t == "tx_status":
|
||||
state = payload.get("state")
|
||||
msg = payload.get("message")
|
||||
tail = f" — {msg}" if msg else ""
|
||||
print(f"[tx_status] {state}{tail}")
|
||||
elif t == "error":
|
||||
print(f"[error] {payload.get('message')}")
|
||||
|
||||
async def send_bytes(self, data: bytes) -> None:
|
||||
# Agent side won't send RX bytes in this script (no RX session).
|
||||
pass
|
||||
|
||||
|
||||
def _make_iq_frame(buffer_size: int, tone_hz: float, sample_rate: float,
|
||||
phase_offset: float = 0.0) -> tuple[bytes, float]:
|
||||
"""Return ``(interleaved_float32_bytes, next_phase)`` for a sine tone.
|
||||
|
||||
Emitting one continuous phase-coherent tone requires threading the phase
|
||||
across frames; the returned ``next_phase`` should be fed back as
|
||||
``phase_offset`` on the next call so the sinusoid doesn't glitch at frame
|
||||
boundaries. Amplitude is 0.7 to leave some headroom below the [-1, 1] cap
|
||||
that ``_verify_sample_format`` polices elsewhere in the toolkit.
|
||||
"""
|
||||
n = np.arange(buffer_size, dtype=np.float64)
|
||||
phase = 2.0 * np.pi * tone_hz / sample_rate * n + phase_offset
|
||||
amp = 0.7
|
||||
iq = amp * (np.cos(phase) + 1j * np.sin(phase))
|
||||
iq = iq.astype(np.complex64)
|
||||
interleaved = np.empty(buffer_size * 2, dtype=np.float32)
|
||||
interleaved[0::2] = iq.real
|
||||
interleaved[1::2] = iq.imag
|
||||
next_phase = (2.0 * np.pi * tone_hz / sample_rate * buffer_size + phase_offset) % (2.0 * np.pi)
|
||||
return interleaved.tobytes(), next_phase
|
||||
|
||||
|
||||
def _make_pluto_factory(identifier: str | None):
|
||||
def factory(device: str, _ident: str | None):
|
||||
if device != "pluto":
|
||||
raise ValueError(f"this script only drives pluto; got device={device!r}")
|
||||
from ria_toolkit_oss.sdr.pluto import Pluto
|
||||
return Pluto(identifier=identifier)
|
||||
return factory
|
||||
|
||||
|
||||
async def _run(args: argparse.Namespace) -> int:
|
||||
ws = LoggingFakeWs()
|
||||
cfg = AgentConfig(
|
||||
tx_enabled=True,
|
||||
# Pluto's TX gain range is [-89, 0]. Cap at 0 so a fat-fingered
|
||||
# --gain=+5 still gets rejected at the agent boundary rather than
|
||||
# turned into mystery attenuation by Pluto's setter.
|
||||
tx_max_gain_db=0.0,
|
||||
tx_max_duration_s=float(args.duration) if args.duration > 0 else None,
|
||||
)
|
||||
streamer = Streamer(ws=ws, sdr_factory=_make_pluto_factory(args.identifier), cfg=cfg)
|
||||
|
||||
await streamer.on_message(
|
||||
{
|
||||
"type": "tx_start",
|
||||
"app_id": "smoke",
|
||||
"radio_config": {
|
||||
"device": "pluto",
|
||||
"identifier": args.identifier,
|
||||
"tx_sample_rate": int(args.sample_rate),
|
||||
"tx_center_frequency": int(args.frequency),
|
||||
"tx_gain": int(args.gain),
|
||||
"buffer_size": int(args.buffer_size),
|
||||
# "repeat" keeps the last buffer on the air if we ever stall,
|
||||
# so a continuous carrier stays up even when Python GC or
|
||||
# asyncio scheduling briefly pauses the producer.
|
||||
"underrun_policy": "repeat",
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
# Abort if tx_start was rejected by an interlock (no session → nothing to do).
|
||||
if streamer._tx is None:
|
||||
print("tx_start rejected — see [tx_status] line above for the reason.",
|
||||
file=sys.stderr)
|
||||
return 2
|
||||
|
||||
print(f"Transmitting at {args.frequency/1e6:.3f} MHz with "
|
||||
f"{args.tone/1e3:.1f} kHz baseband tone at gain {args.gain} dB. "
|
||||
f"{'Running for ' + str(args.duration) + 's' if args.duration > 0 else 'Run until Ctrl-C'}.")
|
||||
|
||||
# Arrange a clean shutdown on Ctrl-C.
|
||||
stop = asyncio.Event()
|
||||
loop = asyncio.get_running_loop()
|
||||
for sig in (signal.SIGINT, signal.SIGTERM):
|
||||
try:
|
||||
loop.add_signal_handler(sig, stop.set)
|
||||
except NotImplementedError:
|
||||
# add_signal_handler is not available on Windows event loops.
|
||||
pass
|
||||
|
||||
# Produce buffers at the nominal sample-rate pace. We deliberately stay
|
||||
# slightly ahead of the radio — queue is bounded at 8, so backpressure
|
||||
# flows naturally.
|
||||
phase = 0.0
|
||||
buffer_dt = args.buffer_size / args.sample_rate
|
||||
# Aim for one buffer every ``buffer_dt * 0.5`` seconds so the queue stays
|
||||
# topped up. The queue's own backpressure keeps us from spinning.
|
||||
produce_interval = buffer_dt * 0.5
|
||||
try:
|
||||
async def producer():
|
||||
nonlocal phase
|
||||
while not stop.is_set():
|
||||
frame, phase = _make_iq_frame(
|
||||
args.buffer_size, args.tone, args.sample_rate, phase
|
||||
)
|
||||
await streamer.on_binary(frame)
|
||||
await asyncio.sleep(produce_interval)
|
||||
|
||||
producer_task = asyncio.create_task(producer())
|
||||
|
||||
if args.duration > 0:
|
||||
try:
|
||||
await asyncio.wait_for(stop.wait(), timeout=args.duration)
|
||||
except asyncio.TimeoutError:
|
||||
pass
|
||||
else:
|
||||
await stop.wait()
|
||||
|
||||
stop.set()
|
||||
producer_task.cancel()
|
||||
try:
|
||||
await producer_task
|
||||
except (asyncio.CancelledError, Exception):
|
||||
pass
|
||||
finally:
|
||||
await streamer.on_message({"type": "tx_stop", "app_id": "smoke"})
|
||||
|
||||
print("TX session closed.")
|
||||
return 0
|
||||
|
||||
|
||||
def main() -> int:
|
||||
p = argparse.ArgumentParser(
|
||||
description="End-to-end TX smoke test: agent → Pluto continuous tone.",
|
||||
)
|
||||
p.add_argument("--identifier", default=None,
|
||||
help="Pluto IP/hostname (default: auto-discover pluto.local)")
|
||||
p.add_argument("--frequency", type=float, default=3_410_000_000.0,
|
||||
help="TX LO in Hz (default 2.45 GHz)")
|
||||
p.add_argument("--gain", type=float, default=-0.0,
|
||||
help="TX gain in dB; Pluto range [-89, 0] (default -30)")
|
||||
p.add_argument("--sample-rate", type=float, default=1_000_000.0,
|
||||
help="Baseband sample rate (default 1 Msps)")
|
||||
p.add_argument("--tone", type=float, default=100_000.0,
|
||||
help="Baseband tone offset in Hz; 0 = DC (default 100 kHz)")
|
||||
p.add_argument("--buffer-size", type=int, default=4096,
|
||||
help="Complex samples per frame (default 4096)")
|
||||
p.add_argument("--duration", type=float, default=60.0,
|
||||
help="Seconds to transmit; 0 = run until Ctrl-C (default 30)")
|
||||
p.add_argument("--log-level", default="INFO")
|
||||
args = p.parse_args()
|
||||
|
||||
logging.basicConfig(
|
||||
level=getattr(logging, args.log_level.upper(), logging.INFO),
|
||||
format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
|
||||
)
|
||||
|
||||
try:
|
||||
return asyncio.run(_run(args))
|
||||
except KeyboardInterrupt:
|
||||
return 130
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
|
|
@ -1,236 +0,0 @@
|
|||
#!/usr/bin/env python3
|
||||
"""Full-stack TX smoke test: localhost mock-hub → WS → agent → real Pluto.
|
||||
|
||||
Same radio output as ``pluto_tx_smoke.py`` (continuous tone at 2 450.1 MHz),
|
||||
but drives the agent through the *real* WebSocket path instead of calling
|
||||
handlers in-process. Proves that the hub-driven path behaves identically:
|
||||
|
||||
mock hub ── ws:// ──▶ WsClient.run() ──▶ Streamer.on_message
|
||||
└▶ Streamer.on_binary
|
||||
│
|
||||
▼
|
||||
real Pluto
|
||||
|
||||
This is the most rigorous check short of pointing the real ``ria-agent stream``
|
||||
at a live ria-hub. If a tone appears on the spectrum analyzer here but *not*
|
||||
when ria-hub drives it, the fault is above the WS decoder (registration,
|
||||
capability gate, TX operator, hub's binary-frame publisher); everything
|
||||
downstream of ``ws.recv()`` is this script's code path.
|
||||
|
||||
Usage::
|
||||
|
||||
python3 scripts/pluto_tx_ws_smoke.py # default 30s tone
|
||||
python3 scripts/pluto_tx_ws_smoke.py --identifier 192.168.3.1
|
||||
python3 scripts/pluto_tx_ws_smoke.py --duration 0 # until Ctrl-C
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import signal
|
||||
import sys
|
||||
|
||||
import numpy as np
|
||||
import websockets
|
||||
|
||||
from ria_toolkit_oss.agent.config import AgentConfig
|
||||
from ria_toolkit_oss.agent.streamer import Streamer
|
||||
from ria_toolkit_oss.agent.ws_client import WsClient
|
||||
|
||||
|
||||
def _make_iq_frame(buffer_size: int, tone_hz: float, sample_rate: float,
|
||||
phase_offset: float) -> tuple[bytes, float]:
|
||||
n = np.arange(buffer_size, dtype=np.float64)
|
||||
phase = 2.0 * np.pi * tone_hz / sample_rate * n + phase_offset
|
||||
amp = 0.7
|
||||
iq = (amp * (np.cos(phase) + 1j * np.sin(phase))).astype(np.complex64)
|
||||
interleaved = np.empty(buffer_size * 2, dtype=np.float32)
|
||||
interleaved[0::2] = iq.real
|
||||
interleaved[1::2] = iq.imag
|
||||
next_phase = (2.0 * np.pi * tone_hz / sample_rate * buffer_size + phase_offset) % (2.0 * np.pi)
|
||||
return interleaved.tobytes(), next_phase
|
||||
|
||||
|
||||
def _make_pluto_factory(identifier: str | None):
|
||||
def factory(device: str, _ident: str | None):
|
||||
if device != "pluto":
|
||||
raise ValueError(f"this script only drives pluto; got device={device!r}")
|
||||
from ria_toolkit_oss.sdr.pluto import Pluto
|
||||
return Pluto(identifier=identifier)
|
||||
return factory
|
||||
|
||||
|
||||
async def _mock_hub_handler(ws, args, stop: asyncio.Event):
|
||||
"""Server side of the WS. Sends tx_start, streams IQ, then tx_stop."""
|
||||
# Drain the first heartbeat so the log is clean; we don't need to gate on
|
||||
# it for a localhost smoke test.
|
||||
try:
|
||||
first = await asyncio.wait_for(ws.recv(), timeout=2.0)
|
||||
if isinstance(first, str):
|
||||
payload = json.loads(first)
|
||||
if payload.get("type") == "heartbeat":
|
||||
caps = payload.get("capabilities")
|
||||
print(f"[mock-hub] agent heartbeat: capabilities={caps} "
|
||||
f"tx_enabled={payload.get('tx_enabled')}")
|
||||
except asyncio.TimeoutError:
|
||||
print("[mock-hub] warning: no heartbeat received in first 2s")
|
||||
|
||||
# Arm the agent's TX path.
|
||||
await ws.send(json.dumps({
|
||||
"type": "tx_start",
|
||||
"app_id": "ws-smoke",
|
||||
"radio_config": {
|
||||
"device": "pluto",
|
||||
"identifier": args.identifier,
|
||||
"tx_sample_rate": int(args.sample_rate),
|
||||
"tx_center_frequency": int(args.frequency),
|
||||
"tx_gain": int(args.gain),
|
||||
"buffer_size": int(args.buffer_size),
|
||||
"underrun_policy": "repeat",
|
||||
},
|
||||
}))
|
||||
print(f"[mock-hub] sent tx_start at {args.frequency/1e6:.3f} MHz, "
|
||||
f"gain={args.gain} dB")
|
||||
|
||||
# Producer: push IQ frames at a steady clip. Use a concurrent receiver so
|
||||
# tx_status frames show up in real time rather than being queued behind
|
||||
# the sends.
|
||||
phase = 0.0
|
||||
buffer_dt = args.buffer_size / args.sample_rate
|
||||
|
||||
async def receiver():
|
||||
try:
|
||||
while True:
|
||||
msg = await ws.recv()
|
||||
if isinstance(msg, str):
|
||||
print(f"[mock-hub] ← {msg}")
|
||||
except (websockets.ConnectionClosed, asyncio.CancelledError):
|
||||
pass
|
||||
|
||||
recv_task = asyncio.create_task(receiver())
|
||||
try:
|
||||
deadline = None if args.duration <= 0 else (
|
||||
asyncio.get_event_loop().time() + args.duration
|
||||
)
|
||||
while not stop.is_set():
|
||||
if deadline is not None and asyncio.get_event_loop().time() >= deadline:
|
||||
break
|
||||
frame, phase = _make_iq_frame(
|
||||
args.buffer_size, args.tone, args.sample_rate, phase
|
||||
)
|
||||
try:
|
||||
await ws.send(frame)
|
||||
except websockets.ConnectionClosed:
|
||||
break
|
||||
# Slightly ahead of real-time; WS backpressure handles the rest.
|
||||
await asyncio.sleep(buffer_dt * 0.5)
|
||||
finally:
|
||||
try:
|
||||
await ws.send(json.dumps({"type": "tx_stop", "app_id": "ws-smoke"}))
|
||||
print("[mock-hub] sent tx_stop")
|
||||
except websockets.ConnectionClosed:
|
||||
pass
|
||||
# Give the agent a moment to emit `tx_status: done` before we tear down.
|
||||
await asyncio.sleep(0.3)
|
||||
recv_task.cancel()
|
||||
try:
|
||||
await recv_task
|
||||
except (asyncio.CancelledError, Exception):
|
||||
pass
|
||||
|
||||
|
||||
async def _run(args: argparse.Namespace) -> int:
|
||||
stop = asyncio.Event()
|
||||
loop = asyncio.get_running_loop()
|
||||
for sig in (signal.SIGINT, signal.SIGTERM):
|
||||
try:
|
||||
loop.add_signal_handler(sig, stop.set)
|
||||
except NotImplementedError:
|
||||
pass
|
||||
|
||||
# Start the mock hub on a local port.
|
||||
async def handler(ws):
|
||||
try:
|
||||
await _mock_hub_handler(ws, args, stop)
|
||||
finally:
|
||||
stop.set()
|
||||
|
||||
server = await websockets.serve(handler, "127.0.0.1", 0)
|
||||
port = server.sockets[0].getsockname()[1]
|
||||
print(f"[mock-hub] listening on ws://127.0.0.1:{port}")
|
||||
|
||||
# Run the agent — exactly as ``ria-agent stream`` would, just with a
|
||||
# different URL and an in-memory AgentConfig instead of one loaded from
|
||||
# ``~/.ria/agent.json``.
|
||||
client = WsClient(
|
||||
f"ws://127.0.0.1:{port}",
|
||||
token="",
|
||||
heartbeat_interval=5.0,
|
||||
reconnect_pause=0.5,
|
||||
)
|
||||
streamer = Streamer(
|
||||
ws=client,
|
||||
sdr_factory=_make_pluto_factory(args.identifier),
|
||||
cfg=AgentConfig(tx_enabled=True, tx_max_gain_db=0.0),
|
||||
)
|
||||
client_task = asyncio.create_task(
|
||||
client.run(
|
||||
on_message=streamer.on_message,
|
||||
heartbeat=streamer.build_heartbeat,
|
||||
on_binary=streamer.on_binary,
|
||||
)
|
||||
)
|
||||
|
||||
try:
|
||||
await stop.wait()
|
||||
finally:
|
||||
client.stop()
|
||||
client_task.cancel()
|
||||
try:
|
||||
await client_task
|
||||
except (asyncio.CancelledError, Exception):
|
||||
pass
|
||||
server.close()
|
||||
await server.wait_closed()
|
||||
|
||||
print("Done.")
|
||||
return 0
|
||||
|
||||
|
||||
def main() -> int:
|
||||
p = argparse.ArgumentParser(
|
||||
description="Full-stack TX smoke: localhost mock-hub → WS → agent → Pluto.",
|
||||
)
|
||||
p.add_argument("--identifier", default=None,
|
||||
help="Pluto IP/hostname (default: auto-discover pluto.local)")
|
||||
p.add_argument("--frequency", type=float, default=2_450_000_000.0,
|
||||
help="TX LO in Hz (default 2.45 GHz)")
|
||||
p.add_argument("--gain", type=float, default=0.0,
|
||||
help="TX gain in dB; Pluto range [-89, 0] (default 0)")
|
||||
p.add_argument("--sample-rate", type=float, default=1_000_000.0,
|
||||
help="Baseband sample rate (default 1 Msps)")
|
||||
p.add_argument("--tone", type=float, default=100_000.0,
|
||||
help="Baseband tone offset in Hz (default 100 kHz)")
|
||||
p.add_argument("--buffer-size", type=int, default=4096,
|
||||
help="Complex samples per frame (default 4096)")
|
||||
p.add_argument("--duration", type=float, default=30.0,
|
||||
help="Seconds to transmit; 0 = run until Ctrl-C (default 30)")
|
||||
p.add_argument("--log-level", default="INFO")
|
||||
args = p.parse_args()
|
||||
|
||||
logging.basicConfig(
|
||||
level=getattr(logging, args.log_level.upper(), logging.INFO),
|
||||
format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
|
||||
)
|
||||
|
||||
try:
|
||||
return asyncio.run(_run(args))
|
||||
except KeyboardInterrupt:
|
||||
return 130
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
|
|
@ -1,26 +0,0 @@
|
|||
"""RIA Toolkit agent package.
|
||||
|
||||
Provides two execution modes:
|
||||
|
||||
- **Legacy long-poll executor** (`NodeAgent` in :mod:`legacy_executor`) — an
|
||||
HTTP long-polling agent that runs ONNX inference locally on the host.
|
||||
- **Streamer** (:mod:`streamer`) — a thin WebSocket client that opens an SDR
|
||||
and streams raw IQ to the RIA Hub server, which performs all inference.
|
||||
|
||||
Back-compat: ``from ria_toolkit_oss.agent import NodeAgent`` and the ``main``
|
||||
entry point continue to work.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from .legacy_executor import NodeAgent
|
||||
from .legacy_executor import main as _legacy_main
|
||||
|
||||
__all__ = ["NodeAgent", "main"]
|
||||
|
||||
|
||||
def main() -> None:
|
||||
"""Unified CLI entry point. Dispatches to streamer/legacy subcommands."""
|
||||
from .cli import main as _cli_main
|
||||
|
||||
_cli_main()
|
||||
|
|
@ -1,212 +0,0 @@
|
|||
"""Unified ``ria-agent`` CLI.
|
||||
|
||||
Subcommands:
|
||||
|
||||
- ``ria-agent run [legacy args]`` — legacy long-poll NodeAgent (unchanged).
|
||||
- ``ria-agent stream`` — new WebSocket-based IQ streamer.
|
||||
- ``ria-agent detect`` — print SDR drivers whose modules import cleanly.
|
||||
- ``ria-agent register --hub URL --api-key KEY`` — register with the hub and
|
||||
save credentials (and optional TX interlocks) to ``~/.ria/agent.json``.
|
||||
|
||||
Invoking ``ria-agent`` with no subcommand falls through to the legacy
|
||||
long-poll behavior for back-compatibility with existing deployments.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import sys
|
||||
|
||||
from . import config as _config
|
||||
from .hardware import available_devices
|
||||
from .legacy_executor import main as _legacy_main
|
||||
from .namegen import generate_agent_name
|
||||
|
||||
_LEGACY_ALIASES = {"--hub", "--key", "--name", "--device", "--insecure", "--log-level", "--config"}
|
||||
|
||||
|
||||
def _cmd_detect(_args: argparse.Namespace) -> int:
|
||||
devices = available_devices()
|
||||
if not devices:
|
||||
print("No SDR drivers available (install ria-toolkit-oss[all-sdr] or per-driver extras).")
|
||||
return 0
|
||||
for name in devices:
|
||||
print(name)
|
||||
return 0
|
||||
|
||||
|
||||
def _cmd_register(args: argparse.Namespace) -> int:
|
||||
import urllib.request
|
||||
|
||||
hub_url = args.hub.rstrip("/")
|
||||
url = f"{hub_url}/screens/agents/register"
|
||||
name = args.name or generate_agent_name()
|
||||
body = json.dumps({"name": name}).encode()
|
||||
req = urllib.request.Request(
|
||||
url,
|
||||
data=body,
|
||||
headers={
|
||||
"Content-Type": "application/json",
|
||||
"X-API-Key": args.api_key,
|
||||
},
|
||||
)
|
||||
try:
|
||||
with urllib.request.urlopen(req) as resp:
|
||||
data = json.loads(resp.read())
|
||||
except Exception as e:
|
||||
print(f"error: registration failed: {e}", file=sys.stderr)
|
||||
return 1
|
||||
|
||||
agent_id = data["agent_id"]
|
||||
token = data["token"]
|
||||
|
||||
cfg = _config.load()
|
||||
cfg.hub_url = hub_url
|
||||
cfg.agent_id = agent_id
|
||||
cfg.token = token
|
||||
cfg.api_key = args.api_key
|
||||
cfg.name = name
|
||||
cfg.insecure = bool(args.insecure)
|
||||
cfg.tx_enabled = bool(getattr(args, "allow_tx", False))
|
||||
if (v := getattr(args, "tx_max_gain_db", None)) is not None:
|
||||
cfg.tx_max_gain_db = float(v)
|
||||
if (v := getattr(args, "tx_max_duration_s", None)) is not None:
|
||||
cfg.tx_max_duration_s = float(v)
|
||||
freq_ranges = getattr(args, "tx_freq_range", None) or []
|
||||
if freq_ranges:
|
||||
cfg.tx_allowed_freq_ranges = [[float(lo), float(hi)] for lo, hi in freq_ranges]
|
||||
path = _config.save(cfg)
|
||||
|
||||
print(f"Registered agent: {agent_id}")
|
||||
if cfg.tx_enabled:
|
||||
caps: list[str] = []
|
||||
if cfg.tx_max_gain_db is not None:
|
||||
caps.append(f"gain<={cfg.tx_max_gain_db} dB")
|
||||
if cfg.tx_max_duration_s is not None:
|
||||
caps.append(f"duration<={cfg.tx_max_duration_s} s")
|
||||
if cfg.tx_allowed_freq_ranges:
|
||||
caps.append(f"freq in {cfg.tx_allowed_freq_ranges}")
|
||||
tail = f" ({', '.join(caps)})" if caps else ""
|
||||
print(f"TX enabled{tail}")
|
||||
print(f"Credentials saved to {path}")
|
||||
return 0
|
||||
|
||||
|
||||
def _cmd_stream(args: argparse.Namespace) -> int:
|
||||
from .streamer import run_streamer
|
||||
|
||||
cfg = _config.load()
|
||||
url = args.url or _derive_ws_url(cfg.hub_url, cfg.agent_id)
|
||||
token = args.token or cfg.token
|
||||
if not url:
|
||||
print("error: --url is required (or run `ria-agent register` first)", file=sys.stderr)
|
||||
return 2
|
||||
if getattr(args, "allow_tx", False):
|
||||
cfg.tx_enabled = True
|
||||
try:
|
||||
asyncio.run(run_streamer(url, token, cfg=cfg))
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
return 0
|
||||
|
||||
|
||||
def _derive_ws_url(hub_url: str, agent_id: str) -> str:
|
||||
if not hub_url:
|
||||
return ""
|
||||
base = hub_url.rstrip("/")
|
||||
if base.startswith("https://"):
|
||||
base = "wss://" + base[len("https://"):]
|
||||
elif base.startswith("http://"):
|
||||
base = "ws://" + base[len("http://"):]
|
||||
suffix = f"/screens/agent/ws?agent_id={agent_id}" if agent_id else "/screens/agent/ws"
|
||||
return base + suffix
|
||||
|
||||
|
||||
def main() -> None:
|
||||
# Back-compat: if the first non-flag token matches a known legacy flag,
|
||||
# or there is no subcommand at all, dispatch to the legacy CLI.
|
||||
argv = sys.argv[1:]
|
||||
if not argv or (argv[0].startswith("--") and argv[0] in _LEGACY_ALIASES):
|
||||
_legacy_main()
|
||||
return
|
||||
|
||||
parser = argparse.ArgumentParser(prog="ria-agent")
|
||||
sub = parser.add_subparsers(dest="command", required=True)
|
||||
|
||||
sub.add_parser("run", help="Legacy long-poll agent (NodeAgent)")
|
||||
sub.add_parser("detect", help="List available SDR drivers")
|
||||
|
||||
p_reg = sub.add_parser("register", help="Register agent with RIA Hub and save credentials")
|
||||
p_reg.add_argument("--hub", required=True, help="RIA Hub URL (e.g. http://whitehorse:3005)")
|
||||
p_reg.add_argument("--api-key", dest="api_key", required=True, help="Hub API key")
|
||||
p_reg.add_argument("--name", default=None, help="Human-friendly agent name")
|
||||
p_reg.add_argument("--insecure", action="store_true", help="Skip TLS verification")
|
||||
p_reg.add_argument(
|
||||
"--allow-tx",
|
||||
dest="allow_tx",
|
||||
action="store_true",
|
||||
help="Opt this agent in to TX (required for any transmission from the hub)",
|
||||
)
|
||||
p_reg.add_argument(
|
||||
"--tx-max-gain-db",
|
||||
dest="tx_max_gain_db",
|
||||
type=float,
|
||||
default=None,
|
||||
help="Reject tx_start frames whose tx_gain exceeds this cap (dB)",
|
||||
)
|
||||
p_reg.add_argument(
|
||||
"--tx-max-duration-s",
|
||||
dest="tx_max_duration_s",
|
||||
type=float,
|
||||
default=None,
|
||||
help="Auto-stop any TX session after this many seconds",
|
||||
)
|
||||
p_reg.add_argument(
|
||||
"--tx-freq-range",
|
||||
dest="tx_freq_range",
|
||||
type=float,
|
||||
nargs=2,
|
||||
action="append",
|
||||
metavar=("LO", "HI"),
|
||||
default=None,
|
||||
help="Allowed TX center-frequency range in Hz (repeat for multiple bands)",
|
||||
)
|
||||
|
||||
p_stream = sub.add_parser("stream", help="Run the WebSocket IQ streamer")
|
||||
p_stream.add_argument("--url", default=None, help="Override WebSocket URL")
|
||||
p_stream.add_argument("--token", default=None, help="Override bearer token")
|
||||
p_stream.add_argument("--log-level", default="INFO")
|
||||
p_stream.add_argument(
|
||||
"--allow-tx",
|
||||
dest="allow_tx",
|
||||
action="store_true",
|
||||
help="Runtime override: enable TX for this process without writing config",
|
||||
)
|
||||
|
||||
# Unknown extras are forwarded to the legacy CLI when command == "run".
|
||||
args, extras = parser.parse_known_args(argv)
|
||||
|
||||
logging.basicConfig(
|
||||
level=getattr(logging, getattr(args, "log_level", "INFO"), logging.INFO),
|
||||
format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
|
||||
)
|
||||
|
||||
if args.command == "run":
|
||||
sys.argv = [sys.argv[0], *extras]
|
||||
_legacy_main()
|
||||
return
|
||||
if args.command == "detect":
|
||||
sys.exit(_cmd_detect(args))
|
||||
if args.command == "register":
|
||||
sys.exit(_cmd_register(args))
|
||||
if args.command == "stream":
|
||||
sys.exit(_cmd_stream(args))
|
||||
|
||||
parser.error(f"unknown command: {args.command}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
@ -1,88 +0,0 @@
|
|||
"""Agent configuration stored at ``~/.ria/agent.json``.
|
||||
|
||||
Schema::
|
||||
|
||||
{
|
||||
"hub_url": "https://riahub.example.com",
|
||||
"agent_id": "agent-abc123",
|
||||
"token": "rha_xxxx",
|
||||
"name": "lab-bench-1",
|
||||
"insecure": false,
|
||||
"tx_enabled": false,
|
||||
"tx_max_gain_db": null,
|
||||
"tx_max_duration_s": null,
|
||||
"tx_allowed_freq_ranges": null
|
||||
}
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
from dataclasses import asdict, dataclass, field
|
||||
from pathlib import Path
|
||||
|
||||
def _resolve_default_path() -> Path:
|
||||
return Path(os.environ.get("RIA_AGENT_CONFIG", str(Path.home() / ".ria" / "agent.json")))
|
||||
|
||||
|
||||
@dataclass
|
||||
class AgentConfig:
|
||||
hub_url: str = ""
|
||||
agent_id: str = ""
|
||||
token: str = ""
|
||||
name: str = ""
|
||||
insecure: bool = False
|
||||
api_key: str = ""
|
||||
tx_enabled: bool = False
|
||||
tx_max_gain_db: float | None = None
|
||||
tx_max_duration_s: float | None = None
|
||||
tx_allowed_freq_ranges: list[list[float]] | None = None
|
||||
extra: dict = field(default_factory=dict)
|
||||
|
||||
|
||||
def default_path() -> Path:
|
||||
return _resolve_default_path()
|
||||
|
||||
|
||||
def _coerce_ranges(raw) -> list[list[float]] | None:
|
||||
if raw is None:
|
||||
return None
|
||||
out: list[list[float]] = []
|
||||
for pair in raw:
|
||||
lo, hi = pair
|
||||
out.append([float(lo), float(hi)])
|
||||
return out
|
||||
|
||||
|
||||
def load(path: Path | None = None) -> AgentConfig:
|
||||
p = path or _resolve_default_path()
|
||||
if not p.exists():
|
||||
return AgentConfig()
|
||||
data = json.loads(p.read_text())
|
||||
known = {f for f in AgentConfig.__dataclass_fields__ if f != "extra"}
|
||||
extra = {k: v for k, v in data.items() if k not in known}
|
||||
return AgentConfig(
|
||||
hub_url=data.get("hub_url", ""),
|
||||
agent_id=data.get("agent_id", ""),
|
||||
token=data.get("token", ""),
|
||||
name=data.get("name", ""),
|
||||
insecure=bool(data.get("insecure", False)),
|
||||
api_key=data.get("api_key", ""),
|
||||
tx_enabled=bool(data.get("tx_enabled", False)),
|
||||
tx_max_gain_db=(float(v) if (v := data.get("tx_max_gain_db")) is not None else None),
|
||||
tx_max_duration_s=(float(v) if (v := data.get("tx_max_duration_s")) is not None else None),
|
||||
tx_allowed_freq_ranges=_coerce_ranges(data.get("tx_allowed_freq_ranges")),
|
||||
extra=extra,
|
||||
)
|
||||
|
||||
|
||||
def save(cfg: AgentConfig, path: Path | None = None) -> Path:
|
||||
p = path or _resolve_default_path()
|
||||
p.parent.mkdir(parents=True, exist_ok=True)
|
||||
data = asdict(cfg)
|
||||
extra = data.pop("extra", {}) or {}
|
||||
data.update(extra)
|
||||
p.write_text(json.dumps(data, indent=2))
|
||||
os.chmod(p, 0o600)
|
||||
return p
|
||||
|
|
@ -1,56 +0,0 @@
|
|||
"""Hardware detection and heartbeat payload construction for the streamer."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from ria_toolkit_oss.sdr import detect_available
|
||||
|
||||
from .config import AgentConfig
|
||||
|
||||
|
||||
def available_devices() -> list[str]:
|
||||
"""Return a sorted list of device names whose driver modules import cleanly."""
|
||||
return sorted(detect_available().keys())
|
||||
|
||||
|
||||
def heartbeat_payload(
|
||||
status: str = "idle",
|
||||
app_id: str | None = None,
|
||||
*,
|
||||
cfg: AgentConfig | None = None,
|
||||
sessions: dict | None = None,
|
||||
) -> dict:
|
||||
"""Build the JSON body of a periodic heartbeat frame.
|
||||
|
||||
*cfg* drives the ``capabilities`` list and the ``tx_enabled`` flag. If not
|
||||
supplied, the heartbeat advertises RX-only with ``tx_enabled=False`` —
|
||||
matching the pre-TX shape.
|
||||
"""
|
||||
c = cfg or AgentConfig()
|
||||
capabilities = ["rx"]
|
||||
if c.tx_enabled:
|
||||
capabilities.append("tx")
|
||||
|
||||
payload: dict = {
|
||||
"type": "heartbeat",
|
||||
"hardware": available_devices(),
|
||||
"status": status,
|
||||
"capabilities": capabilities,
|
||||
"tx_enabled": bool(c.tx_enabled),
|
||||
}
|
||||
# Surface configured interlock values so the hub can pre-filter UI controls
|
||||
# before sending a tx_start that would be rejected. Only included when TX
|
||||
# is opted in AND the operator set a cap.
|
||||
if c.tx_enabled:
|
||||
if c.tx_max_gain_db is not None:
|
||||
payload["tx_max_gain_db"] = float(c.tx_max_gain_db)
|
||||
if c.tx_max_duration_s is not None:
|
||||
payload["tx_max_duration_s"] = float(c.tx_max_duration_s)
|
||||
if c.tx_allowed_freq_ranges:
|
||||
payload["tx_allowed_freq_ranges"] = [
|
||||
[float(lo), float(hi)] for lo, hi in c.tx_allowed_freq_ranges
|
||||
]
|
||||
if app_id:
|
||||
payload["app_id"] = app_id
|
||||
if sessions:
|
||||
payload["sessions"] = sessions
|
||||
return payload
|
||||
|
|
@ -1,147 +0,0 @@
|
|||
"""Generate random human-readable agent names.
|
||||
|
||||
Produces names in the form ``adjective-colour-animal``, e.g.
|
||||
``swift-teal-falcon`` or ``brave-coral-otter``. All words are chosen
|
||||
to be friendly and inoffensive.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import random
|
||||
|
||||
ADJECTIVES: list[str] = [
|
||||
"brave",
|
||||
"bright",
|
||||
"calm",
|
||||
"clever",
|
||||
"cool",
|
||||
"daring",
|
||||
"eager",
|
||||
"fair",
|
||||
"fancy",
|
||||
"fast",
|
||||
"fierce",
|
||||
"gentle",
|
||||
"grand",
|
||||
"happy",
|
||||
"jolly",
|
||||
"keen",
|
||||
"kind",
|
||||
"lively",
|
||||
"lucky",
|
||||
"mighty",
|
||||
"noble",
|
||||
"plucky",
|
||||
"proud",
|
||||
"quick",
|
||||
"quiet",
|
||||
"sharp",
|
||||
"shiny",
|
||||
"sleek",
|
||||
"smart",
|
||||
"steady",
|
||||
"stellar",
|
||||
"strong",
|
||||
"sturdy",
|
||||
"sunny",
|
||||
"sure",
|
||||
"swift",
|
||||
"tall",
|
||||
"vivid",
|
||||
"warm",
|
||||
"wise",
|
||||
]
|
||||
|
||||
COLOURS: list[str] = [
|
||||
"amber",
|
||||
"aqua",
|
||||
"azure",
|
||||
"beige",
|
||||
"blue",
|
||||
"bronze",
|
||||
"coral",
|
||||
"copper",
|
||||
"crimson",
|
||||
"cyan",
|
||||
"denim",
|
||||
"gold",
|
||||
"green",
|
||||
"grey",
|
||||
"indigo",
|
||||
"ivory",
|
||||
"jade",
|
||||
"lemon",
|
||||
"lilac",
|
||||
"lime",
|
||||
"maroon",
|
||||
"mint",
|
||||
"navy",
|
||||
"olive",
|
||||
"onyx",
|
||||
"peach",
|
||||
"pearl",
|
||||
"plum",
|
||||
"red",
|
||||
"rose",
|
||||
"ruby",
|
||||
"rust",
|
||||
"sage",
|
||||
"sand",
|
||||
"scarlet",
|
||||
"silver",
|
||||
"slate",
|
||||
"steel",
|
||||
"teal",
|
||||
"violet",
|
||||
]
|
||||
|
||||
ANIMALS: list[str] = [
|
||||
"badger",
|
||||
"bear",
|
||||
"bison",
|
||||
"crane",
|
||||
"deer",
|
||||
"dolphin",
|
||||
"eagle",
|
||||
"elk",
|
||||
"falcon",
|
||||
"finch",
|
||||
"fox",
|
||||
"gecko",
|
||||
"hawk",
|
||||
"heron",
|
||||
"horse",
|
||||
"ibis",
|
||||
"jaguar",
|
||||
"jay",
|
||||
"kite",
|
||||
"koala",
|
||||
"lark",
|
||||
"lion",
|
||||
"lynx",
|
||||
"marten",
|
||||
"moose",
|
||||
"newt",
|
||||
"orca",
|
||||
"osprey",
|
||||
"otter",
|
||||
"owl",
|
||||
"panda",
|
||||
"puma",
|
||||
"raven",
|
||||
"robin",
|
||||
"salmon",
|
||||
"seal",
|
||||
"shark",
|
||||
"stork",
|
||||
"swift",
|
||||
"wolf",
|
||||
]
|
||||
|
||||
|
||||
def generate_agent_name() -> str:
|
||||
"""Return a random ``adjective-colour-animal`` name."""
|
||||
adj = random.choice(ADJECTIVES)
|
||||
col = random.choice(COLOURS)
|
||||
ani = random.choice(ANIMALS)
|
||||
return f"{adj}-{col}-{ani}"
|
||||
|
|
@ -1,758 +0,0 @@
|
|||
"""IQ-streaming agent.
|
||||
|
||||
Listens for control messages from the RIA Hub over a persistent WebSocket.
|
||||
Supports:
|
||||
|
||||
- An **RX session** (hub sends ``start``/``stop``/``configure``; agent opens
|
||||
the SDR, loops ``sdr.rx()`` and ships raw interleaved float32 IQ).
|
||||
- A **TX session** (hub sends ``tx_start``/``tx_stop``/``tx_configure`` plus
|
||||
binary IQ frames; agent feeds them into ``sdr._stream_tx``). Phase 3 wires
|
||||
up the session plumbing and rejects TX when ``cfg.tx_enabled`` is False;
|
||||
Phase 4 implements the full TX loop.
|
||||
|
||||
Both sessions can run concurrently on the same physical SDR (FDD) — a
|
||||
ref-counted SDR registry shares one driver instance when RX and TX name the
|
||||
same ``(device, identifier)``.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import queue
|
||||
import threading
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
|
||||
from .config import AgentConfig
|
||||
from .hardware import heartbeat_payload
|
||||
from .ws_client import WsClient
|
||||
|
||||
logger = logging.getLogger("ria_agent.streamer")
|
||||
|
||||
_DEFAULT_BUFFER_SIZE = 1024
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Session dataclasses
|
||||
|
||||
|
||||
@dataclass
|
||||
class RxSession:
|
||||
app_id: str
|
||||
sdr: Any
|
||||
device_key: tuple[str, str | None]
|
||||
buffer_size: int
|
||||
task: asyncio.Task | None = None
|
||||
pending_config: dict = field(default_factory=dict)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TxSession:
|
||||
app_id: str
|
||||
sdr: Any
|
||||
device_key: tuple[str, str | None]
|
||||
buffer_size: int
|
||||
task: Any = None # concurrent.futures.Future from run_in_executor
|
||||
pending_config: dict = field(default_factory=dict)
|
||||
underrun_policy: str = "pause"
|
||||
last_buffer: np.ndarray | None = None
|
||||
stop_event: threading.Event = field(default_factory=threading.Event)
|
||||
started_at: float = 0.0
|
||||
max_duration_s: float | None = None
|
||||
state: str = "armed"
|
||||
# Thread-safe queue of inbound interleaved-float32 IQ frames. Bounded so
|
||||
# hub-side over-production triggers WS backpressure rather than memory
|
||||
# growth in the agent.
|
||||
in_queue: "queue.Queue[bytes]" = field(default_factory=lambda: queue.Queue(maxsize=8))
|
||||
# Set by the TX callback when it hits an underrun while policy=="pause";
|
||||
# asyncio side flips the session state and emits tx_status.
|
||||
underrun_flag: threading.Event = field(default_factory=threading.Event)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# SDR registry (ref-counted so one Pluto handle serves RX + TX simultaneously)
|
||||
|
||||
|
||||
class _SdrRegistry:
|
||||
def __init__(self, factory):
|
||||
self._factory = factory
|
||||
self._instances: dict[tuple[str, str | None], tuple[Any, int]] = {}
|
||||
self._lock = threading.Lock()
|
||||
|
||||
def acquire(self, device: str, identifier: str | None) -> tuple[Any, tuple[str, str | None]]:
|
||||
key = (device, identifier)
|
||||
with self._lock:
|
||||
if key in self._instances:
|
||||
sdr, rc = self._instances[key]
|
||||
self._instances[key] = (sdr, rc + 1)
|
||||
return sdr, key
|
||||
# Build outside the lock: driver init can be slow and we don't want to
|
||||
# block concurrent releases on unrelated devices.
|
||||
sdr = self._factory(device, identifier)
|
||||
with self._lock:
|
||||
if key in self._instances:
|
||||
# Raced another acquirer; discard our duplicate and share theirs.
|
||||
other_sdr, rc = self._instances[key]
|
||||
try:
|
||||
sdr.close()
|
||||
except Exception:
|
||||
pass
|
||||
self._instances[key] = (other_sdr, rc + 1)
|
||||
return other_sdr, key
|
||||
self._instances[key] = (sdr, 1)
|
||||
return sdr, key
|
||||
|
||||
def release(self, key: tuple[str, str | None]) -> bool:
|
||||
"""Decrement refcount. Returns True if the caller owns the last reference
|
||||
and should close the SDR."""
|
||||
with self._lock:
|
||||
sdr, rc = self._instances.get(key, (None, 0))
|
||||
if sdr is None:
|
||||
return False
|
||||
if rc <= 1:
|
||||
del self._instances[key]
|
||||
return True
|
||||
self._instances[key] = (sdr, rc - 1)
|
||||
return False
|
||||
|
||||
def refcount(self, key: tuple[str, str | None]) -> int:
|
||||
with self._lock:
|
||||
return self._instances.get(key, (None, 0))[1]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Streamer
|
||||
|
||||
|
||||
class Streamer:
|
||||
"""Main streamer loop.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
ws:
|
||||
Connected :class:`WsClient`.
|
||||
sdr_factory:
|
||||
Callable ``(device, identifier) -> SDR``. Defaults to the helper in
|
||||
:mod:`ria_toolkit_oss.sdr`. Injectable for tests.
|
||||
cfg:
|
||||
:class:`AgentConfig` for interlocks (``tx_enabled`` and caps) and
|
||||
heartbeat capabilities. Defaults to an empty ``AgentConfig()`` which
|
||||
leaves TX disabled.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
ws,
|
||||
sdr_factory=None,
|
||||
cfg: AgentConfig | None = None,
|
||||
) -> None:
|
||||
self.ws = ws
|
||||
self._cfg = cfg or AgentConfig()
|
||||
self._registry = _SdrRegistry(sdr_factory or _default_sdr_factory)
|
||||
self._rx: RxSession | None = None
|
||||
self._tx: TxSession | None = None
|
||||
# Pending radio_config accepted via ``configure`` before ``start``.
|
||||
self._standalone_pending_config: dict = {}
|
||||
# Cached asyncio event loop, set the first time a handler runs. Used
|
||||
# to schedule async callbacks from the TX executor thread.
|
||||
self._loop: asyncio.AbstractEventLoop | None = None
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Back-compat read-only shims for callers that check ``._sdr`` etc.
|
||||
# Writes to these attributes are not supported — use the session objects.
|
||||
|
||||
@property
|
||||
def _sdr(self):
|
||||
return self._rx.sdr if self._rx is not None else None
|
||||
|
||||
@property
|
||||
def _pending_config(self) -> dict:
|
||||
return self._rx.pending_config if self._rx is not None else self._standalone_pending_config
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# WsClient wiring
|
||||
|
||||
def build_heartbeat(self) -> dict:
|
||||
status = "streaming" if (self._rx is not None or self._tx is not None) else "idle"
|
||||
app_id: str | None = None
|
||||
if self._rx is not None:
|
||||
app_id = self._rx.app_id
|
||||
elif self._tx is not None:
|
||||
app_id = self._tx.app_id
|
||||
|
||||
sessions: dict[str, dict] = {}
|
||||
if self._rx is not None:
|
||||
sessions["rx"] = {"app_id": self._rx.app_id, "state": "streaming"}
|
||||
if self._tx is not None:
|
||||
sessions["tx"] = {"app_id": self._tx.app_id, "state": self._tx.state}
|
||||
|
||||
return heartbeat_payload(
|
||||
status=status,
|
||||
app_id=app_id,
|
||||
cfg=self._cfg,
|
||||
sessions=sessions or None,
|
||||
)
|
||||
|
||||
# Advisory / keepalive message types we accept and ignore without warning.
|
||||
_IGNORED_MESSAGE_TYPES = frozenset({"tx_data_available"})
|
||||
|
||||
async def on_message(self, msg: dict) -> None:
|
||||
t = msg.get("type")
|
||||
if t in self._IGNORED_MESSAGE_TYPES:
|
||||
logger.debug("Ignoring advisory message: %r", t)
|
||||
return
|
||||
handler = {
|
||||
"start": self._handle_rx_start,
|
||||
"stop": self._handle_rx_stop,
|
||||
"configure": self._handle_rx_configure,
|
||||
"tx_start": self._handle_tx_start,
|
||||
"tx_stop": self._handle_tx_stop,
|
||||
"tx_configure": self._handle_tx_configure,
|
||||
}.get(t)
|
||||
if handler is None:
|
||||
logger.warning("Unknown server message type: %r", t)
|
||||
return
|
||||
await handler(msg)
|
||||
|
||||
async def on_binary(self, data: bytes) -> None:
|
||||
tx = self._tx
|
||||
if tx is None:
|
||||
logger.debug("Dropping %d-byte binary frame: no TX session", len(data))
|
||||
return
|
||||
# Backpressure: if the TX queue is full, await briefly so the hub's
|
||||
# ``await ws.send`` throttles naturally via TCP. We don't block
|
||||
# indefinitely — a 2s stall means something else is wrong.
|
||||
loop = asyncio.get_running_loop()
|
||||
try:
|
||||
await loop.run_in_executor(None, lambda: tx.in_queue.put(data, timeout=2.0))
|
||||
except queue.Full:
|
||||
logger.warning("TX queue stalled; dropping frame")
|
||||
|
||||
# ==================================================================
|
||||
# RX
|
||||
|
||||
async def _handle_rx_start(self, msg: dict) -> None:
|
||||
if self._rx is not None:
|
||||
logger.warning("start received while already streaming — ignoring")
|
||||
return
|
||||
|
||||
app_id = msg.get("app_id") or ""
|
||||
radio_config = dict(msg.get("radio_config") or {})
|
||||
device = radio_config.pop("device", None)
|
||||
identifier = radio_config.pop("identifier", None)
|
||||
buffer_size = int(radio_config.pop("buffer_size", _DEFAULT_BUFFER_SIZE))
|
||||
if not device:
|
||||
await self._send_error(app_id, "start missing radio_config.device")
|
||||
return
|
||||
|
||||
try:
|
||||
sdr, device_key = self._registry.acquire(device, identifier)
|
||||
_apply_sdr_config(sdr, radio_config)
|
||||
except Exception as exc:
|
||||
logger.exception("Failed to open SDR %r", device)
|
||||
await self._send_error(app_id, f"SDR init failed: {exc}")
|
||||
return
|
||||
|
||||
# Inherit any pending config that was queued before start.
|
||||
pending = dict(self._standalone_pending_config)
|
||||
self._standalone_pending_config = {}
|
||||
|
||||
session = RxSession(
|
||||
app_id=app_id,
|
||||
sdr=sdr,
|
||||
device_key=device_key,
|
||||
buffer_size=buffer_size,
|
||||
pending_config=pending,
|
||||
)
|
||||
self._rx = session
|
||||
await self._send_status("streaming", app_id)
|
||||
session.task = asyncio.create_task(
|
||||
self._capture_loop(session), name="ria-streamer-capture"
|
||||
)
|
||||
|
||||
async def _handle_rx_stop(self, msg: dict) -> None:
|
||||
session = self._rx
|
||||
if session is None:
|
||||
return
|
||||
if session.task is not None:
|
||||
session.task.cancel()
|
||||
try:
|
||||
await session.task
|
||||
except (asyncio.CancelledError, Exception):
|
||||
pass
|
||||
self._close_session_sdr(session)
|
||||
app_id = session.app_id
|
||||
self._rx = None
|
||||
await self._send_status("idle", app_id)
|
||||
|
||||
async def _handle_rx_configure(self, msg: dict) -> None:
|
||||
cfg = dict(msg.get("radio_config") or {})
|
||||
if self._rx is not None:
|
||||
self._rx.pending_config.update(cfg)
|
||||
else:
|
||||
self._standalone_pending_config.update(cfg)
|
||||
logger.debug("Queued configure: %s", cfg)
|
||||
|
||||
async def _capture_loop(self, session: RxSession) -> None:
|
||||
loop = asyncio.get_running_loop()
|
||||
try:
|
||||
while True:
|
||||
if session.pending_config:
|
||||
cfg = session.pending_config
|
||||
session.pending_config = {}
|
||||
try:
|
||||
_apply_sdr_config(session.sdr, cfg)
|
||||
except Exception as exc:
|
||||
logger.warning("Applying configure failed: %s", exc)
|
||||
|
||||
try:
|
||||
samples = await loop.run_in_executor(
|
||||
None, session.sdr.rx, session.buffer_size
|
||||
)
|
||||
except Exception as exc:
|
||||
from ria_toolkit_oss.sdr import SdrDisconnectedError
|
||||
|
||||
if isinstance(exc, SdrDisconnectedError):
|
||||
logger.warning("SDR disconnected: %s", exc)
|
||||
await self._send_error(session.app_id, f"SDR disconnected: {exc}")
|
||||
else:
|
||||
logger.exception("SDR rx error")
|
||||
await self._send_error(session.app_id, f"SDR capture failed: {exc}")
|
||||
break
|
||||
|
||||
payload = _samples_to_interleaved_float32(samples)
|
||||
try:
|
||||
await self.ws.send_bytes(payload)
|
||||
except Exception as exc:
|
||||
logger.warning("Send failed: %s — ending capture", exc)
|
||||
break
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
finally:
|
||||
self._close_session_sdr(session)
|
||||
# If the loop died on its own (e.g. SDR disconnect), clear the
|
||||
# session handle so future ``start`` messages can proceed.
|
||||
if self._rx is session:
|
||||
self._rx = None
|
||||
|
||||
# ==================================================================
|
||||
# TX
|
||||
|
||||
async def _handle_tx_start(self, msg: dict) -> None:
|
||||
app_id = msg.get("app_id") or ""
|
||||
radio_config = dict(msg.get("radio_config") or {})
|
||||
|
||||
# --- interlocks (agent-enforced; never trust the hub alone) ---
|
||||
if not self._cfg.tx_enabled:
|
||||
await self._send_tx_status(app_id, "error", "tx disabled on this agent")
|
||||
return
|
||||
tx_gain = radio_config.get("tx_gain")
|
||||
if (
|
||||
self._cfg.tx_max_gain_db is not None
|
||||
and tx_gain is not None
|
||||
and float(tx_gain) > float(self._cfg.tx_max_gain_db)
|
||||
):
|
||||
await self._send_tx_status(
|
||||
app_id,
|
||||
"error",
|
||||
f"tx_gain {tx_gain} exceeds cap {self._cfg.tx_max_gain_db}",
|
||||
)
|
||||
return
|
||||
tx_freq = radio_config.get("tx_center_frequency")
|
||||
if self._cfg.tx_allowed_freq_ranges and tx_freq is not None:
|
||||
f = float(tx_freq)
|
||||
if not any(float(lo) <= f <= float(hi) for lo, hi in self._cfg.tx_allowed_freq_ranges):
|
||||
await self._send_tx_status(
|
||||
app_id,
|
||||
"error",
|
||||
f"tx_center_frequency {tx_freq} outside allowed ranges",
|
||||
)
|
||||
return
|
||||
|
||||
if self._tx is not None:
|
||||
await self._send_tx_status(app_id, "error", "tx already active on this agent")
|
||||
return
|
||||
|
||||
# --- device ---
|
||||
device = radio_config.pop("device", None)
|
||||
identifier = radio_config.pop("identifier", None)
|
||||
buffer_size = int(radio_config.pop("buffer_size", _DEFAULT_BUFFER_SIZE))
|
||||
underrun_policy = str(radio_config.pop("underrun_policy", "pause"))
|
||||
if underrun_policy not in ("pause", "zero", "repeat"):
|
||||
await self._send_tx_status(
|
||||
app_id, "error", f"invalid underrun_policy {underrun_policy!r}"
|
||||
)
|
||||
return
|
||||
if not device:
|
||||
await self._send_tx_status(app_id, "error", "tx_start missing radio_config.device")
|
||||
return
|
||||
|
||||
device_key: tuple[str, str | None] | None = None
|
||||
sdr: Any = None
|
||||
try:
|
||||
sdr, device_key = self._registry.acquire(device, identifier)
|
||||
_apply_sdr_config(sdr, radio_config)
|
||||
# init_tx is mandatory for any driver that exposes it: drivers
|
||||
# that gate _stream_tx on _tx_initialized (Pluto, HackRF, USRP,
|
||||
# …) crash with a confusing "TX was not initialized" error 2 s
|
||||
# later in the executor thread if we skip it. Treat the three
|
||||
# required keys as a hard contract — a missing one is a hub-side
|
||||
# manifest bug and we want it surfaced immediately, not papered
|
||||
# over with stale radio state.
|
||||
if hasattr(sdr, "init_tx"):
|
||||
init_args = {
|
||||
k: radio_config.get(f"tx_{k}")
|
||||
for k in ("sample_rate", "center_frequency", "gain")
|
||||
}
|
||||
missing = [f"tx_{k}" for k, v in init_args.items() if v is None]
|
||||
if missing:
|
||||
raise ValueError(
|
||||
f"tx_start missing required radio_config keys: {missing}"
|
||||
)
|
||||
sdr.init_tx(
|
||||
sample_rate=init_args["sample_rate"],
|
||||
center_frequency=init_args["center_frequency"],
|
||||
gain=init_args["gain"],
|
||||
channel=radio_config.get("tx_channel", 0),
|
||||
gain_mode=radio_config.get("tx_gain_mode", "manual"),
|
||||
)
|
||||
except Exception as exc:
|
||||
if device_key is not None:
|
||||
if self._registry.release(device_key):
|
||||
try:
|
||||
sdr.close()
|
||||
except Exception:
|
||||
pass
|
||||
logger.exception("Failed to init TX on %r", device)
|
||||
await self._send_tx_status(app_id, "error", f"tx init failed: {exc}")
|
||||
return
|
||||
|
||||
self._loop = asyncio.get_running_loop()
|
||||
session = TxSession(
|
||||
app_id=app_id,
|
||||
sdr=sdr,
|
||||
device_key=device_key,
|
||||
buffer_size=buffer_size,
|
||||
underrun_policy=underrun_policy,
|
||||
started_at=time.monotonic(),
|
||||
max_duration_s=self._cfg.tx_max_duration_s,
|
||||
)
|
||||
self._tx = session
|
||||
await self._send_tx_status(app_id, "armed")
|
||||
session.task = self._loop.run_in_executor(None, self._tx_executor_body, session)
|
||||
# Spawn a small watchdog that transitions armed → transmitting when
|
||||
# the first buffer has been consumed, and surfaces underrun / max-
|
||||
# duration terminations back to the hub.
|
||||
asyncio.create_task(self._tx_watchdog(session))
|
||||
|
||||
async def _handle_tx_stop(self, msg: dict) -> None:
|
||||
session = self._tx
|
||||
if session is None:
|
||||
return
|
||||
app_id = session.app_id
|
||||
session.stop_event.set()
|
||||
try:
|
||||
session.sdr.pause_tx()
|
||||
except Exception:
|
||||
logger.debug("pause_tx raised during stop", exc_info=True)
|
||||
# Wake the executor thread if it's blocked on ``queue.get``.
|
||||
self._drain_tx_queue(session)
|
||||
if session.task is not None:
|
||||
try:
|
||||
await asyncio.wait_for(asyncio.wrap_future(session.task), timeout=1.5)
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning("TX executor did not exit within 1.5s after stop")
|
||||
except Exception:
|
||||
logger.debug("TX executor raised on shutdown", exc_info=True)
|
||||
self._close_session_sdr(session)
|
||||
self._tx = None
|
||||
await self._send_tx_status(app_id, "done")
|
||||
|
||||
async def _handle_tx_configure(self, msg: dict) -> None:
|
||||
if self._tx is None:
|
||||
return
|
||||
self._tx.pending_config.update(msg.get("radio_config") or {})
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# TX executor & watchdog
|
||||
|
||||
def _tx_executor_body(self, session: TxSession) -> None:
|
||||
try:
|
||||
session.sdr._stream_tx(lambda n: self._tx_callback(session, n))
|
||||
except Exception as exc:
|
||||
logger.exception("TX stream crashed")
|
||||
# Schedule both the error frame and session teardown on the loop
|
||||
# so ``self._tx`` clears, subsequent binary frames are rejected,
|
||||
# and the SDR handle is released.
|
||||
self._schedule(self._tx_crash_teardown(session, str(exc)))
|
||||
|
||||
def _tx_callback(self, session: TxSession, num_samples) -> np.ndarray:
|
||||
n = int(num_samples)
|
||||
# Honor stop requests: return silence one last time and let the driver
|
||||
# exit its loop on the next iteration (pause_tx flips _enable_tx).
|
||||
if session.stop_event.is_set():
|
||||
return _silence(n)
|
||||
|
||||
# Max-duration watchdog.
|
||||
if (
|
||||
session.max_duration_s is not None
|
||||
and (time.monotonic() - session.started_at) >= float(session.max_duration_s)
|
||||
):
|
||||
session.stop_event.set()
|
||||
try:
|
||||
session.sdr.pause_tx()
|
||||
except Exception:
|
||||
pass
|
||||
self._schedule(self._send_tx_status(session.app_id, "done", "max duration reached"))
|
||||
return _silence(n)
|
||||
|
||||
# Apply queued configure at buffer boundary.
|
||||
if session.pending_config:
|
||||
cfg = session.pending_config
|
||||
session.pending_config = {}
|
||||
try:
|
||||
_apply_sdr_config(session.sdr, cfg)
|
||||
except Exception as exc:
|
||||
logger.debug("tx_configure apply failed: %s", exc)
|
||||
|
||||
try:
|
||||
raw = session.in_queue.get(timeout=0.1)
|
||||
except queue.Empty:
|
||||
return self._underrun_fill(session, n)
|
||||
|
||||
arr = np.frombuffer(raw, dtype=np.float32)
|
||||
if arr.size < 2 or arr.size % 2 != 0:
|
||||
logger.warning("Malformed TX frame: %d floats (must be non-zero even count)", arr.size)
|
||||
return self._underrun_fill(session, n)
|
||||
samples = (arr[0::2].astype(np.complex64) + 1j * arr[1::2].astype(np.complex64))
|
||||
if samples.size < n:
|
||||
out = np.zeros(n, dtype=np.complex64)
|
||||
out[: samples.size] = samples
|
||||
session.last_buffer = out
|
||||
return out
|
||||
if samples.size > n:
|
||||
samples = samples[:n]
|
||||
session.last_buffer = samples
|
||||
if session.state == "armed":
|
||||
session.state = "transmitting"
|
||||
self._schedule(self._send_tx_status(session.app_id, "transmitting"))
|
||||
return samples
|
||||
|
||||
def _underrun_fill(self, session: TxSession, n: int) -> np.ndarray:
|
||||
policy = session.underrun_policy
|
||||
if policy == "zero":
|
||||
return _silence(n)
|
||||
if policy == "repeat" and session.last_buffer is not None:
|
||||
buf = session.last_buffer
|
||||
if buf.size == n:
|
||||
return buf
|
||||
if buf.size > n:
|
||||
return buf[:n].copy()
|
||||
out = np.zeros(n, dtype=np.complex64)
|
||||
out[: buf.size] = buf
|
||||
return out
|
||||
# "pause" policy (default) or "repeat" before any buffer arrived.
|
||||
if not session.underrun_flag.is_set():
|
||||
session.underrun_flag.set()
|
||||
session.stop_event.set()
|
||||
try:
|
||||
session.sdr.pause_tx()
|
||||
except Exception:
|
||||
pass
|
||||
return _silence(n)
|
||||
|
||||
async def _tx_watchdog(self, session: TxSession) -> None:
|
||||
# Poll the underrun flag so we can emit status + tear down cleanly
|
||||
# when the callback flips the flag from the executor thread. Check
|
||||
# underrun_flag before stop_event, since the "pause" path sets both.
|
||||
while session is self._tx:
|
||||
if session.underrun_flag.is_set():
|
||||
await self._send_tx_status(session.app_id, "underrun")
|
||||
await self._teardown_tx_after_underrun(session)
|
||||
return
|
||||
if session.stop_event.is_set():
|
||||
return
|
||||
await asyncio.sleep(0.05)
|
||||
|
||||
async def _tx_crash_teardown(self, session: TxSession, message: str) -> None:
|
||||
# Called from the executor thread via _schedule when _stream_tx raises.
|
||||
# Emit the error, mark stopped, drain the queue, release the SDR.
|
||||
await self._send_tx_status(session.app_id, "error", f"tx stream crashed: {message}")
|
||||
if self._tx is not session:
|
||||
return
|
||||
session.stop_event.set()
|
||||
self._drain_tx_queue(session)
|
||||
self._close_session_sdr(session)
|
||||
if self._tx is session:
|
||||
self._tx = None
|
||||
|
||||
async def _teardown_tx_after_underrun(self, session: TxSession) -> None:
|
||||
if self._tx is not session:
|
||||
return
|
||||
self._drain_tx_queue(session)
|
||||
if session.task is not None:
|
||||
try:
|
||||
await asyncio.wait_for(asyncio.wrap_future(session.task), timeout=1.0)
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning("TX executor did not exit within 1s after underrun")
|
||||
except Exception:
|
||||
logger.debug("TX executor raised during underrun teardown", exc_info=True)
|
||||
self._close_session_sdr(session)
|
||||
if self._tx is session:
|
||||
self._tx = None
|
||||
|
||||
def _drain_tx_queue(self, session: TxSession) -> None:
|
||||
try:
|
||||
while True:
|
||||
session.in_queue.get_nowait()
|
||||
except queue.Empty:
|
||||
pass
|
||||
|
||||
def _schedule(self, coro) -> None:
|
||||
loop = self._loop
|
||||
if loop is None:
|
||||
return
|
||||
try:
|
||||
asyncio.run_coroutine_threadsafe(coro, loop)
|
||||
except Exception:
|
||||
logger.debug("_schedule failed", exc_info=True)
|
||||
|
||||
# ==================================================================
|
||||
# Helpers
|
||||
|
||||
def _close_session_sdr(self, session) -> None:
|
||||
if session.sdr is None:
|
||||
return
|
||||
should_close = self._registry.release(session.device_key)
|
||||
if should_close:
|
||||
try:
|
||||
session.sdr.close()
|
||||
except Exception:
|
||||
logger.debug("SDR close raised", exc_info=True)
|
||||
|
||||
async def _send_status(self, status: str, app_id: str) -> None:
|
||||
try:
|
||||
await self.ws.send_json({"type": "status", "status": status, "app_id": app_id})
|
||||
except Exception as exc:
|
||||
logger.debug("Status send failed: %s", exc)
|
||||
|
||||
async def _send_error(self, app_id: str, message: str) -> None:
|
||||
try:
|
||||
await self.ws.send_json({"type": "error", "app_id": app_id, "message": message})
|
||||
except Exception as exc:
|
||||
logger.debug("Error-frame send failed: %s", exc)
|
||||
|
||||
async def _send_tx_status(self, app_id: str, state: str, message: str | None = None) -> None:
|
||||
payload: dict = {"type": "tx_status", "app_id": app_id, "state": state}
|
||||
if message is not None:
|
||||
payload["message"] = message
|
||||
try:
|
||||
await self.ws.send_json(payload)
|
||||
except Exception as exc:
|
||||
logger.debug("tx_status send failed: %s", exc)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
|
||||
_CONFIG_ATTR_MAP = {
|
||||
"sample_rate": ("sample_rate", "rx_sample_rate"),
|
||||
"center_frequency": ("center_freq", "rx_center_frequency"),
|
||||
"center_freq": ("center_freq", "rx_center_frequency"),
|
||||
"gain": ("gain", "rx_gain"),
|
||||
"bandwidth": ("bandwidth", "rx_bandwidth"),
|
||||
"tx_sample_rate": ("tx_sample_rate",),
|
||||
"tx_center_frequency": ("tx_center_frequency", "tx_lo"),
|
||||
"tx_gain": ("tx_gain",),
|
||||
"tx_bandwidth": ("tx_bandwidth",),
|
||||
}
|
||||
|
||||
|
||||
def _is_stub_setter(method: Any) -> bool:
|
||||
"""True when *method* is an unimplemented base-class stub.
|
||||
|
||||
The ``SDR`` abstract base defines ``set_rx_sample_rate`` / ``set_tx_gain``
|
||||
etc. as zero-argument ``NotImplementedError`` stubs. A driver (Pluto) that
|
||||
actually transmits overrides them with a real ``(value, ...)`` signature.
|
||||
Comparing ``__qualname__`` against ``SDR.`` lets us skip the stubs cheaply.
|
||||
"""
|
||||
return getattr(method, "__qualname__", "").startswith("SDR.")
|
||||
|
||||
|
||||
def _apply_sdr_config(sdr: Any, cfg: dict) -> None:
|
||||
"""Apply a radio_config dict to an SDR.
|
||||
|
||||
Prefers ``sdr.set_<attr>(value)`` when the driver implements it — Pluto's
|
||||
setters take ``_param_lock``, so routing through them keeps concurrent
|
||||
RX + TX reconfigures from racing on shared native attributes. Falls back
|
||||
to ``setattr`` for drivers (MockSDR, tests) that don't override the
|
||||
base-class stubs.
|
||||
"""
|
||||
for key, value in cfg.items():
|
||||
if value is None:
|
||||
continue
|
||||
attrs = _CONFIG_ATTR_MAP.get(key, (key,))
|
||||
applied = False
|
||||
for attr in attrs:
|
||||
setter = getattr(sdr, f"set_{attr}", None)
|
||||
if callable(setter) and not _is_stub_setter(setter):
|
||||
try:
|
||||
setter(value)
|
||||
applied = True
|
||||
break
|
||||
except Exception as exc:
|
||||
logger.debug("set_%s(%r) failed: %s", attr, value, exc)
|
||||
# Fall through to setattr; some drivers may partially
|
||||
# implement setters.
|
||||
if applied:
|
||||
continue
|
||||
for attr in attrs:
|
||||
if hasattr(sdr, attr):
|
||||
try:
|
||||
setattr(sdr, attr, value)
|
||||
applied = True
|
||||
break
|
||||
except Exception as exc:
|
||||
logger.debug("setattr %s=%r failed: %s", attr, value, exc)
|
||||
if not applied:
|
||||
logger.debug("radio_config key %r ignored (no matching attr)", key)
|
||||
|
||||
|
||||
def _silence(num_samples: int) -> np.ndarray:
|
||||
"""Return a ``num_samples``-length zero-filled complex64 buffer."""
|
||||
return np.zeros(int(num_samples), dtype=np.complex64)
|
||||
|
||||
|
||||
def _samples_to_interleaved_float32(samples: Any) -> bytes:
|
||||
"""Convert complex IQ samples (any numeric dtype) to interleaved float32 bytes."""
|
||||
arr = np.asarray(samples)
|
||||
if np.iscomplexobj(arr):
|
||||
interleaved = np.empty(arr.size * 2, dtype=np.float32)
|
||||
interleaved[0::2] = arr.real.astype(np.float32, copy=False).ravel()
|
||||
interleaved[1::2] = arr.imag.astype(np.float32, copy=False).ravel()
|
||||
return interleaved.tobytes()
|
||||
return arr.astype(np.float32, copy=False).tobytes()
|
||||
|
||||
|
||||
def _default_sdr_factory(device: str, identifier: str | None):
|
||||
from ria_toolkit_oss.sdr import get_sdr_device
|
||||
|
||||
return get_sdr_device(device, ident=identifier)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Top-level entry
|
||||
|
||||
async def run_streamer(ws_url: str, token: str, *, cfg: AgentConfig | None = None) -> None:
|
||||
"""Connect to *ws_url* and run the streamer loop until cancelled."""
|
||||
ws = WsClient(ws_url, token)
|
||||
streamer = Streamer(ws, cfg=cfg)
|
||||
await ws.run(
|
||||
streamer.on_message,
|
||||
streamer.build_heartbeat,
|
||||
on_binary=streamer.on_binary,
|
||||
)
|
||||
|
|
@ -1,128 +0,0 @@
|
|||
"""Persistent WebSocket client for the streamer agent.
|
||||
|
||||
Handles connection lifecycle: connect, heartbeat, auto-reconnect on drop.
|
||||
The caller drives the I/O loop via ``run()`` with a message handler callback.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
from typing import Awaitable, Callable
|
||||
|
||||
logger = logging.getLogger("ria_agent.ws")
|
||||
|
||||
MessageHandler = Callable[[dict], Awaitable[None]]
|
||||
HeartbeatBuilder = Callable[[], dict]
|
||||
BinaryHandler = Callable[[bytes], Awaitable[None]]
|
||||
|
||||
|
||||
class WsClient:
|
||||
"""Persistent WebSocket connection with heartbeat and auto-reconnect.
|
||||
|
||||
``url`` should be a full ``wss://host/path`` (or ``ws://``) URL. ``token``
|
||||
is sent as a bearer in the ``Authorization`` header on connect.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
url: str,
|
||||
token: str,
|
||||
*,
|
||||
heartbeat_interval: float = 30.0,
|
||||
reconnect_pause: float = 5.0,
|
||||
) -> None:
|
||||
self.url = url
|
||||
self.token = token
|
||||
self.heartbeat_interval = heartbeat_interval
|
||||
self.reconnect_pause = reconnect_pause
|
||||
self._ws = None
|
||||
self._stop = asyncio.Event()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
async def _connect(self):
|
||||
import websockets
|
||||
|
||||
headers = [("Authorization", f"Bearer {self.token}")] if self.token else None
|
||||
# websockets >= 12 accepts additional_headers; fall back to extra_headers for older versions.
|
||||
try:
|
||||
return await websockets.connect(self.url, additional_headers=headers)
|
||||
except TypeError:
|
||||
return await websockets.connect(self.url, extra_headers=headers)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
async def send_json(self, payload: dict) -> None:
|
||||
if self._ws is None:
|
||||
raise ConnectionError("WebSocket is not connected")
|
||||
await self._ws.send(json.dumps(payload))
|
||||
|
||||
async def send_bytes(self, data: bytes) -> None:
|
||||
if self._ws is None:
|
||||
raise ConnectionError("WebSocket is not connected")
|
||||
await self._ws.send(data)
|
||||
|
||||
def stop(self) -> None:
|
||||
self._stop.set()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
async def run(
|
||||
self,
|
||||
on_message: MessageHandler,
|
||||
heartbeat: HeartbeatBuilder,
|
||||
on_binary: BinaryHandler | None = None,
|
||||
) -> None:
|
||||
"""Main loop: connect, heartbeat, dispatch messages, reconnect on drop."""
|
||||
while not self._stop.is_set():
|
||||
try:
|
||||
self._ws = await self._connect()
|
||||
logger.info("Connected to %s", self.url)
|
||||
hb_task = asyncio.create_task(self._heartbeat_loop(heartbeat))
|
||||
try:
|
||||
async for raw in self._ws:
|
||||
if isinstance(raw, bytes):
|
||||
if on_binary is None:
|
||||
logger.debug("Discarding unexpected %d-byte binary frame", len(raw))
|
||||
continue
|
||||
try:
|
||||
await on_binary(raw)
|
||||
except Exception:
|
||||
logger.exception("on_binary handler raised; dropping frame")
|
||||
continue
|
||||
try:
|
||||
msg = json.loads(raw)
|
||||
except json.JSONDecodeError:
|
||||
logger.warning("Malformed control frame: %r", raw[:200])
|
||||
continue
|
||||
await on_message(msg)
|
||||
finally:
|
||||
hb_task.cancel()
|
||||
try:
|
||||
await hb_task
|
||||
except (asyncio.CancelledError, Exception):
|
||||
pass
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
except Exception as exc:
|
||||
if self._stop.is_set():
|
||||
break
|
||||
logger.warning("WS error: %s — reconnecting in %.1fs", exc, self.reconnect_pause)
|
||||
finally:
|
||||
try:
|
||||
if self._ws is not None:
|
||||
await self._ws.close()
|
||||
except Exception:
|
||||
pass
|
||||
self._ws = None
|
||||
if self._stop.is_set():
|
||||
break
|
||||
await asyncio.sleep(self.reconnect_pause)
|
||||
|
||||
async def _heartbeat_loop(self, heartbeat: HeartbeatBuilder) -> None:
|
||||
while True:
|
||||
try:
|
||||
await self.send_json(heartbeat())
|
||||
except Exception as exc:
|
||||
logger.debug("Heartbeat send failed: %s", exc)
|
||||
return
|
||||
await asyncio.sleep(self.heartbeat_interval)
|
||||
|
|
@ -1 +0,0 @@
|
|||
"""App runner: pull and run containerized RIA applications."""
|
||||
|
|
@ -1,276 +0,0 @@
|
|||
"""Unified ``ria-app`` CLI.
|
||||
|
||||
Subcommands:
|
||||
|
||||
- ``ria-app pull <app>[:tag]`` — pull a RIA app image from the configured registry.
|
||||
- ``ria-app run <app>[:tag]`` — pull (if needed) and run, auto-configuring
|
||||
GPU/USB/network flags from image labels set by CI.
|
||||
- ``ria-app list`` — list locally cached RIA app images.
|
||||
- ``ria-app stop <app>`` — stop a running app container.
|
||||
- ``ria-app logs <app>`` — tail logs of a running app container.
|
||||
- ``ria-app configure`` — set default registry/namespace.
|
||||
|
||||
Image references resolve as::
|
||||
|
||||
my-classifier -> {registry}/{namespace}/my-classifier:latest
|
||||
group/my-classifier -> {registry}/group/my-classifier:latest
|
||||
host/group/app:tag -> host/group/app:tag (fully-qualified passthrough)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import shutil
|
||||
import subprocess
|
||||
import sys
|
||||
|
||||
from . import config as _config
|
||||
|
||||
_LABEL_PROFILE = "ria.profile"
|
||||
_LABEL_HARDWARE = "ria.hardware"
|
||||
_LABEL_APP = "ria.app"
|
||||
|
||||
|
||||
def _engine(cfg: _config.AppConfig, sudo_override: bool = False) -> list[str]:
|
||||
for exe in ("docker", "podman"):
|
||||
if shutil.which(exe):
|
||||
use_sudo = sudo_override or cfg.sudo
|
||||
return (["sudo", exe] if use_sudo else [exe])
|
||||
print("error: neither 'docker' nor 'podman' found on PATH", file=sys.stderr)
|
||||
sys.exit(2)
|
||||
|
||||
|
||||
def _resolve_ref(app: str, cfg: _config.AppConfig) -> str:
|
||||
ref = app if ":" in app.split("/")[-1] else f"{app}:latest"
|
||||
slashes = ref.count("/")
|
||||
if slashes >= 2:
|
||||
return ref
|
||||
if slashes == 1:
|
||||
return f"{cfg.registry}/{ref}" if cfg.registry else ref
|
||||
if not cfg.registry or not cfg.namespace:
|
||||
print(
|
||||
"error: app is not fully qualified and no default registry/namespace configured. "
|
||||
"Run `ria-app configure` or pass a full image reference (registry/namespace/app:tag).",
|
||||
file=sys.stderr,
|
||||
)
|
||||
sys.exit(2)
|
||||
return f"{cfg.registry}/{cfg.namespace}/{ref}"
|
||||
|
||||
|
||||
def _container_name(ref: str) -> str:
|
||||
name = ref.rsplit("/", 1)[-1].split(":", 1)[0]
|
||||
return f"ria-app-{name}"
|
||||
|
||||
|
||||
def _inspect_labels(engine: list[str], ref: str) -> dict:
|
||||
try:
|
||||
out = subprocess.check_output(
|
||||
[*engine, "image", "inspect", "--format", "{{json .Config.Labels}}", ref],
|
||||
stderr=subprocess.DEVNULL,
|
||||
)
|
||||
except subprocess.CalledProcessError:
|
||||
return {}
|
||||
try:
|
||||
return json.loads(out.decode().strip()) or {}
|
||||
except json.JSONDecodeError:
|
||||
return {}
|
||||
|
||||
|
||||
def _gpu_available() -> bool:
|
||||
if os.path.exists("/dev/nvidia0"):
|
||||
return True
|
||||
return shutil.which("nvidia-smi") is not None
|
||||
|
||||
|
||||
def _hardware_flags(labels: dict, no_gpu: bool, no_usb: bool, no_host_net: bool) -> tuple[list[str], list[str]]:
|
||||
flags: list[str] = []
|
||||
notes: list[str] = []
|
||||
profile = (labels.get(_LABEL_PROFILE) or "").lower()
|
||||
hardware = (labels.get(_LABEL_HARDWARE) or "").lower()
|
||||
hw_items = {h.strip() for h in hardware.split(",") if h.strip()}
|
||||
|
||||
wants_gpu = any(k in profile for k in ("nvidia", "holoscan", "cuda"))
|
||||
if wants_gpu and not no_gpu:
|
||||
if _gpu_available():
|
||||
flags += ["--gpus", "all"]
|
||||
else:
|
||||
notes.append("image wants GPU but no NVIDIA runtime detected — skipping --gpus (use --force-gpu to override)")
|
||||
|
||||
if hw_items & {"pluto", "rtlsdr", "hackrf", "bladerf"} and not no_usb:
|
||||
flags += ["--device", "/dev/bus/usb"]
|
||||
|
||||
if hw_items & {"usrp", "thinkrf", "pluto"} and not no_host_net:
|
||||
flags += ["--net", "host"]
|
||||
|
||||
return flags, notes
|
||||
|
||||
|
||||
def _cmd_configure(args: argparse.Namespace) -> int:
|
||||
cfg = _config.load()
|
||||
if args.registry:
|
||||
cfg.registry = args.registry
|
||||
if args.namespace:
|
||||
cfg.namespace = args.namespace
|
||||
if args.sudo is not None:
|
||||
cfg.sudo = args.sudo
|
||||
path = _config.save(cfg)
|
||||
print(f"Saved app config to {path}")
|
||||
print(f" registry: {cfg.registry or '(unset)'}")
|
||||
print(f" namespace: {cfg.namespace or '(unset)'}")
|
||||
print(f" sudo: {cfg.sudo}")
|
||||
return 0
|
||||
|
||||
|
||||
def _cmd_pull(args: argparse.Namespace) -> int:
|
||||
cfg = _config.load()
|
||||
engine = _engine(cfg, args.sudo)
|
||||
ref = _resolve_ref(args.app, cfg)
|
||||
print(f"Pulling {ref}")
|
||||
return subprocess.call([*engine, "pull", ref])
|
||||
|
||||
|
||||
def _cmd_run(args: argparse.Namespace) -> int:
|
||||
cfg = _config.load()
|
||||
engine = _engine(cfg, args.sudo)
|
||||
ref = _resolve_ref(args.app, cfg)
|
||||
|
||||
if not _inspect_labels(engine, ref):
|
||||
rc = subprocess.call([*engine, "pull", ref])
|
||||
if rc != 0:
|
||||
return rc
|
||||
|
||||
labels = _inspect_labels(engine, ref)
|
||||
no_gpu = args.no_gpu and not args.force_gpu
|
||||
hw_flags, notes = _hardware_flags(labels, no_gpu=no_gpu, no_usb=args.no_usb, no_host_net=args.no_host_net)
|
||||
if args.force_gpu and "--gpus" not in hw_flags:
|
||||
hw_flags = ["--gpus", "all", *hw_flags]
|
||||
|
||||
cmd = [*engine, "run", "--rm"]
|
||||
if not args.foreground:
|
||||
cmd += ["-d"]
|
||||
cmd += ["--name", args.name or _container_name(ref)]
|
||||
cmd += hw_flags
|
||||
|
||||
if args.config:
|
||||
cmd += ["-v", f"{args.config}:/config/config.yaml:ro", "-e", "RIA_CONFIG=/config/config.yaml"]
|
||||
|
||||
for env in args.env or []:
|
||||
cmd += ["-e", env]
|
||||
for vol in args.volume or []:
|
||||
cmd += ["-v", vol]
|
||||
for port in args.publish or []:
|
||||
cmd += ["-p", port]
|
||||
|
||||
cmd += list(args.docker_args or [])
|
||||
cmd += [ref]
|
||||
cmd += list(args.app_args or [])
|
||||
|
||||
if args.dry_run:
|
||||
print(" ".join(cmd))
|
||||
return 0
|
||||
|
||||
label_str = ", ".join(f"{k}={v}" for k, v in labels.items() if k.startswith("ria.")) or "(no ria.* labels)"
|
||||
print(f"Running {ref} [{label_str}]")
|
||||
if hw_flags:
|
||||
print(f" auto flags: {' '.join(hw_flags)}")
|
||||
for note in notes:
|
||||
print(f" note: {note}")
|
||||
return subprocess.call(cmd)
|
||||
|
||||
|
||||
def _cmd_list(args: argparse.Namespace) -> int:
|
||||
cfg = _config.load()
|
||||
engine = _engine(cfg, args.sudo)
|
||||
return subprocess.call(
|
||||
[
|
||||
*engine,
|
||||
"images",
|
||||
"--filter",
|
||||
f"label={_LABEL_APP}",
|
||||
"--format",
|
||||
"table {{.Repository}}:{{.Tag}}\t{{.ID}}\t{{.Size}}",
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def _cmd_stop(args: argparse.Namespace) -> int:
|
||||
cfg = _config.load()
|
||||
engine = _engine(cfg, args.sudo)
|
||||
name = args.name or _container_name(_resolve_ref(args.app, cfg))
|
||||
return subprocess.call([*engine, "stop", name])
|
||||
|
||||
|
||||
def _cmd_logs(args: argparse.Namespace) -> int:
|
||||
cfg = _config.load()
|
||||
engine = _engine(cfg, args.sudo)
|
||||
name = args.name or _container_name(_resolve_ref(args.app, cfg))
|
||||
cmd = [*engine, "logs"]
|
||||
if args.follow:
|
||||
cmd += ["-f"]
|
||||
cmd += [name]
|
||||
return subprocess.call(cmd)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(prog="ria-app")
|
||||
parser.add_argument("--sudo", action="store_true", default=False, help="Run docker/podman via sudo")
|
||||
sub = parser.add_subparsers(dest="command", required=True)
|
||||
|
||||
p_cfg = sub.add_parser("configure", help="Set default registry/namespace")
|
||||
p_cfg.add_argument("--registry", default=None, help="Default container registry (e.g. registry.riahub.ai)")
|
||||
p_cfg.add_argument("--namespace", default=None, help="Default namespace (e.g. qoherent)")
|
||||
p_cfg.add_argument(
|
||||
"--sudo",
|
||||
dest="sudo",
|
||||
action=argparse.BooleanOptionalAction,
|
||||
default=None,
|
||||
help="Persist sudo default (--sudo / --no-sudo)",
|
||||
)
|
||||
|
||||
p_pull = sub.add_parser("pull", help="Pull an app image")
|
||||
p_pull.add_argument("app", help="App name or image reference")
|
||||
|
||||
p_run = sub.add_parser("run", help="Run an app, auto-detecting hardware flags")
|
||||
p_run.add_argument("app", help="App name or image reference")
|
||||
p_run.add_argument("--name", default=None, help="Container name (default: ria-app-<app>)")
|
||||
p_run.add_argument("--config", default=None, help="Path to config.yaml to mount into the container")
|
||||
p_run.add_argument("-e", "--env", action="append", help="Extra env var (KEY=VALUE)")
|
||||
p_run.add_argument("-v", "--volume", action="append", help="Extra volume mount")
|
||||
p_run.add_argument("-p", "--publish", action="append", help="Publish port")
|
||||
p_run.add_argument("--foreground", "-F", action="store_true", help="Run in foreground (no -d)")
|
||||
p_run.add_argument("--no-gpu", action="store_true", help="Skip --gpus flag even if image wants GPU")
|
||||
p_run.add_argument("--force-gpu", action="store_true", help="Force --gpus all even if no NVIDIA runtime detected")
|
||||
p_run.add_argument("--no-usb", action="store_true", help="Skip --device /dev/bus/usb")
|
||||
p_run.add_argument("--no-host-net", action="store_true", help="Skip --net host")
|
||||
p_run.add_argument("--dry-run", action="store_true", help="Print the container command and exit")
|
||||
p_run.add_argument("--docker-args", nargs=argparse.REMAINDER, help="Pass remaining args to docker/podman run")
|
||||
p_run.add_argument("--app-args", nargs=argparse.REMAINDER, help="Pass remaining args to the app entrypoint")
|
||||
|
||||
sub.add_parser("list", help="List locally cached RIA app images")
|
||||
|
||||
p_stop = sub.add_parser("stop", help="Stop a running app")
|
||||
p_stop.add_argument("app", help="App name or image reference")
|
||||
p_stop.add_argument("--name", default=None, help="Container name override")
|
||||
|
||||
p_logs = sub.add_parser("logs", help="Tail logs of a running app")
|
||||
p_logs.add_argument("app", help="App name or image reference")
|
||||
p_logs.add_argument("--name", default=None, help="Container name override")
|
||||
p_logs.add_argument("-f", "--follow", action="store_true", help="Follow log output")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
dispatch = {
|
||||
"configure": _cmd_configure,
|
||||
"pull": _cmd_pull,
|
||||
"run": _cmd_run,
|
||||
"list": _cmd_list,
|
||||
"stop": _cmd_stop,
|
||||
"logs": _cmd_logs,
|
||||
}
|
||||
sys.exit(dispatch[args.command](args))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
@ -1,51 +0,0 @@
|
|||
"""App runner configuration at ``~/.ria/toolkit.json``.
|
||||
|
||||
Schema::
|
||||
|
||||
{
|
||||
"registry": "registry.riahub.ai",
|
||||
"namespace": "qoherent"
|
||||
}
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
from dataclasses import asdict, dataclass
|
||||
from pathlib import Path
|
||||
|
||||
_DEFAULT_PATH = Path(os.environ.get("RIA_TOOLKIT_CONFIG", str(Path.home() / ".ria" / "toolkit.json")))
|
||||
|
||||
|
||||
@dataclass
|
||||
class AppConfig:
|
||||
registry: str = ""
|
||||
namespace: str = ""
|
||||
sudo: bool = False
|
||||
|
||||
|
||||
def default_path() -> Path:
|
||||
return _DEFAULT_PATH
|
||||
|
||||
|
||||
def load(path: Path | None = None) -> AppConfig:
|
||||
p = path or _DEFAULT_PATH
|
||||
if not p.exists():
|
||||
return AppConfig(
|
||||
registry=os.environ.get("RIA_REGISTRY", ""),
|
||||
namespace=os.environ.get("RIA_NAMESPACE", ""),
|
||||
)
|
||||
data = json.loads(p.read_text())
|
||||
return AppConfig(
|
||||
registry=data.get("registry", "") or os.environ.get("RIA_REGISTRY", ""),
|
||||
namespace=data.get("namespace", "") or os.environ.get("RIA_NAMESPACE", ""),
|
||||
sudo=bool(data.get("sudo", False)) or os.environ.get("RIA_DOCKER_SUDO", "") not in ("", "0", "false"),
|
||||
)
|
||||
|
||||
|
||||
def save(cfg: AppConfig, path: Path | None = None) -> Path:
|
||||
p = path or _DEFAULT_PATH
|
||||
p.parent.mkdir(parents=True, exist_ok=True)
|
||||
p.write_text(json.dumps(asdict(cfg), indent=2))
|
||||
return p
|
||||
|
|
@ -4,48 +4,10 @@ It streamlines tasks involving signal reception and transmission, as well as com
|
|||
operations such as detecting and configuring available devices.
|
||||
"""
|
||||
|
||||
__all__ = [
|
||||
"SDR",
|
||||
"SDRError",
|
||||
"SDRParameterError",
|
||||
"SdrDisconnectedError",
|
||||
"MockSDR",
|
||||
"get_sdr_device",
|
||||
"detect_available",
|
||||
]
|
||||
__all__ = ["SDR", "SDRError", "SDRParameterError", "MockSDR", "get_sdr_device"]
|
||||
|
||||
from .mock import MockSDR
|
||||
from .sdr import SDR, SDRError, SdrDisconnectedError, SDRParameterError, translate_disconnect # noqa: F401
|
||||
|
||||
|
||||
_DRIVER_CANDIDATES: tuple[tuple[str, str, str], ...] = (
|
||||
("mock", "ria_toolkit_oss.sdr.mock", "MockSDR"),
|
||||
("pluto", "ria_toolkit_oss.sdr.pluto", "Pluto"),
|
||||
("hackrf", "ria_toolkit_oss.sdr.hackrf", "HackRF"),
|
||||
("rtlsdr", "ria_toolkit_oss.sdr.rtlsdr", "RTLSDR"),
|
||||
("usrp", "ria_toolkit_oss.sdr.usrp", "USRP"),
|
||||
("blade", "ria_toolkit_oss.sdr.blade", "Blade"),
|
||||
("thinkrf", "ria_toolkit_oss.sdr.thinkrf", "ThinkRF"),
|
||||
)
|
||||
|
||||
|
||||
def detect_available() -> dict[str, type]:
|
||||
"""Return ``{device_name: driver_class}`` for every driver whose module imports cleanly.
|
||||
|
||||
Importability is a proxy for "the user has installed this driver's optional dependency".
|
||||
It does not probe for physical hardware presence — that requires actually instantiating
|
||||
the driver, which can be slow and side-effectful.
|
||||
"""
|
||||
import importlib
|
||||
|
||||
out: dict[str, type] = {}
|
||||
for name, module_path, cls_name in _DRIVER_CANDIDATES:
|
||||
try:
|
||||
mod = importlib.import_module(module_path)
|
||||
out[name] = getattr(mod, cls_name)
|
||||
except Exception:
|
||||
continue
|
||||
return out
|
||||
from .sdr import SDR, SDRError, SDRParameterError
|
||||
|
||||
|
||||
def get_sdr_device(device_type: str, ident: str | None = None, tx: bool = False) -> SDR:
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ import adi
|
|||
import numpy as np
|
||||
|
||||
from ria_toolkit_oss.datatypes.recording import Recording
|
||||
from ria_toolkit_oss.sdr.sdr import SDR, SDRError, SDRParameterError, translate_disconnect
|
||||
from ria_toolkit_oss.sdr.sdr import SDR, SDRError, SDRParameterError
|
||||
|
||||
|
||||
class Pluto(SDR):
|
||||
|
|
@ -164,25 +164,6 @@ class Pluto(SDR):
|
|||
# send callback complex signal
|
||||
callback(buffer=signal, metadata=None)
|
||||
|
||||
def rx(self, num_samples: Optional[int] = None) -> np.ndarray:
|
||||
"""PlutoSDR-style single-buffer capture returning a complex64 array.
|
||||
|
||||
Sets the radio buffer size to *num_samples* (if given) and returns one
|
||||
buffer directly from ``self.radio.rx()``. Raises
|
||||
:class:`SdrDisconnectedError` on USB/device drop so callers (e.g. the
|
||||
streamer) can report the failure and stop cleanly instead of crashing.
|
||||
"""
|
||||
if num_samples is not None:
|
||||
try:
|
||||
self.set_rx_buffer_size(buffer_size=int(num_samples))
|
||||
except Exception as exc:
|
||||
raise translate_disconnect(exc) from exc
|
||||
try:
|
||||
samples = self.radio.rx()
|
||||
except Exception as exc:
|
||||
raise translate_disconnect(exc) from exc
|
||||
return np.asarray(samples)
|
||||
|
||||
def _record_fast(self, num_samples):
|
||||
"""Optimized single-buffer capture for ≤16M samples."""
|
||||
|
||||
|
|
@ -384,10 +365,7 @@ class Pluto(SDR):
|
|||
self._enable_tx = True
|
||||
while self._enable_tx is True:
|
||||
buffer = self._convert_tx_samples(callback(self.tx_buffer_size))
|
||||
# pyadi-iio's ``radio.tx`` auto-wraps single-channel 1-D input.
|
||||
# Indexing ``buffer[0]`` was a latent bug for callbacks that
|
||||
# returned 1-D samples (scalar → TypeError inside pyadi).
|
||||
self.radio.tx(buffer)
|
||||
self.radio.tx(buffer[0])
|
||||
|
||||
def set_rx_center_frequency(self, center_frequency):
|
||||
"""
|
||||
|
|
@ -517,85 +495,74 @@ class Pluto(SDR):
|
|||
raise SDRError(e)
|
||||
|
||||
def set_tx_center_frequency(self, center_frequency):
|
||||
# ``adi.Pluto`` exposes one radio handle shared between RX and TX; concurrent
|
||||
# RX + TX sessions (see the agent ``_SdrRegistry``) may call RX and TX
|
||||
# setters at the same time. Serialize with ``_param_lock`` — RX setters hold
|
||||
# the same reentrant lock — so native attribute writes don't interleave.
|
||||
with self._param_lock:
|
||||
if center_frequency < 70e6 or center_frequency > 6e9:
|
||||
raise SDRParameterError(
|
||||
f"{self.__class__.__name__}: Center frequency {center_frequency/1e9:.3f} GHz "
|
||||
f"out of range:\nStandard:\t[{325e6/1e9:.3f} - {3.8e9/1e9:.3f} GHz]\nHacked:\t"
|
||||
f"[{70e6/1e9:.3f} - {6e9/1e9:.3f} GHz]"
|
||||
)
|
||||
if center_frequency < 70e6 or center_frequency > 6e9:
|
||||
raise SDRParameterError(
|
||||
f"{self.__class__.__name__}: Center frequency {center_frequency/1e9:.3f} GHz "
|
||||
f"out of range:\nStandard:\t[{325e6/1e9:.3f} - {3.8e9/1e9:.3f} GHz]\nHacked:\t"
|
||||
f"[{70e6/1e9:.3f} - {6e9/1e9:.3f} GHz]"
|
||||
)
|
||||
|
||||
try:
|
||||
self.radio.tx_lo = int(center_frequency)
|
||||
self.tx_center_frequency = center_frequency
|
||||
except OSError as e:
|
||||
raise SDRError(e)
|
||||
except ValueError:
|
||||
raise SDRParameterError(
|
||||
f"{self.__class__.__name__}: Center frequency {center_frequency/1e9:.3f} GHz "
|
||||
f"out of range:\nStandard:\t[{325e6/1e9:.3f} - {3.8e9/1e9:.3f} GHz]\nHacked:\t"
|
||||
f"[{70e6/1e9:.3f} - {6e9/1e9:.3f} GHz]"
|
||||
)
|
||||
try:
|
||||
self.radio.tx_lo = int(center_frequency)
|
||||
self.tx_center_frequency = center_frequency
|
||||
except OSError as e:
|
||||
raise SDRError(e)
|
||||
except ValueError:
|
||||
raise SDRParameterError(
|
||||
f"{self.__class__.__name__}: Center frequency {center_frequency/1e9:.3f} GHz "
|
||||
f"out of range:\nStandard:\t[{325e6/1e9:.3f} - {3.8e9/1e9:.3f} GHz]\nHacked:\t"
|
||||
f"[{70e6/1e9:.3f} - {6e9/1e9:.3f} GHz]"
|
||||
)
|
||||
|
||||
def set_tx_sample_rate(self, sample_rate):
|
||||
# ``self.radio.sample_rate`` is shared between RX and TX on Pluto — RX's
|
||||
# ``set_rx_sample_rate`` writes the same native attribute. Hold ``_param_lock``
|
||||
# so full-duplex sessions can't interleave writes.
|
||||
with self._param_lock:
|
||||
min_rate, max_rate = 65.1e3, 61.44e6
|
||||
if sample_rate < min_rate or sample_rate > max_rate:
|
||||
raise SDRParameterError(
|
||||
f"{self.__class__.__name__}: Sample rate {sample_rate/1e6:.3f} Msps "
|
||||
f"out of range: [{min_rate/1e6:.3f} - {max_rate/1e6:.3f} Msps]"
|
||||
)
|
||||
min_rate, max_rate = 65.1e3, 61.44e6
|
||||
if sample_rate < min_rate or sample_rate > max_rate:
|
||||
raise SDRParameterError(
|
||||
f"{self.__class__.__name__}: Sample rate {sample_rate/1e6:.3f} Msps "
|
||||
f"out of range: [{min_rate/1e6:.3f} - {max_rate/1e6:.3f} Msps]"
|
||||
)
|
||||
|
||||
try:
|
||||
self.radio.sample_rate = sample_rate
|
||||
self.tx_sample_rate = sample_rate
|
||||
except OSError as e:
|
||||
raise SDRError(e)
|
||||
except ValueError:
|
||||
raise SDRParameterError(
|
||||
f"{self.__class__.__name__}: Sample rate {sample_rate/1e6:.3f} Msps "
|
||||
f"out of range: [{min_rate/1e6:.3f} - {max_rate/1e6:.3f} Msps]"
|
||||
)
|
||||
try:
|
||||
self.radio.sample_rate = sample_rate
|
||||
self.tx_sample_rate = sample_rate
|
||||
except OSError as e:
|
||||
raise SDRError(e)
|
||||
except ValueError:
|
||||
raise SDRParameterError(
|
||||
f"{self.__class__.__name__}: Sample rate {sample_rate/1e6:.3f} Msps "
|
||||
f"out of range: [{min_rate/1e6:.3f} - {max_rate/1e6:.3f} Msps]"
|
||||
)
|
||||
|
||||
def set_tx_gain(self, gain, channel=0, gain_mode="absolute"):
|
||||
# Serialize with RX setters: see ``set_tx_sample_rate`` above.
|
||||
with self._param_lock:
|
||||
tx_gain_min = -89
|
||||
tx_gain_max = 0
|
||||
tx_gain_min = -89
|
||||
tx_gain_max = 0
|
||||
|
||||
if gain_mode == "relative":
|
||||
if gain > 0:
|
||||
raise SDRParameterError("When gain_mode = 'relative', gain must be < 0. This sets\
|
||||
the gain relative to the maximum possible gain.")
|
||||
else:
|
||||
abs_gain = tx_gain_max + gain
|
||||
if gain_mode == "relative":
|
||||
if gain > 0:
|
||||
raise SDRParameterError("When gain_mode = 'relative', gain must be < 0. This sets\
|
||||
the gain relative to the maximum possible gain.")
|
||||
else:
|
||||
abs_gain = gain
|
||||
abs_gain = tx_gain_max + gain
|
||||
else:
|
||||
abs_gain = gain
|
||||
|
||||
if abs_gain < tx_gain_min or abs_gain > tx_gain_max:
|
||||
abs_gain = min(max(gain, tx_gain_min), tx_gain_max)
|
||||
print(f"Gain {gain} out of range for Pluto.")
|
||||
print(f"Gain range: {tx_gain_min} to {tx_gain_max} dB")
|
||||
if abs_gain < tx_gain_min or abs_gain > tx_gain_max:
|
||||
abs_gain = min(max(gain, tx_gain_min), tx_gain_max)
|
||||
print(f"Gain {gain} out of range for Pluto.")
|
||||
print(f"Gain range: {tx_gain_min} to {tx_gain_max} dB")
|
||||
|
||||
try:
|
||||
self.tx_gain = abs_gain
|
||||
try:
|
||||
self.tx_gain = abs_gain
|
||||
|
||||
if channel == 0:
|
||||
self.radio.tx_hardwaregain_chan0 = int(abs_gain)
|
||||
elif channel == 1:
|
||||
self.radio.tx_hardwaregain_chan1 = int(abs_gain)
|
||||
else:
|
||||
raise SDRParameterError(f"Pluto channel must be 0 or 1 but was {channel}.")
|
||||
if channel == 0:
|
||||
self.radio.tx_hardwaregain_chan0 = int(abs_gain)
|
||||
elif channel == 1:
|
||||
self.radio.tx_hardwaregain_chan1 = int(abs_gain)
|
||||
else:
|
||||
raise SDRParameterError(f"Pluto channel must be 0 or 1 but was {channel}.")
|
||||
|
||||
except Exception as e:
|
||||
raise SDRError(e)
|
||||
except Exception as e:
|
||||
raise SDRError(e)
|
||||
|
||||
def set_tx_channel(self, channel):
|
||||
if channel == 0:
|
||||
|
|
|
|||
|
|
@ -561,51 +561,3 @@ class SDROverflowError(SDRError):
|
|||
"""Buffer overflow detected."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class SdrDisconnectedError(SDRError):
|
||||
"""Raised when the SDR device disappears mid-operation (USB unplug, network drop)."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
# Substrings that strongly indicate a device has disappeared rather than a
|
||||
# transient / recoverable error. Checked case-insensitively against str(exc).
|
||||
_DISCONNECT_MARKERS = (
|
||||
"no such device",
|
||||
"device not found",
|
||||
"not found",
|
||||
"broken pipe",
|
||||
"disconnected",
|
||||
"no device",
|
||||
"device unplugged",
|
||||
"usb",
|
||||
"i/o error",
|
||||
"input/output error",
|
||||
"errno 19", # ENODEV
|
||||
"errno 5", # EIO
|
||||
)
|
||||
|
||||
|
||||
def translate_disconnect(exc: BaseException) -> BaseException:
|
||||
"""Return ``SdrDisconnectedError`` if *exc* looks like a USB/device drop, else *exc*.
|
||||
|
||||
Drivers wrap their native-API calls with::
|
||||
|
||||
try:
|
||||
return self.radio.rx()
|
||||
except Exception as exc:
|
||||
raise translate_disconnect(exc) from exc
|
||||
|
||||
The caller (e.g. the streamer) can then catch ``SdrDisconnectedError``
|
||||
specifically and report it to the hub rather than crashing the loop.
|
||||
"""
|
||||
if isinstance(exc, SdrDisconnectedError):
|
||||
return exc
|
||||
msg = str(exc).lower()
|
||||
if any(marker in msg for marker in _DISCONNECT_MARKERS):
|
||||
return SdrDisconnectedError(str(exc))
|
||||
# OSError subclass with ENODEV / EIO errno is also a disconnect signal.
|
||||
if isinstance(exc, OSError) and getattr(exc, "errno", None) in (5, 19):
|
||||
return SdrDisconnectedError(str(exc))
|
||||
return exc
|
||||
|
|
|
|||
|
|
@ -1,111 +0,0 @@
|
|||
"""CLI flags for TX opt-in and interlocks."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import sys
|
||||
from unittest.mock import patch
|
||||
|
||||
from ria_toolkit_oss.agent import cli as agent_cli
|
||||
from ria_toolkit_oss.agent import config as agent_config
|
||||
|
||||
|
||||
class _FakeResp:
|
||||
def __init__(self, payload: dict):
|
||||
self._payload = payload
|
||||
|
||||
def read(self) -> bytes:
|
||||
return json.dumps(self._payload).encode()
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, *_a):
|
||||
return False
|
||||
|
||||
|
||||
def _run_register(argv: list[str], cfg_path) -> int:
|
||||
fake_resp = _FakeResp({"agent_id": "agent-1", "token": "tok-abc"})
|
||||
with patch.dict("os.environ", {"RIA_AGENT_CONFIG": str(cfg_path)}, clear=False), \
|
||||
patch("urllib.request.urlopen", return_value=fake_resp), \
|
||||
patch.object(sys, "argv", ["ria-agent", *argv]):
|
||||
try:
|
||||
agent_cli.main()
|
||||
except SystemExit as exc:
|
||||
return int(exc.code or 0)
|
||||
return 0
|
||||
|
||||
|
||||
def test_register_without_allow_tx_keeps_tx_disabled(tmp_path):
|
||||
cfg_path = tmp_path / "agent.json"
|
||||
_run_register(
|
||||
["register", "--hub", "http://hub:3005", "--api-key", "K"],
|
||||
cfg_path,
|
||||
)
|
||||
cfg = agent_config.load(path=cfg_path)
|
||||
assert cfg.agent_id == "agent-1"
|
||||
assert cfg.tx_enabled is False
|
||||
assert cfg.tx_max_gain_db is None
|
||||
|
||||
|
||||
def test_register_with_allow_tx_and_caps(tmp_path):
|
||||
cfg_path = tmp_path / "agent.json"
|
||||
_run_register(
|
||||
[
|
||||
"register",
|
||||
"--hub",
|
||||
"http://hub:3005",
|
||||
"--api-key",
|
||||
"K",
|
||||
"--allow-tx",
|
||||
"--tx-max-gain-db",
|
||||
"-10",
|
||||
"--tx-max-duration-s",
|
||||
"60",
|
||||
"--tx-freq-range",
|
||||
"2.4e9",
|
||||
"2.5e9",
|
||||
"--tx-freq-range",
|
||||
"5.7e9",
|
||||
"5.8e9",
|
||||
],
|
||||
cfg_path,
|
||||
)
|
||||
cfg = agent_config.load(path=cfg_path)
|
||||
assert cfg.tx_enabled is True
|
||||
assert cfg.tx_max_gain_db == -10.0
|
||||
assert cfg.tx_max_duration_s == 60.0
|
||||
assert cfg.tx_allowed_freq_ranges == [[2.4e9, 2.5e9], [5.7e9, 5.8e9]]
|
||||
|
||||
|
||||
def test_stream_allow_tx_does_not_persist(tmp_path):
|
||||
# Pre-register with tx_enabled=False, then simulate `stream --allow-tx`.
|
||||
# The on-disk config must remain unchanged; the runtime flag is process-local.
|
||||
cfg_path = tmp_path / "agent.json"
|
||||
base = agent_config.AgentConfig(
|
||||
hub_url="http://hub:3005",
|
||||
agent_id="agent-1",
|
||||
token="tok-abc",
|
||||
tx_enabled=False,
|
||||
)
|
||||
agent_config.save(base, path=cfg_path)
|
||||
|
||||
captured: dict = {}
|
||||
|
||||
async def _fake_run_streamer(url, token, *, cfg):
|
||||
captured["cfg"] = cfg
|
||||
return None
|
||||
|
||||
with patch.dict("os.environ", {"RIA_AGENT_CONFIG": str(cfg_path)}, clear=False), \
|
||||
patch("ria_toolkit_oss.agent.streamer.run_streamer", new=_fake_run_streamer), \
|
||||
patch.object(sys, "argv", ["ria-agent", "stream", "--allow-tx"]):
|
||||
try:
|
||||
agent_cli.main()
|
||||
except SystemExit:
|
||||
pass
|
||||
|
||||
# Runtime cfg had TX flipped on
|
||||
assert captured["cfg"].tx_enabled is True
|
||||
# But the persisted file is untouched
|
||||
on_disk = agent_config.load(path=cfg_path)
|
||||
assert on_disk.tx_enabled is False
|
||||
|
|
@ -1,63 +0,0 @@
|
|||
from ria_toolkit_oss.agent import config as agent_config
|
||||
|
||||
|
||||
def test_round_trip(tmp_path):
|
||||
p = tmp_path / "agent.json"
|
||||
cfg = agent_config.AgentConfig(
|
||||
hub_url="https://hub.example.com",
|
||||
agent_id="agent-1",
|
||||
token="t",
|
||||
name="bench",
|
||||
insecure=True,
|
||||
)
|
||||
agent_config.save(cfg, path=p)
|
||||
loaded = agent_config.load(path=p)
|
||||
assert loaded == cfg
|
||||
|
||||
|
||||
def test_load_missing_returns_empty(tmp_path):
|
||||
loaded = agent_config.load(path=tmp_path / "none.json")
|
||||
assert loaded == agent_config.AgentConfig()
|
||||
|
||||
|
||||
def test_tx_fields_round_trip(tmp_path):
|
||||
p = tmp_path / "agent.json"
|
||||
cfg = agent_config.AgentConfig(
|
||||
hub_url="https://hub.example.com",
|
||||
agent_id="agent-1",
|
||||
token="t",
|
||||
tx_enabled=True,
|
||||
tx_max_gain_db=-10.0,
|
||||
tx_max_duration_s=60.0,
|
||||
tx_allowed_freq_ranges=[[2.4e9, 2.5e9], [5.7e9, 5.8e9]],
|
||||
)
|
||||
agent_config.save(cfg, path=p)
|
||||
loaded = agent_config.load(path=p)
|
||||
assert loaded.tx_enabled is True
|
||||
assert loaded.tx_max_gain_db == -10.0
|
||||
assert loaded.tx_max_duration_s == 60.0
|
||||
assert loaded.tx_allowed_freq_ranges == [[2.4e9, 2.5e9], [5.7e9, 5.8e9]]
|
||||
|
||||
|
||||
def test_tx_fields_default_when_absent(tmp_path):
|
||||
# Old configs written before TX existed should load cleanly with safe defaults.
|
||||
p = tmp_path / "agent.json"
|
||||
p.write_text('{"hub_url": "x", "agent_id": "a", "token": "t"}')
|
||||
cfg = agent_config.load(path=p)
|
||||
assert cfg.tx_enabled is False
|
||||
assert cfg.tx_max_gain_db is None
|
||||
assert cfg.tx_max_duration_s is None
|
||||
assert cfg.tx_allowed_freq_ranges is None
|
||||
|
||||
|
||||
def test_extra_keys_preserved(tmp_path):
|
||||
p = tmp_path / "agent.json"
|
||||
p.write_text('{"hub_url": "x", "custom": 42}')
|
||||
cfg = agent_config.load(path=p)
|
||||
assert cfg.hub_url == "x"
|
||||
assert cfg.extra == {"custom": 42}
|
||||
agent_config.save(cfg, path=p)
|
||||
import json
|
||||
|
||||
data = json.loads(p.read_text())
|
||||
assert data["custom"] == 42
|
||||
|
|
@ -1,83 +0,0 @@
|
|||
"""SdrDisconnectedError translation + streamer handling."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
|
||||
from ria_toolkit_oss.agent.streamer import Streamer
|
||||
from ria_toolkit_oss.sdr import SdrDisconnectedError
|
||||
from ria_toolkit_oss.sdr.sdr import translate_disconnect
|
||||
|
||||
|
||||
def test_translate_disconnect_usb_message():
|
||||
exc = RuntimeError("libiio: Input/output error (errno 5)")
|
||||
out = translate_disconnect(exc)
|
||||
assert isinstance(out, SdrDisconnectedError)
|
||||
|
||||
|
||||
def test_translate_disconnect_enodev_oserror():
|
||||
exc = OSError(19, "No such device")
|
||||
assert isinstance(translate_disconnect(exc), SdrDisconnectedError)
|
||||
|
||||
|
||||
def test_translate_disconnect_passes_through_unrelated():
|
||||
exc = ValueError("bad sample rate")
|
||||
assert translate_disconnect(exc) is exc
|
||||
|
||||
|
||||
def test_translate_disconnect_preserves_sdr_disconnected():
|
||||
original = SdrDisconnectedError("already typed")
|
||||
assert translate_disconnect(original) is original
|
||||
|
||||
|
||||
class _FlakySdr:
|
||||
"""SDR that raises SdrDisconnectedError on the first rx() call."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.closed = False
|
||||
|
||||
def rx(self, n): # noqa: D401 - trivial
|
||||
raise SdrDisconnectedError("usb gone")
|
||||
|
||||
def close(self):
|
||||
self.closed = True
|
||||
|
||||
|
||||
class _Ws:
|
||||
def __init__(self):
|
||||
self.json_sent = []
|
||||
self.bytes_sent = []
|
||||
|
||||
async def send_json(self, p):
|
||||
self.json_sent.append(p)
|
||||
|
||||
async def send_bytes(self, b):
|
||||
self.bytes_sent.append(b)
|
||||
|
||||
|
||||
def test_streamer_reports_disconnected_and_ends_capture():
|
||||
async def scenario():
|
||||
ws = _Ws()
|
||||
sdr = _FlakySdr()
|
||||
streamer = Streamer(ws=ws, sdr_factory=lambda d, i: sdr)
|
||||
await streamer.on_message(
|
||||
{
|
||||
"type": "start",
|
||||
"app_id": "a",
|
||||
"radio_config": {"device": "fake", "buffer_size": 8},
|
||||
}
|
||||
)
|
||||
# Wait for the capture loop to emit its error frame and tear down the session.
|
||||
for _ in range(100):
|
||||
if any(m.get("type") == "error" for m in ws.json_sent) and streamer._rx is None:
|
||||
break
|
||||
await asyncio.sleep(0.01)
|
||||
return ws, sdr, streamer
|
||||
|
||||
ws, sdr, streamer = asyncio.run(scenario())
|
||||
assert sdr.closed
|
||||
errors = [m for m in ws.json_sent if m.get("type") == "error"]
|
||||
assert errors, "expected an error frame"
|
||||
assert "disconnected" in errors[-1]["message"].lower()
|
||||
# Session handle cleared so future starts can proceed.
|
||||
assert streamer._rx is None
|
||||
|
|
@ -1,134 +0,0 @@
|
|||
"""Concurrent RX + TX sessions on the same agent — shared SDR via registry."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ria_toolkit_oss.agent.config import AgentConfig
|
||||
from ria_toolkit_oss.agent.streamer import Streamer
|
||||
from ria_toolkit_oss.sdr.mock import MockSDR
|
||||
|
||||
|
||||
class FullDuplexMockSDR(MockSDR):
|
||||
"""MockSDR with a recording TX path so the test can assert both directions."""
|
||||
|
||||
def __init__(self, buffer_size: int):
|
||||
super().__init__(buffer_size=buffer_size)
|
||||
self.tx_produced: list[np.ndarray] = []
|
||||
|
||||
def _stream_tx(self, callback):
|
||||
self._enable_tx = True
|
||||
self._tx_initialized = True
|
||||
while self._enable_tx:
|
||||
result = callback(self.rx_buffer_size)
|
||||
self.tx_produced.append(np.asarray(result).copy())
|
||||
time.sleep(0.005)
|
||||
|
||||
|
||||
class FakeWs:
|
||||
def __init__(self):
|
||||
self.json_sent = []
|
||||
self.bytes_sent = []
|
||||
|
||||
async def send_json(self, p):
|
||||
self.json_sent.append(p)
|
||||
|
||||
async def send_bytes(self, b):
|
||||
self.bytes_sent.append(b)
|
||||
|
||||
|
||||
def _iq_frame(samples: np.ndarray) -> bytes:
|
||||
interleaved = np.empty(samples.size * 2, dtype=np.float32)
|
||||
interleaved[0::2] = samples.real
|
||||
interleaved[1::2] = samples.imag
|
||||
return interleaved.tobytes()
|
||||
|
||||
|
||||
def test_rx_and_tx_share_one_sdr_instance():
|
||||
built: list[FullDuplexMockSDR] = []
|
||||
|
||||
def factory(device, identifier):
|
||||
sdr = FullDuplexMockSDR(buffer_size=16)
|
||||
built.append(sdr)
|
||||
return sdr
|
||||
|
||||
async def scenario():
|
||||
ws = FakeWs()
|
||||
s = Streamer(ws=ws, sdr_factory=factory, cfg=AgentConfig(tx_enabled=True))
|
||||
|
||||
# Start RX first.
|
||||
await s.on_message(
|
||||
{
|
||||
"type": "start",
|
||||
"app_id": "app-1",
|
||||
"radio_config": {"device": "mock", "buffer_size": 16},
|
||||
}
|
||||
)
|
||||
# Then start TX on the same device — should share the SDR handle.
|
||||
await s.on_message(
|
||||
{
|
||||
"type": "tx_start",
|
||||
"app_id": "app-1",
|
||||
"radio_config": {
|
||||
"device": "mock",
|
||||
"buffer_size": 16,
|
||||
"tx_sample_rate": 1_000_000,
|
||||
"tx_gain": -20,
|
||||
"tx_center_frequency": 2.45e9,
|
||||
"underrun_policy": "zero",
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
# Push a known TX buffer.
|
||||
marker = np.arange(16, dtype=np.complex64) + 7
|
||||
await s.on_binary(_iq_frame(marker))
|
||||
|
||||
# Let both directions produce output.
|
||||
for _ in range(80):
|
||||
rx_ok = len(ws.bytes_sent) >= 2
|
||||
tx_ok = any(np.array_equal(b, marker) for b in built[0].tx_produced) if built else False
|
||||
if rx_ok and tx_ok:
|
||||
break
|
||||
await asyncio.sleep(0.01)
|
||||
|
||||
# Heartbeat should show both sessions.
|
||||
hb = s.build_heartbeat()
|
||||
|
||||
# Stop TX first, RX keeps running.
|
||||
await s.on_message({"type": "tx_stop", "app_id": "app-1"})
|
||||
tx_after_stop = s._tx is None
|
||||
rx_still_active = s._rx is not None
|
||||
|
||||
# Now stop RX.
|
||||
await s.on_message({"type": "stop", "app_id": "app-1"})
|
||||
|
||||
return ws, s, built, hb, tx_after_stop, rx_still_active
|
||||
|
||||
ws, s, built, hb, tx_after_stop, rx_still_active = asyncio.run(scenario())
|
||||
|
||||
# One SDR was built and shared.
|
||||
assert len(built) == 1, f"expected exactly one SDR instance, got {len(built)}"
|
||||
|
||||
# Both directions produced output.
|
||||
assert len(ws.bytes_sent) >= 1, "RX produced no IQ frames"
|
||||
marker = np.arange(16, dtype=np.complex64) + 7
|
||||
assert any(
|
||||
np.array_equal(b, marker) for b in built[0].tx_produced
|
||||
), "TX callback never saw the pushed marker buffer"
|
||||
|
||||
# Heartbeat reflected both sessions while they were active.
|
||||
assert hb["sessions"]["rx"]["app_id"] == "app-1"
|
||||
assert hb["sessions"]["tx"]["app_id"] == "app-1"
|
||||
|
||||
# Stopping TX does not tear down RX.
|
||||
assert tx_after_stop
|
||||
assert rx_still_active
|
||||
|
||||
# After both stops, registry is empty.
|
||||
assert s._registry.refcount(("mock", None)) == 0
|
||||
assert s._rx is None
|
||||
assert s._tx is None
|
||||
|
|
@ -1,73 +0,0 @@
|
|||
from ria_toolkit_oss.agent import hardware
|
||||
from ria_toolkit_oss.sdr import detect_available
|
||||
|
||||
|
||||
def test_detect_available_includes_mock():
|
||||
drivers = detect_available()
|
||||
assert "mock" in drivers
|
||||
from ria_toolkit_oss.sdr.mock import MockSDR
|
||||
|
||||
assert drivers["mock"] is MockSDR
|
||||
|
||||
|
||||
def test_available_devices_sorted_list():
|
||||
devices = hardware.available_devices()
|
||||
assert isinstance(devices, list)
|
||||
assert devices == sorted(devices)
|
||||
assert "mock" in devices
|
||||
|
||||
|
||||
def test_heartbeat_payload_shape():
|
||||
p = hardware.heartbeat_payload()
|
||||
assert p["type"] == "heartbeat"
|
||||
assert p["status"] == "idle"
|
||||
assert "mock" in p["hardware"]
|
||||
assert "app_id" not in p
|
||||
# New fields, default shape
|
||||
assert p["capabilities"] == ["rx"]
|
||||
assert p["tx_enabled"] is False
|
||||
|
||||
p2 = hardware.heartbeat_payload(status="streaming", app_id="abc")
|
||||
assert p2["status"] == "streaming"
|
||||
assert p2["app_id"] == "abc"
|
||||
|
||||
|
||||
def test_heartbeat_payload_tx_capability_from_cfg():
|
||||
from ria_toolkit_oss.agent.config import AgentConfig
|
||||
|
||||
p = hardware.heartbeat_payload(cfg=AgentConfig(tx_enabled=True))
|
||||
assert p["capabilities"] == ["rx", "tx"]
|
||||
assert p["tx_enabled"] is True
|
||||
|
||||
|
||||
def test_heartbeat_payload_sessions_field():
|
||||
sessions = {"rx": {"app_id": "a", "state": "streaming"}}
|
||||
p = hardware.heartbeat_payload(status="streaming", app_id="a", sessions=sessions)
|
||||
assert p["sessions"] == sessions
|
||||
|
||||
|
||||
def test_heartbeat_payload_surfaces_tx_caps_when_enabled():
|
||||
from ria_toolkit_oss.agent.config import AgentConfig
|
||||
|
||||
cfg = AgentConfig(
|
||||
tx_enabled=True,
|
||||
tx_max_gain_db=-10.0,
|
||||
tx_max_duration_s=60.0,
|
||||
tx_allowed_freq_ranges=[[2.4e9, 2.5e9], [5.7e9, 5.8e9]],
|
||||
)
|
||||
p = hardware.heartbeat_payload(cfg=cfg)
|
||||
assert p["tx_max_gain_db"] == -10.0
|
||||
assert p["tx_max_duration_s"] == 60.0
|
||||
assert p["tx_allowed_freq_ranges"] == [[2.4e9, 2.5e9], [5.7e9, 5.8e9]]
|
||||
|
||||
|
||||
def test_heartbeat_payload_omits_caps_when_tx_disabled():
|
||||
from ria_toolkit_oss.agent.config import AgentConfig
|
||||
|
||||
# Caps set but tx_enabled=False — don't leak them; they're only meaningful
|
||||
# when the hub can attempt a tx_start.
|
||||
cfg = AgentConfig(tx_enabled=False, tx_max_gain_db=-10.0)
|
||||
p = hardware.heartbeat_payload(cfg=cfg)
|
||||
assert "tx_max_gain_db" not in p
|
||||
assert "tx_max_duration_s" not in p
|
||||
assert "tx_allowed_freq_ranges" not in p
|
||||
|
|
@ -1,100 +0,0 @@
|
|||
"""End-to-end: local websockets server drives a Streamer with a MockSDR."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
|
||||
import websockets
|
||||
|
||||
from ria_toolkit_oss.agent.streamer import Streamer
|
||||
from ria_toolkit_oss.agent.ws_client import WsClient
|
||||
from ria_toolkit_oss.sdr.mock import MockSDR
|
||||
|
||||
|
||||
def test_server_start_stream_stop_cycle_over_real_ws():
|
||||
async def scenario():
|
||||
control_frames: list[dict] = []
|
||||
binary_frames: list[bytes] = []
|
||||
ready = asyncio.Event()
|
||||
stopped = asyncio.Event()
|
||||
|
||||
async def server_handler(ws):
|
||||
# Agent will open the connection; wait for heartbeat first.
|
||||
try:
|
||||
first = await asyncio.wait_for(ws.recv(), timeout=2.0)
|
||||
control_frames.append(json.loads(first))
|
||||
await ws.send(
|
||||
json.dumps(
|
||||
{
|
||||
"type": "start",
|
||||
"app_id": "app-1",
|
||||
"radio_config": {
|
||||
"device": "mock",
|
||||
"buffer_size": 32,
|
||||
"sample_rate": 1_000_000,
|
||||
"center_frequency": 2.45e9,
|
||||
},
|
||||
}
|
||||
)
|
||||
)
|
||||
while len(binary_frames) < 3:
|
||||
msg = await asyncio.wait_for(ws.recv(), timeout=2.0)
|
||||
if isinstance(msg, bytes):
|
||||
binary_frames.append(msg)
|
||||
else:
|
||||
control_frames.append(json.loads(msg))
|
||||
ready.set()
|
||||
await ws.send(json.dumps({"type": "stop", "app_id": "app-1"}))
|
||||
# Drain final status frame.
|
||||
try:
|
||||
while True:
|
||||
msg = await asyncio.wait_for(ws.recv(), timeout=0.5)
|
||||
if isinstance(msg, bytes):
|
||||
binary_frames.append(msg)
|
||||
else:
|
||||
control_frames.append(json.loads(msg))
|
||||
except (asyncio.TimeoutError, Exception):
|
||||
pass
|
||||
stopped.set()
|
||||
except Exception:
|
||||
stopped.set()
|
||||
|
||||
server = await websockets.serve(server_handler, "127.0.0.1", 0)
|
||||
port = server.sockets[0].getsockname()[1]
|
||||
try:
|
||||
client = WsClient(
|
||||
f"ws://127.0.0.1:{port}",
|
||||
token="",
|
||||
heartbeat_interval=10.0,
|
||||
reconnect_pause=0.05,
|
||||
)
|
||||
streamer = Streamer(ws=client, sdr_factory=lambda d, i: MockSDR(buffer_size=32, seed=0))
|
||||
task = asyncio.create_task(
|
||||
client.run(on_message=streamer.on_message, heartbeat=streamer.build_heartbeat)
|
||||
)
|
||||
await asyncio.wait_for(ready.wait(), timeout=3.0)
|
||||
await asyncio.wait_for(stopped.wait(), timeout=3.0)
|
||||
client.stop()
|
||||
task.cancel()
|
||||
try:
|
||||
await task
|
||||
except (asyncio.CancelledError, Exception):
|
||||
pass
|
||||
finally:
|
||||
server.close()
|
||||
await server.wait_closed()
|
||||
return control_frames, binary_frames
|
||||
|
||||
controls, binaries = asyncio.run(scenario())
|
||||
|
||||
# Heartbeat reached the server.
|
||||
assert any(f.get("type") == "heartbeat" for f in controls)
|
||||
# Status transitioned idle -> streaming -> idle.
|
||||
statuses = [f["status"] for f in controls if f.get("type") == "status"]
|
||||
assert "streaming" in statuses
|
||||
assert statuses[-1] == "idle"
|
||||
# Three binary IQ frames of 32 samples × 2 floats × 4 bytes.
|
||||
assert len(binaries) >= 3
|
||||
for b in binaries[:3]:
|
||||
assert len(b) == 32 * 2 * 4
|
||||
|
|
@ -1,144 +0,0 @@
|
|||
"""End-to-end: local websockets server drives a Streamer's TX path."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
import websockets
|
||||
|
||||
from ria_toolkit_oss.agent.config import AgentConfig
|
||||
from ria_toolkit_oss.agent.streamer import Streamer
|
||||
from ria_toolkit_oss.agent.ws_client import WsClient
|
||||
from ria_toolkit_oss.sdr.mock import MockSDR
|
||||
|
||||
|
||||
class RecordingMockSDR(MockSDR):
|
||||
def __init__(self, buffer_size: int):
|
||||
super().__init__(buffer_size=buffer_size)
|
||||
self.tx_produced: list[np.ndarray] = []
|
||||
|
||||
def _stream_tx(self, callback):
|
||||
self._enable_tx = True
|
||||
self._tx_initialized = True
|
||||
while self._enable_tx:
|
||||
result = callback(self.rx_buffer_size)
|
||||
self.tx_produced.append(np.asarray(result).copy())
|
||||
time.sleep(0.005)
|
||||
|
||||
|
||||
def _iq_frame(samples: np.ndarray) -> bytes:
|
||||
interleaved = np.empty(samples.size * 2, dtype=np.float32)
|
||||
interleaved[0::2] = samples.real
|
||||
interleaved[1::2] = samples.imag
|
||||
return interleaved.tobytes()
|
||||
|
||||
|
||||
def test_server_tx_start_binary_stop_cycle_over_real_ws():
|
||||
BUF = 16
|
||||
sdr = RecordingMockSDR(buffer_size=BUF)
|
||||
marker = np.arange(BUF, dtype=np.complex64) + 1
|
||||
|
||||
async def scenario():
|
||||
control_frames: list[dict] = []
|
||||
done = asyncio.Event()
|
||||
|
||||
async def server_handler(ws):
|
||||
try:
|
||||
# Drain initial heartbeat.
|
||||
first = await asyncio.wait_for(ws.recv(), timeout=2.0)
|
||||
control_frames.append(json.loads(first))
|
||||
|
||||
await ws.send(
|
||||
json.dumps(
|
||||
{
|
||||
"type": "tx_start",
|
||||
"app_id": "tx-app",
|
||||
"radio_config": {
|
||||
"device": "mock",
|
||||
"buffer_size": BUF,
|
||||
"tx_sample_rate": 1_000_000,
|
||||
"tx_center_frequency": 2.45e9,
|
||||
"tx_gain": -20,
|
||||
"underrun_policy": "zero",
|
||||
},
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
# Push a few binary IQ frames.
|
||||
for _ in range(3):
|
||||
await ws.send(_iq_frame(marker))
|
||||
|
||||
# Wait for at least "armed" + "transmitting" statuses.
|
||||
for _ in range(100):
|
||||
msg = await asyncio.wait_for(ws.recv(), timeout=2.0)
|
||||
if isinstance(msg, str):
|
||||
control_frames.append(json.loads(msg))
|
||||
if any(
|
||||
f.get("type") == "tx_status" and f.get("state") == "transmitting"
|
||||
for f in control_frames
|
||||
):
|
||||
break
|
||||
|
||||
await ws.send(json.dumps({"type": "tx_stop", "app_id": "tx-app"}))
|
||||
|
||||
# Drain trailing statuses.
|
||||
try:
|
||||
while True:
|
||||
msg = await asyncio.wait_for(ws.recv(), timeout=0.5)
|
||||
if isinstance(msg, str):
|
||||
control_frames.append(json.loads(msg))
|
||||
except (asyncio.TimeoutError, Exception):
|
||||
pass
|
||||
finally:
|
||||
done.set()
|
||||
|
||||
server = await websockets.serve(server_handler, "127.0.0.1", 0)
|
||||
port = server.sockets[0].getsockname()[1]
|
||||
try:
|
||||
client = WsClient(
|
||||
f"ws://127.0.0.1:{port}",
|
||||
token="",
|
||||
heartbeat_interval=10.0,
|
||||
reconnect_pause=0.05,
|
||||
)
|
||||
streamer = Streamer(
|
||||
ws=client,
|
||||
sdr_factory=lambda d, i: sdr,
|
||||
cfg=AgentConfig(tx_enabled=True),
|
||||
)
|
||||
task = asyncio.create_task(
|
||||
client.run(
|
||||
on_message=streamer.on_message,
|
||||
heartbeat=streamer.build_heartbeat,
|
||||
on_binary=streamer.on_binary,
|
||||
)
|
||||
)
|
||||
await asyncio.wait_for(done.wait(), timeout=5.0)
|
||||
client.stop()
|
||||
task.cancel()
|
||||
try:
|
||||
await task
|
||||
except (asyncio.CancelledError, Exception):
|
||||
pass
|
||||
finally:
|
||||
server.close()
|
||||
await server.wait_closed()
|
||||
return control_frames, streamer
|
||||
|
||||
controls, streamer = asyncio.run(scenario())
|
||||
|
||||
# Heartbeat reached the server.
|
||||
assert any(f.get("type") == "heartbeat" for f in controls)
|
||||
# tx_status lifecycle: armed → transmitting → done.
|
||||
tx_states = [f["state"] for f in controls if f.get("type") == "tx_status"]
|
||||
assert tx_states[0] == "armed"
|
||||
assert "transmitting" in tx_states
|
||||
assert tx_states[-1] == "done"
|
||||
# TX callback saw our marker buffer at least once.
|
||||
assert any(np.array_equal(b, marker) for b in sdr.tx_produced)
|
||||
# Session cleared.
|
||||
assert streamer._tx is None
|
||||
|
|
@ -1,19 +0,0 @@
|
|||
"""Regression: legacy NodeAgent still importable after the package move."""
|
||||
|
||||
|
||||
def test_import_node_agent_from_package():
|
||||
from ria_toolkit_oss.agent import NodeAgent
|
||||
|
||||
assert NodeAgent.__name__ == "NodeAgent"
|
||||
|
||||
|
||||
def test_main_entry_point_exists():
|
||||
from ria_toolkit_oss.agent import main
|
||||
|
||||
assert callable(main)
|
||||
|
||||
|
||||
def test_legacy_module_still_direct_importable():
|
||||
from ria_toolkit_oss.agent.legacy_executor import NodeAgent as LegacyNodeAgent
|
||||
|
||||
assert LegacyNodeAgent.__name__ == "NodeAgent"
|
||||
|
|
@ -1,210 +0,0 @@
|
|||
"""Step-A6 (Pluto lock audit) coverage.
|
||||
|
||||
Verifies the two invariants the handoff doc calls for when RX and TX run
|
||||
concurrently on one shared SDR handle:
|
||||
|
||||
1. ``_param_lock`` actually serializes concurrent RX + TX setter calls — the
|
||||
spec's §A6 acceptance criterion is *"``_param_lock`` instrumented for
|
||||
contention"*. We drive parallel ``set_{rx,tx}_sample_rate`` calls through
|
||||
the lock and assert it's hit often enough to prove both paths fight for it.
|
||||
2. Under a sustained full-duplex session (RX capturing + TX transmitting on
|
||||
one ``(device, identifier)``), no setter write is dropped and no exception
|
||||
escapes the executor — i.e., the shared-handle assumption holds. Runs
|
||||
against ``MockSDR`` per the spec; the real Pluto driver now takes the
|
||||
same lock on its TX setters so the production code path is isomorphic.
|
||||
|
||||
The stress window is 2 seconds by default — the handoff mentions 30 s but
|
||||
that's impractical in CI. Set ``RIA_LOCK_STRESS_S`` to override.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ria_toolkit_oss.agent.config import AgentConfig
|
||||
from ria_toolkit_oss.agent.streamer import Streamer
|
||||
from ria_toolkit_oss.sdr.mock import MockSDR
|
||||
|
||||
|
||||
_STRESS_S = float(os.environ.get("RIA_LOCK_STRESS_S", "2.0"))
|
||||
|
||||
|
||||
class InstrumentedMockSDR(MockSDR):
|
||||
"""MockSDR that counts lock acquisitions and exposes a real ``_param_lock``.
|
||||
|
||||
``_param_lock`` is inherited from ``SDR`` as a reentrant lock; we wrap it
|
||||
with a counter that records every time RX or TX setters grab it, so the
|
||||
test can assert real contention rather than just "the code compiles".
|
||||
"""
|
||||
|
||||
def __init__(self, buffer_size: int):
|
||||
super().__init__(buffer_size=buffer_size)
|
||||
self.rx_lock_hits = 0
|
||||
self.tx_lock_hits = 0
|
||||
self.param_lock_hits = 0
|
||||
# Shadow lock that increments a counter each time __enter__ fires.
|
||||
real_lock = self._param_lock
|
||||
|
||||
test = self
|
||||
|
||||
class CountingLock:
|
||||
def __enter__(self_inner):
|
||||
test.param_lock_hits += 1
|
||||
real_lock.acquire()
|
||||
return self_inner
|
||||
|
||||
def __exit__(self_inner, *a):
|
||||
real_lock.release()
|
||||
return False
|
||||
|
||||
# ``threading.RLock`` interop for any code that calls acquire/release directly.
|
||||
def acquire(self_inner, *a, **k):
|
||||
test.param_lock_hits += 1
|
||||
return real_lock.acquire(*a, **k)
|
||||
|
||||
def release(self_inner):
|
||||
return real_lock.release()
|
||||
|
||||
self._param_lock = CountingLock()
|
||||
|
||||
# The MockSDR doesn't ship RX setter methods that hit the lock — override
|
||||
# ``sample_rate`` / ``center_freq`` / ``gain`` writes to route through the
|
||||
# same lock the real Pluto driver uses, so this test faithfully models the
|
||||
# production contention path.
|
||||
def set_rx_sample_rate(self, sample_rate):
|
||||
with self._param_lock:
|
||||
self.rx_lock_hits += 1
|
||||
self.rx_sample_rate = float(sample_rate)
|
||||
self.sample_rate = self.rx_sample_rate
|
||||
|
||||
def set_tx_sample_rate(self, sample_rate):
|
||||
with self._param_lock:
|
||||
self.tx_lock_hits += 1
|
||||
self.tx_sample_rate = float(sample_rate)
|
||||
# Mirror Pluto: both RX and TX write the same native attribute.
|
||||
self.sample_rate = self.tx_sample_rate
|
||||
|
||||
|
||||
class FakeWs:
|
||||
def __init__(self):
|
||||
self.json_sent: list[dict] = []
|
||||
self.bytes_sent: list[bytes] = []
|
||||
|
||||
async def send_json(self, p):
|
||||
self.json_sent.append(p)
|
||||
|
||||
async def send_bytes(self, b):
|
||||
self.bytes_sent.append(b)
|
||||
|
||||
|
||||
def _iq_frame(samples: np.ndarray) -> bytes:
|
||||
interleaved = np.empty(samples.size * 2, dtype=np.float32)
|
||||
interleaved[0::2] = samples.real
|
||||
interleaved[1::2] = samples.imag
|
||||
return interleaved.tobytes()
|
||||
|
||||
|
||||
def test_param_lock_contended_under_concurrent_setters():
|
||||
"""Run two threads that hammer RX + TX sample-rate setters and assert both
|
||||
lock paths fire. This proves the lock is doing work — if either setter
|
||||
bypassed ``_param_lock``, one of the counters would stay at zero."""
|
||||
sdr = InstrumentedMockSDR(buffer_size=16)
|
||||
stop = threading.Event()
|
||||
|
||||
def rx_setter():
|
||||
i = 0
|
||||
while not stop.is_set():
|
||||
sdr.set_rx_sample_rate(1_000_000 + (i % 1000))
|
||||
i += 1
|
||||
|
||||
def tx_setter():
|
||||
i = 0
|
||||
while not stop.is_set():
|
||||
sdr.set_tx_sample_rate(2_000_000 + (i % 1000))
|
||||
i += 1
|
||||
|
||||
t1 = threading.Thread(target=rx_setter)
|
||||
t2 = threading.Thread(target=tx_setter)
|
||||
t1.start()
|
||||
t2.start()
|
||||
time.sleep(min(_STRESS_S, 2.0))
|
||||
stop.set()
|
||||
t1.join()
|
||||
t2.join()
|
||||
|
||||
assert sdr.rx_lock_hits > 100, f"RX setter barely ran: {sdr.rx_lock_hits}"
|
||||
assert sdr.tx_lock_hits > 100, f"TX setter barely ran: {sdr.tx_lock_hits}"
|
||||
# Every setter call should have passed through _param_lock exactly once.
|
||||
assert sdr.param_lock_hits >= sdr.rx_lock_hits + sdr.tx_lock_hits
|
||||
|
||||
|
||||
def test_full_duplex_stays_healthy_over_stress_window():
|
||||
"""Start RX + TX on one shared SDR and drive both paths for ``_STRESS_S``
|
||||
seconds, pushing binary frames and emitting ``tx_configure`` mid-stream.
|
||||
The session must survive, deliver buffers in both directions, and leave
|
||||
the registry clean on shutdown."""
|
||||
BUF = 32
|
||||
sdr = InstrumentedMockSDR(buffer_size=BUF)
|
||||
|
||||
async def scenario():
|
||||
ws = FakeWs()
|
||||
s = Streamer(ws=ws, sdr_factory=lambda d, i: sdr, cfg=AgentConfig(tx_enabled=True))
|
||||
|
||||
await s.on_message(
|
||||
{"type": "start", "app_id": "app-1",
|
||||
"radio_config": {"device": "mock", "buffer_size": BUF}}
|
||||
)
|
||||
await s.on_message(
|
||||
{"type": "tx_start", "app_id": "app-1",
|
||||
"radio_config": {
|
||||
"device": "mock", "buffer_size": BUF,
|
||||
"tx_sample_rate": 1_000_000,
|
||||
"tx_center_frequency": 2.45e9,
|
||||
"tx_gain": -20,
|
||||
"underrun_policy": "zero",
|
||||
}}
|
||||
)
|
||||
|
||||
marker = np.arange(BUF, dtype=np.complex64) + 1
|
||||
deadline = time.monotonic() + _STRESS_S
|
||||
i = 0
|
||||
while time.monotonic() < deadline:
|
||||
await s.on_binary(_iq_frame(marker))
|
||||
if i % 8 == 0:
|
||||
# Mid-stream parameter reconfiguration touches _apply_sdr_config,
|
||||
# which routes through the same setters the stress test above
|
||||
# verifies.
|
||||
await s.on_message(
|
||||
{"type": "tx_configure", "app_id": "app-1",
|
||||
"radio_config": {"tx_sample_rate": 1_000_000 + i}}
|
||||
)
|
||||
await s.on_message(
|
||||
{"type": "configure", "app_id": "app-1",
|
||||
"radio_config": {"sample_rate": 2_000_000 + i}}
|
||||
)
|
||||
i += 1
|
||||
await asyncio.sleep(0.005)
|
||||
|
||||
await s.on_message({"type": "tx_stop", "app_id": "app-1"})
|
||||
await s.on_message({"type": "stop", "app_id": "app-1"})
|
||||
return ws, s
|
||||
|
||||
ws, s = asyncio.run(scenario())
|
||||
|
||||
# No error frame leaked out.
|
||||
errors = [m for m in ws.json_sent
|
||||
if m.get("type") in ("error", "tx_status") and m.get("state") == "error"]
|
||||
assert errors == [], f"Unexpected error frames: {errors}"
|
||||
# RX produced IQ frames and TX's callback ran — heartbeat-level contention
|
||||
# check: both setter paths were hit at least once during configure dispatch.
|
||||
assert ws.bytes_sent, "RX produced no IQ frames"
|
||||
assert sdr.param_lock_hits > 0
|
||||
# Sessions cleaned up; registry drained.
|
||||
assert s._tx is None
|
||||
assert s._rx is None
|
||||
assert s._registry.refcount(("mock", None)) == 0
|
||||
|
|
@ -1,207 +0,0 @@
|
|||
"""Unit tests for the streamer: drive it with a fake WsClient + MockSDR."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ria_toolkit_oss.agent.streamer import (
|
||||
Streamer,
|
||||
_apply_sdr_config,
|
||||
_samples_to_interleaved_float32,
|
||||
)
|
||||
from ria_toolkit_oss.sdr.mock import MockSDR
|
||||
|
||||
|
||||
class FakeWs:
|
||||
def __init__(self):
|
||||
self.json_sent: list[dict] = []
|
||||
self.bytes_sent: list[bytes] = []
|
||||
|
||||
async def send_json(self, payload: dict) -> None:
|
||||
self.json_sent.append(payload)
|
||||
|
||||
async def send_bytes(self, data: bytes) -> None:
|
||||
self.bytes_sent.append(data)
|
||||
|
||||
|
||||
def _factory(device: str, identifier):
|
||||
return MockSDR(buffer_size=32, seed=0)
|
||||
|
||||
|
||||
def test_samples_to_interleaved_float32_roundtrip():
|
||||
c = np.array([1 + 2j, 3 + 4j], dtype=np.complex64)
|
||||
raw = _samples_to_interleaved_float32(c)
|
||||
arr = np.frombuffer(raw, dtype=np.float32)
|
||||
assert arr.tolist() == [1.0, 2.0, 3.0, 4.0]
|
||||
|
||||
|
||||
def test_apply_sdr_config_sets_attributes():
|
||||
sdr = MockSDR(buffer_size=16)
|
||||
_apply_sdr_config(sdr, {"sample_rate": 2e6, "center_frequency": 9.15e8, "gain": 30})
|
||||
assert sdr.sample_rate == 2e6
|
||||
assert sdr.center_freq == 9.15e8
|
||||
assert sdr.gain == 30
|
||||
|
||||
|
||||
def test_heartbeat_reflects_status_and_app():
|
||||
async def scenario():
|
||||
s = Streamer(ws=FakeWs(), sdr_factory=_factory)
|
||||
hb = s.build_heartbeat()
|
||||
assert hb["type"] == "heartbeat"
|
||||
assert hb["status"] == "idle"
|
||||
# capabilities default to rx-only
|
||||
assert hb["capabilities"] == ["rx"]
|
||||
assert hb["tx_enabled"] is False
|
||||
|
||||
await s.on_message(
|
||||
{
|
||||
"type": "start",
|
||||
"app_id": "app-42",
|
||||
"radio_config": {"device": "mock", "buffer_size": 32},
|
||||
}
|
||||
)
|
||||
hb2 = s.build_heartbeat()
|
||||
assert hb2["status"] == "streaming"
|
||||
assert hb2["app_id"] == "app-42"
|
||||
assert hb2["sessions"]["rx"]["app_id"] == "app-42"
|
||||
await s.on_message({"type": "stop", "app_id": "app-42"})
|
||||
|
||||
asyncio.run(scenario())
|
||||
|
||||
|
||||
def test_full_start_stream_stop_cycle():
|
||||
async def scenario():
|
||||
ws = FakeWs()
|
||||
streamer = Streamer(ws=ws, sdr_factory=_factory)
|
||||
|
||||
await streamer.on_message(
|
||||
{
|
||||
"type": "start",
|
||||
"app_id": "abc",
|
||||
"radio_config": {
|
||||
"device": "mock",
|
||||
"sample_rate": 1_000_000,
|
||||
"center_frequency": 2.45e9,
|
||||
"gain": 40,
|
||||
"buffer_size": 64,
|
||||
},
|
||||
}
|
||||
)
|
||||
for _ in range(30):
|
||||
if len(ws.bytes_sent) >= 2:
|
||||
break
|
||||
await asyncio.sleep(0.02)
|
||||
await streamer.on_message({"type": "stop", "app_id": "abc"})
|
||||
return ws, streamer
|
||||
|
||||
ws, streamer = asyncio.run(scenario())
|
||||
assert len(ws.bytes_sent) >= 1
|
||||
for frame in ws.bytes_sent:
|
||||
assert len(frame) == 64 * 2 * 4 # 64 samples × (I,Q) × float32
|
||||
statuses = [m for m in ws.json_sent if m.get("type") == "status"]
|
||||
assert statuses[0]["status"] == "streaming"
|
||||
assert statuses[-1]["status"] == "idle"
|
||||
assert streamer._rx is None
|
||||
|
||||
|
||||
def test_start_without_device_emits_error():
|
||||
async def scenario():
|
||||
ws = FakeWs()
|
||||
streamer = Streamer(ws=ws, sdr_factory=_factory)
|
||||
await streamer.on_message({"type": "start", "app_id": "x", "radio_config": {}})
|
||||
return ws
|
||||
|
||||
ws = asyncio.run(scenario())
|
||||
errors = [m for m in ws.json_sent if m.get("type") == "error"]
|
||||
assert errors and "device" in errors[0]["message"]
|
||||
|
||||
|
||||
def test_configure_queues_update():
|
||||
async def scenario():
|
||||
streamer = Streamer(ws=FakeWs(), sdr_factory=_factory)
|
||||
await streamer.on_message(
|
||||
{"type": "configure", "app_id": "x", "radio_config": {"center_frequency": 915e6}}
|
||||
)
|
||||
# Before start(), pending config lives on the standalone dict exposed via the _pending_config shim.
|
||||
return streamer._pending_config
|
||||
|
||||
pending = asyncio.run(scenario())
|
||||
assert pending == {"center_frequency": 915e6}
|
||||
|
||||
|
||||
def test_unknown_message_type_is_ignored():
|
||||
async def scenario():
|
||||
s = Streamer(ws=FakeWs(), sdr_factory=_factory)
|
||||
await s.on_message({"type": "nope"})
|
||||
|
||||
asyncio.run(scenario())
|
||||
|
||||
|
||||
def test_tx_data_available_is_a_silent_noop():
|
||||
# Hub sends this as a keepalive; we should accept and ignore without
|
||||
# emitting a WARNING or treating it as an error.
|
||||
async def scenario():
|
||||
ws = FakeWs()
|
||||
s = Streamer(ws=ws, sdr_factory=_factory)
|
||||
await s.on_message({"type": "tx_data_available", "app_id": "x"})
|
||||
return ws
|
||||
|
||||
ws = asyncio.run(scenario())
|
||||
# No outbound frames emitted.
|
||||
assert ws.json_sent == []
|
||||
assert ws.bytes_sent == []
|
||||
|
||||
|
||||
def test_registry_shares_sdr_across_start_stop_cycles():
|
||||
# Two sequential start/stop cycles with the same (device, identifier)
|
||||
# should hit the registry's cache path rather than constructing a new SDR.
|
||||
built: list[MockSDR] = []
|
||||
|
||||
def counting_factory(device: str, identifier):
|
||||
sdr = MockSDR(buffer_size=16, seed=0)
|
||||
built.append(sdr)
|
||||
return sdr
|
||||
|
||||
async def scenario():
|
||||
s = Streamer(ws=FakeWs(), sdr_factory=counting_factory)
|
||||
for _ in range(2):
|
||||
await s.on_message(
|
||||
{
|
||||
"type": "start",
|
||||
"app_id": "a",
|
||||
"radio_config": {"device": "mock", "buffer_size": 16},
|
||||
}
|
||||
)
|
||||
# Let one capture buffer flow before stopping so the loop is engaged.
|
||||
await asyncio.sleep(0.02)
|
||||
await s.on_message({"type": "stop", "app_id": "a"})
|
||||
|
||||
asyncio.run(scenario())
|
||||
# A new SDR per cycle (we fully close between starts) — registry refcount
|
||||
# drops to zero on each stop. This test confirms close-and-rebuild works;
|
||||
# the ref-counting share-while-open case is covered in the full-duplex tests.
|
||||
assert len(built) == 2
|
||||
|
||||
|
||||
def test_tx_start_rejected_when_tx_disabled():
|
||||
from ria_toolkit_oss.agent.config import AgentConfig
|
||||
|
||||
async def scenario():
|
||||
ws = FakeWs()
|
||||
s = Streamer(ws=ws, sdr_factory=_factory, cfg=AgentConfig(tx_enabled=False))
|
||||
await s.on_message(
|
||||
{
|
||||
"type": "tx_start",
|
||||
"app_id": "a",
|
||||
"radio_config": {"device": "mock", "tx_center_frequency": 2.45e9, "tx_gain": -20},
|
||||
}
|
||||
)
|
||||
return ws
|
||||
|
||||
ws = asyncio.run(scenario())
|
||||
tx_statuses = [m for m in ws.json_sent if m.get("type") == "tx_status"]
|
||||
assert tx_statuses, "expected a tx_status frame"
|
||||
assert tx_statuses[-1]["state"] == "error"
|
||||
assert "disabled" in tx_statuses[-1]["message"].lower()
|
||||
|
|
@ -1,140 +0,0 @@
|
|||
"""TX streaming happy path + shutdown semantics."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ria_toolkit_oss.agent.config import AgentConfig
|
||||
from ria_toolkit_oss.agent.streamer import Streamer
|
||||
from ria_toolkit_oss.sdr.mock import MockSDR
|
||||
|
||||
|
||||
class RecordingMockSDR(MockSDR):
|
||||
"""MockSDR that records each TX callback's returned buffer."""
|
||||
|
||||
def __init__(self, buffer_size: int):
|
||||
super().__init__(buffer_size=buffer_size)
|
||||
self.tx_produced: list[np.ndarray] = []
|
||||
|
||||
def _stream_tx(self, callback) -> None:
|
||||
self._enable_tx = True
|
||||
self._tx_initialized = True
|
||||
while self._enable_tx:
|
||||
result = callback(self.rx_buffer_size)
|
||||
self.tx_produced.append(np.asarray(result))
|
||||
time.sleep(0.005)
|
||||
|
||||
|
||||
class FakeWs:
|
||||
def __init__(self):
|
||||
self.json_sent: list[dict] = []
|
||||
self.bytes_sent: list[bytes] = []
|
||||
|
||||
async def send_json(self, payload):
|
||||
self.json_sent.append(payload)
|
||||
|
||||
async def send_bytes(self, data):
|
||||
self.bytes_sent.append(data)
|
||||
|
||||
|
||||
def _iq_frame(samples: np.ndarray) -> bytes:
|
||||
interleaved = np.empty(samples.size * 2, dtype=np.float32)
|
||||
interleaved[0::2] = samples.real
|
||||
interleaved[1::2] = samples.imag
|
||||
return interleaved.tobytes()
|
||||
|
||||
|
||||
def test_tx_start_streams_binary_to_callback():
|
||||
BUF = 16
|
||||
sdr = RecordingMockSDR(buffer_size=BUF)
|
||||
|
||||
async def scenario():
|
||||
ws = FakeWs()
|
||||
s = Streamer(ws=ws, sdr_factory=lambda d, i: sdr, cfg=AgentConfig(tx_enabled=True))
|
||||
|
||||
# Frames of distinct content so we can assert ordering.
|
||||
frame_a = np.arange(BUF, dtype=np.complex64) * (1 + 0j)
|
||||
frame_b = (np.arange(BUF, dtype=np.complex64) + BUF) * (1 + 0j)
|
||||
frame_c = (np.arange(BUF, dtype=np.complex64) + 2 * BUF) * (1 + 0j)
|
||||
|
||||
await s.on_message(
|
||||
{
|
||||
"type": "tx_start",
|
||||
"app_id": "app-1",
|
||||
"radio_config": {
|
||||
"device": "mock",
|
||||
"buffer_size": BUF,
|
||||
"tx_sample_rate": 1_000_000,
|
||||
"tx_center_frequency": 2.45e9,
|
||||
"tx_gain": -20,
|
||||
"underrun_policy": "zero",
|
||||
},
|
||||
}
|
||||
)
|
||||
# Push three IQ frames.
|
||||
await s.on_binary(_iq_frame(frame_a))
|
||||
await s.on_binary(_iq_frame(frame_b))
|
||||
await s.on_binary(_iq_frame(frame_c))
|
||||
|
||||
# Let the executor thread consume them.
|
||||
for _ in range(100):
|
||||
# At least the 3 real frames, plus any zero-fill from before they
|
||||
# arrived. We stop once 3 non-trivial buffers are recorded.
|
||||
nontrivial = [b for b in sdr.tx_produced if np.any(b != 0)]
|
||||
if len(nontrivial) >= 3:
|
||||
break
|
||||
await asyncio.sleep(0.01)
|
||||
|
||||
await s.on_message({"type": "tx_stop", "app_id": "app-1"})
|
||||
return ws, sdr, s
|
||||
|
||||
ws, sdr, streamer = asyncio.run(scenario())
|
||||
|
||||
nontrivial = [b for b in sdr.tx_produced if np.any(b != 0)]
|
||||
assert len(nontrivial) >= 3, "expected ≥3 nontrivial TX buffers"
|
||||
|
||||
# First three nontrivial buffers match the order we pushed them.
|
||||
np.testing.assert_array_equal(nontrivial[0], np.arange(BUF, dtype=np.complex64))
|
||||
np.testing.assert_array_equal(nontrivial[1], np.arange(BUF, 2 * BUF, dtype=np.complex64))
|
||||
np.testing.assert_array_equal(nontrivial[2], np.arange(2 * BUF, 3 * BUF, dtype=np.complex64))
|
||||
|
||||
# Lifecycle: armed → transmitting → done.
|
||||
states = [m["state"] for m in ws.json_sent if m.get("type") == "tx_status"]
|
||||
assert states[0] == "armed"
|
||||
assert "transmitting" in states
|
||||
assert states[-1] == "done"
|
||||
# Session cleared.
|
||||
assert streamer._tx is None
|
||||
|
||||
|
||||
def test_tx_stop_releases_sdr():
|
||||
sdr = RecordingMockSDR(buffer_size=8)
|
||||
|
||||
async def scenario():
|
||||
ws = FakeWs()
|
||||
s = Streamer(ws=ws, sdr_factory=lambda d, i: sdr, cfg=AgentConfig(tx_enabled=True))
|
||||
await s.on_message(
|
||||
{
|
||||
"type": "tx_start",
|
||||
"app_id": "a",
|
||||
"radio_config": {
|
||||
"device": "mock",
|
||||
"buffer_size": 8,
|
||||
"tx_sample_rate": 1_000_000,
|
||||
"tx_center_frequency": 2.45e9,
|
||||
"tx_gain": -20,
|
||||
"underrun_policy": "zero",
|
||||
},
|
||||
}
|
||||
)
|
||||
await asyncio.sleep(0.03)
|
||||
await s.on_message({"type": "tx_stop", "app_id": "a"})
|
||||
return s
|
||||
|
||||
s = asyncio.run(scenario())
|
||||
# After stop, the registry has no outstanding references to ("mock", None).
|
||||
assert s._registry.refcount(("mock", None)) == 0
|
||||
assert s._tx is None
|
||||
|
|
@ -1,174 +0,0 @@
|
|||
"""Agent-side TX interlocks: gain cap, freq ranges, duplicate sessions, disabled."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
|
||||
from ria_toolkit_oss.agent.config import AgentConfig
|
||||
from ria_toolkit_oss.agent.streamer import Streamer
|
||||
from ria_toolkit_oss.sdr.mock import MockSDR
|
||||
|
||||
|
||||
class FakeWs:
|
||||
def __init__(self):
|
||||
self.json_sent = []
|
||||
self.bytes_sent = []
|
||||
|
||||
async def send_json(self, p):
|
||||
self.json_sent.append(p)
|
||||
|
||||
async def send_bytes(self, b):
|
||||
self.bytes_sent.append(b)
|
||||
|
||||
|
||||
def _last_tx_status(ws):
|
||||
frames = [m for m in ws.json_sent if m.get("type") == "tx_status"]
|
||||
return frames[-1] if frames else None
|
||||
|
||||
|
||||
def _tx_start(app_id="a", **radio):
|
||||
rc = {
|
||||
"device": "mock",
|
||||
"buffer_size": 16,
|
||||
"tx_sample_rate": 1_000_000,
|
||||
"tx_center_frequency": 2.45e9,
|
||||
"tx_gain": -20,
|
||||
"underrun_policy": "zero",
|
||||
}
|
||||
rc.update(radio)
|
||||
return {"type": "tx_start", "app_id": app_id, "radio_config": rc}
|
||||
|
||||
|
||||
def _make_streamer(cfg):
|
||||
built: list = []
|
||||
|
||||
def factory(device, identifier):
|
||||
sdr = MockSDR(buffer_size=16)
|
||||
built.append(sdr)
|
||||
return sdr
|
||||
|
||||
ws = FakeWs()
|
||||
s = Streamer(ws=ws, sdr_factory=factory, cfg=cfg)
|
||||
return s, ws, built
|
||||
|
||||
|
||||
def test_rejects_when_tx_disabled():
|
||||
async def scenario():
|
||||
s, ws, built = _make_streamer(AgentConfig(tx_enabled=False))
|
||||
await s.on_message(_tx_start(tx_gain=-20, tx_center_frequency=2.45e9))
|
||||
return s, ws, built
|
||||
|
||||
s, ws, built = asyncio.run(scenario())
|
||||
status = _last_tx_status(ws)
|
||||
assert status and status["state"] == "error"
|
||||
assert "disabled" in status["message"].lower()
|
||||
assert not built, "SDR should never have been constructed"
|
||||
assert s._tx is None
|
||||
|
||||
|
||||
def test_rejects_when_tx_gain_exceeds_cap():
|
||||
async def scenario():
|
||||
s, ws, built = _make_streamer(AgentConfig(tx_enabled=True, tx_max_gain_db=-15.0))
|
||||
await s.on_message(_tx_start(tx_gain=-5, tx_center_frequency=2.45e9))
|
||||
return ws, built
|
||||
|
||||
ws, built = asyncio.run(scenario())
|
||||
status = _last_tx_status(ws)
|
||||
assert status and status["state"] == "error"
|
||||
assert "exceeds cap" in status["message"]
|
||||
assert not built
|
||||
|
||||
|
||||
def test_allows_gain_at_cap_boundary():
|
||||
async def scenario():
|
||||
s, ws, _ = _make_streamer(AgentConfig(tx_enabled=True, tx_max_gain_db=-10.0))
|
||||
await s.on_message(_tx_start(tx_gain=-10, tx_center_frequency=2.45e9))
|
||||
# Stop promptly to avoid keeping an executor thread around.
|
||||
await asyncio.sleep(0.02)
|
||||
await s.on_message({"type": "tx_stop", "app_id": "a"})
|
||||
return ws
|
||||
|
||||
ws = asyncio.run(scenario())
|
||||
states = [m["state"] for m in ws.json_sent if m.get("type") == "tx_status"]
|
||||
assert "armed" in states
|
||||
assert states[-1] == "done"
|
||||
|
||||
|
||||
def test_rejects_when_freq_outside_ranges():
|
||||
async def scenario():
|
||||
s, ws, built = _make_streamer(
|
||||
AgentConfig(
|
||||
tx_enabled=True,
|
||||
tx_allowed_freq_ranges=[[2.4e9, 2.5e9]],
|
||||
)
|
||||
)
|
||||
await s.on_message(_tx_start(tx_center_frequency=5.8e9, tx_gain=-20))
|
||||
return ws, built
|
||||
|
||||
ws, built = asyncio.run(scenario())
|
||||
status = _last_tx_status(ws)
|
||||
assert status and status["state"] == "error"
|
||||
assert "outside allowed ranges" in status["message"]
|
||||
assert not built
|
||||
|
||||
|
||||
def test_allows_freq_inside_a_range():
|
||||
async def scenario():
|
||||
s, ws, _ = _make_streamer(
|
||||
AgentConfig(
|
||||
tx_enabled=True,
|
||||
tx_allowed_freq_ranges=[[2.4e9, 2.5e9], [5.7e9, 5.8e9]],
|
||||
)
|
||||
)
|
||||
await s.on_message(_tx_start(tx_center_frequency=5.75e9, tx_gain=-20))
|
||||
await asyncio.sleep(0.02)
|
||||
await s.on_message({"type": "tx_stop", "app_id": "a"})
|
||||
return ws
|
||||
|
||||
ws = asyncio.run(scenario())
|
||||
states = [m["state"] for m in ws.json_sent if m.get("type") == "tx_status"]
|
||||
assert "armed" in states
|
||||
assert states[-1] == "done"
|
||||
|
||||
|
||||
def test_rejects_duplicate_tx_session():
|
||||
async def scenario():
|
||||
s, ws, _ = _make_streamer(AgentConfig(tx_enabled=True))
|
||||
await s.on_message(_tx_start(app_id="a", tx_gain=-20, tx_center_frequency=2.45e9))
|
||||
await asyncio.sleep(0.01)
|
||||
await s.on_message(_tx_start(app_id="b", tx_gain=-20, tx_center_frequency=2.45e9))
|
||||
# Let the second request process, then stop cleanly.
|
||||
await asyncio.sleep(0.01)
|
||||
await s.on_message({"type": "tx_stop", "app_id": "a"})
|
||||
return ws
|
||||
|
||||
ws = asyncio.run(scenario())
|
||||
errors = [
|
||||
m for m in ws.json_sent
|
||||
if m.get("type") == "tx_status" and m.get("state") == "error"
|
||||
]
|
||||
assert any("already active" in e.get("message", "") for e in errors)
|
||||
|
||||
|
||||
def test_rejects_invalid_underrun_policy():
|
||||
async def scenario():
|
||||
s, ws, _ = _make_streamer(AgentConfig(tx_enabled=True))
|
||||
await s.on_message(
|
||||
{
|
||||
"type": "tx_start",
|
||||
"app_id": "a",
|
||||
"radio_config": {
|
||||
"device": "mock",
|
||||
"buffer_size": 8,
|
||||
"tx_gain": -20,
|
||||
"tx_center_frequency": 2.45e9,
|
||||
"underrun_policy": "teleport",
|
||||
},
|
||||
}
|
||||
)
|
||||
return ws
|
||||
|
||||
ws = asyncio.run(scenario())
|
||||
status = _last_tx_status(ws)
|
||||
assert status and status["state"] == "error"
|
||||
assert "underrun_policy" in status["message"]
|
||||
|
|
@ -1,137 +0,0 @@
|
|||
"""Underrun policies: pause, zero, repeat."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ria_toolkit_oss.agent.config import AgentConfig
|
||||
from ria_toolkit_oss.agent.streamer import Streamer
|
||||
from ria_toolkit_oss.sdr.mock import MockSDR
|
||||
|
||||
|
||||
class RecordingMockSDR(MockSDR):
|
||||
def __init__(self, buffer_size: int):
|
||||
super().__init__(buffer_size=buffer_size)
|
||||
self.tx_produced: list[np.ndarray] = []
|
||||
|
||||
def _stream_tx(self, callback):
|
||||
self._enable_tx = True
|
||||
self._tx_initialized = True
|
||||
while self._enable_tx:
|
||||
result = callback(self.rx_buffer_size)
|
||||
self.tx_produced.append(np.asarray(result).copy())
|
||||
time.sleep(0.005)
|
||||
|
||||
|
||||
class FakeWs:
|
||||
def __init__(self):
|
||||
self.json_sent = []
|
||||
self.bytes_sent = []
|
||||
|
||||
async def send_json(self, p):
|
||||
self.json_sent.append(p)
|
||||
|
||||
async def send_bytes(self, b):
|
||||
self.bytes_sent.append(b)
|
||||
|
||||
|
||||
def _iq_frame(samples: np.ndarray) -> bytes:
|
||||
interleaved = np.empty(samples.size * 2, dtype=np.float32)
|
||||
interleaved[0::2] = samples.real
|
||||
interleaved[1::2] = samples.imag
|
||||
return interleaved.tobytes()
|
||||
|
||||
|
||||
def _start_cfg(policy: str, buf: int = 8) -> dict:
|
||||
return {
|
||||
"type": "tx_start",
|
||||
"app_id": "a",
|
||||
"radio_config": {
|
||||
"device": "mock",
|
||||
"buffer_size": buf,
|
||||
"tx_sample_rate": 1_000_000,
|
||||
"tx_gain": -20,
|
||||
"tx_center_frequency": 2.45e9,
|
||||
"underrun_policy": policy,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def test_underrun_pause_stops_session_and_emits_status():
|
||||
sdr = RecordingMockSDR(buffer_size=8)
|
||||
|
||||
async def scenario():
|
||||
ws = FakeWs()
|
||||
s = Streamer(ws=ws, sdr_factory=lambda d, i: sdr, cfg=AgentConfig(tx_enabled=True))
|
||||
await s.on_message(_start_cfg("pause"))
|
||||
# Do not push any buffers. The callback underruns on first tick and
|
||||
# the watchdog should emit "underrun" and tear down.
|
||||
for _ in range(100):
|
||||
if any(
|
||||
m.get("type") == "tx_status" and m.get("state") == "underrun"
|
||||
for m in ws.json_sent
|
||||
):
|
||||
break
|
||||
await asyncio.sleep(0.01)
|
||||
for _ in range(50):
|
||||
if s._tx is None:
|
||||
break
|
||||
await asyncio.sleep(0.01)
|
||||
return ws, s
|
||||
|
||||
ws, s = asyncio.run(scenario())
|
||||
states = [m["state"] for m in ws.json_sent if m.get("type") == "tx_status"]
|
||||
assert "underrun" in states
|
||||
assert s._tx is None
|
||||
|
||||
|
||||
def test_underrun_zero_keeps_session_alive():
|
||||
sdr = RecordingMockSDR(buffer_size=8)
|
||||
|
||||
async def scenario():
|
||||
ws = FakeWs()
|
||||
s = Streamer(ws=ws, sdr_factory=lambda d, i: sdr, cfg=AgentConfig(tx_enabled=True))
|
||||
await s.on_message(_start_cfg("zero"))
|
||||
# Let it produce several underrun-filled buffers.
|
||||
await asyncio.sleep(0.08)
|
||||
still_alive = s._tx is not None
|
||||
await s.on_message({"type": "tx_stop", "app_id": "a"})
|
||||
return ws, still_alive
|
||||
|
||||
ws, still_alive = asyncio.run(scenario())
|
||||
# No underrun status emitted (policy absorbs it silently).
|
||||
assert not any(
|
||||
m.get("type") == "tx_status" and m.get("state") == "underrun" for m in ws.json_sent
|
||||
)
|
||||
assert still_alive
|
||||
# All produced buffers are zero (no real data was pushed).
|
||||
assert sdr.tx_produced, "expected at least one TX callback invocation"
|
||||
assert all(not np.any(b != 0) for b in sdr.tx_produced)
|
||||
|
||||
|
||||
def test_underrun_repeat_replays_last_buffer():
|
||||
BUF = 8
|
||||
sdr = RecordingMockSDR(buffer_size=BUF)
|
||||
marker = np.arange(BUF, dtype=np.complex64) + 1 # distinct non-zero buffer
|
||||
|
||||
async def scenario():
|
||||
ws = FakeWs()
|
||||
s = Streamer(ws=ws, sdr_factory=lambda d, i: sdr, cfg=AgentConfig(tx_enabled=True))
|
||||
await s.on_message(_start_cfg("repeat", buf=BUF))
|
||||
await s.on_binary(_iq_frame(marker))
|
||||
# Give the executor time to consume the real frame + several repeats.
|
||||
await asyncio.sleep(0.08)
|
||||
await s.on_message({"type": "tx_stop", "app_id": "a"})
|
||||
return ws, sdr
|
||||
|
||||
ws, sdr = asyncio.run(scenario())
|
||||
# No underrun status emitted.
|
||||
assert not any(
|
||||
m.get("type") == "tx_status" and m.get("state") == "underrun" for m in ws.json_sent
|
||||
)
|
||||
# At least two buffers equal to the marker — the real one and ≥1 repeat.
|
||||
matching = [b for b in sdr.tx_produced if np.array_equal(b, marker)]
|
||||
assert len(matching) >= 2, f"expected ≥2 buffers matching marker, got {len(matching)}"
|
||||
|
|
@ -1,164 +0,0 @@
|
|||
"""Reconnect + heartbeat + malformed-control-frame behavior.
|
||||
|
||||
Binary-frame delivery lives in ``test_ws_client_binary.py`` to match the
|
||||
test matrix spelled out in ``Agent TX Streaming Handoff.md`` §A7.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
|
||||
import websockets
|
||||
|
||||
from ria_toolkit_oss.agent.ws_client import WsClient
|
||||
|
||||
|
||||
async def _recv_json(ws) -> dict:
|
||||
raw = await ws.recv()
|
||||
return json.loads(raw)
|
||||
|
||||
|
||||
async def _open_server(handler):
|
||||
# websockets 13 ignores extra positional args; bind to localhost:0 for an
|
||||
# ephemeral port and return both the server and the port.
|
||||
server = await websockets.serve(handler, "127.0.0.1", 0)
|
||||
port = server.sockets[0].getsockname()[1]
|
||||
return server, port
|
||||
|
||||
|
||||
def test_heartbeat_sent_on_connect():
|
||||
async def scenario():
|
||||
received: list[dict] = []
|
||||
connected = asyncio.Event()
|
||||
|
||||
async def handler(ws):
|
||||
connected.set()
|
||||
msg = await _recv_json(ws)
|
||||
received.append(msg)
|
||||
|
||||
server, port = await _open_server(handler)
|
||||
try:
|
||||
client = WsClient(
|
||||
f"ws://127.0.0.1:{port}",
|
||||
token="",
|
||||
heartbeat_interval=0.05,
|
||||
reconnect_pause=0.05,
|
||||
)
|
||||
task = asyncio.create_task(
|
||||
client.run(on_message=lambda _m: asyncio.sleep(0), heartbeat=lambda: {"type": "heartbeat", "n": 1})
|
||||
)
|
||||
await asyncio.wait_for(connected.wait(), timeout=2.0)
|
||||
for _ in range(50):
|
||||
if received:
|
||||
break
|
||||
await asyncio.sleep(0.02)
|
||||
client.stop()
|
||||
task.cancel()
|
||||
try:
|
||||
await task
|
||||
except (asyncio.CancelledError, Exception):
|
||||
pass
|
||||
finally:
|
||||
server.close()
|
||||
await server.wait_closed()
|
||||
return received
|
||||
|
||||
received = asyncio.run(scenario())
|
||||
assert received and received[0]["type"] == "heartbeat"
|
||||
|
||||
|
||||
def test_reconnects_after_server_drop():
|
||||
async def scenario():
|
||||
connections = 0
|
||||
first_dropped = asyncio.Event()
|
||||
|
||||
async def handler(ws):
|
||||
nonlocal connections
|
||||
connections += 1
|
||||
if connections == 1:
|
||||
await ws.close()
|
||||
first_dropped.set()
|
||||
else:
|
||||
try:
|
||||
await ws.recv()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
server, port = await _open_server(handler)
|
||||
try:
|
||||
client = WsClient(
|
||||
f"ws://127.0.0.1:{port}",
|
||||
token="",
|
||||
heartbeat_interval=10.0,
|
||||
reconnect_pause=0.05,
|
||||
)
|
||||
task = asyncio.create_task(
|
||||
client.run(on_message=lambda _m: asyncio.sleep(0), heartbeat=lambda: {"type": "heartbeat"})
|
||||
)
|
||||
await asyncio.wait_for(first_dropped.wait(), timeout=2.0)
|
||||
for _ in range(100):
|
||||
if connections >= 2:
|
||||
break
|
||||
await asyncio.sleep(0.02)
|
||||
client.stop()
|
||||
task.cancel()
|
||||
try:
|
||||
await task
|
||||
except (asyncio.CancelledError, Exception):
|
||||
pass
|
||||
finally:
|
||||
server.close()
|
||||
await server.wait_closed()
|
||||
return connections
|
||||
|
||||
n = asyncio.run(scenario())
|
||||
assert n >= 2
|
||||
|
||||
|
||||
def test_malformed_control_frame_does_not_crash():
|
||||
async def scenario():
|
||||
handled: list[dict] = []
|
||||
done = asyncio.Event()
|
||||
|
||||
async def handler(ws):
|
||||
await ws.send("not json")
|
||||
await ws.send(json.dumps({"type": "ping"}))
|
||||
done.set()
|
||||
try:
|
||||
await ws.wait_closed()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
server, port = await _open_server(handler)
|
||||
try:
|
||||
client = WsClient(
|
||||
f"ws://127.0.0.1:{port}",
|
||||
token="",
|
||||
heartbeat_interval=10.0,
|
||||
reconnect_pause=0.05,
|
||||
)
|
||||
|
||||
async def on_msg(m):
|
||||
handled.append(m)
|
||||
|
||||
task = asyncio.create_task(
|
||||
client.run(on_message=on_msg, heartbeat=lambda: {"type": "heartbeat"})
|
||||
)
|
||||
for _ in range(50):
|
||||
if handled:
|
||||
break
|
||||
await asyncio.sleep(0.02)
|
||||
client.stop()
|
||||
task.cancel()
|
||||
try:
|
||||
await task
|
||||
except (asyncio.CancelledError, Exception):
|
||||
pass
|
||||
finally:
|
||||
server.close()
|
||||
await server.wait_closed()
|
||||
return handled
|
||||
|
||||
handled = asyncio.run(scenario())
|
||||
assert handled and handled[0] == {"type": "ping"}
|
||||
|
|
@ -1,186 +0,0 @@
|
|||
"""Binary-frame delivery on the hub → agent WebSocket.
|
||||
|
||||
Named to match the test matrix in ``Agent TX Streaming Handoff.md`` §A7.
|
||||
Exercises:
|
||||
|
||||
- Binary frames are forwarded to an ``on_binary`` coroutine when supplied.
|
||||
- Binary frames are silently dropped (no crash) when ``on_binary`` is omitted,
|
||||
preserving the pre-TX behavior for RX-only deployments.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
|
||||
import websockets
|
||||
|
||||
from ria_toolkit_oss.agent.ws_client import WsClient
|
||||
|
||||
|
||||
async def _open_server(handler):
|
||||
server = await websockets.serve(handler, "127.0.0.1", 0)
|
||||
port = server.sockets[0].getsockname()[1]
|
||||
return server, port
|
||||
|
||||
|
||||
def test_binary_frame_forwarded_to_handler():
|
||||
payload = bytes(range(128))
|
||||
|
||||
async def scenario():
|
||||
received: list[bytes] = []
|
||||
done = asyncio.Event()
|
||||
|
||||
async def handler(ws):
|
||||
await ws.send(payload)
|
||||
done.set()
|
||||
try:
|
||||
await ws.wait_closed()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
server, port = await _open_server(handler)
|
||||
try:
|
||||
client = WsClient(
|
||||
f"ws://127.0.0.1:{port}",
|
||||
token="",
|
||||
heartbeat_interval=10.0,
|
||||
reconnect_pause=0.05,
|
||||
)
|
||||
|
||||
async def on_bin(data):
|
||||
received.append(data)
|
||||
|
||||
task = asyncio.create_task(
|
||||
client.run(
|
||||
on_message=lambda _m: asyncio.sleep(0),
|
||||
heartbeat=lambda: {"type": "heartbeat"},
|
||||
on_binary=on_bin,
|
||||
)
|
||||
)
|
||||
for _ in range(50):
|
||||
if received:
|
||||
break
|
||||
await asyncio.sleep(0.02)
|
||||
client.stop()
|
||||
task.cancel()
|
||||
try:
|
||||
await task
|
||||
except (asyncio.CancelledError, Exception):
|
||||
pass
|
||||
finally:
|
||||
server.close()
|
||||
await server.wait_closed()
|
||||
return received
|
||||
|
||||
received = asyncio.run(scenario())
|
||||
assert received == [payload]
|
||||
|
||||
|
||||
def test_binary_frame_dropped_when_no_handler():
|
||||
async def scenario():
|
||||
crashes: list[Exception] = []
|
||||
|
||||
async def handler(ws):
|
||||
await ws.send(b"\x00\x01\x02\x03")
|
||||
await ws.send(json.dumps({"type": "ping"}))
|
||||
try:
|
||||
await ws.wait_closed()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
messages: list[dict] = []
|
||||
server, port = await _open_server(handler)
|
||||
try:
|
||||
client = WsClient(
|
||||
f"ws://127.0.0.1:{port}",
|
||||
token="",
|
||||
heartbeat_interval=10.0,
|
||||
reconnect_pause=0.05,
|
||||
)
|
||||
|
||||
async def on_msg(m):
|
||||
messages.append(m)
|
||||
|
||||
task = asyncio.create_task(
|
||||
client.run(on_message=on_msg, heartbeat=lambda: {"type": "heartbeat"})
|
||||
)
|
||||
for _ in range(50):
|
||||
if messages:
|
||||
break
|
||||
await asyncio.sleep(0.02)
|
||||
client.stop()
|
||||
task.cancel()
|
||||
try:
|
||||
await task
|
||||
except (asyncio.CancelledError, Exception) as exc:
|
||||
crashes.append(exc)
|
||||
finally:
|
||||
server.close()
|
||||
await server.wait_closed()
|
||||
return messages, crashes
|
||||
|
||||
messages, _ = asyncio.run(scenario())
|
||||
assert messages and messages[0] == {"type": "ping"}
|
||||
|
||||
|
||||
def test_on_binary_exception_does_not_kill_connection():
|
||||
"""A buggy ``on_binary`` raises mid-stream; the WS loop keeps accepting frames."""
|
||||
|
||||
async def scenario():
|
||||
delivered_binary = 0
|
||||
delivered_control: list[dict] = []
|
||||
|
||||
async def handler(ws):
|
||||
await ws.send(b"\x10\x20\x30")
|
||||
await ws.send(b"\x40\x50\x60")
|
||||
await ws.send(json.dumps({"type": "ping"}))
|
||||
try:
|
||||
await ws.wait_closed()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
server, port = await _open_server(handler)
|
||||
try:
|
||||
client = WsClient(
|
||||
f"ws://127.0.0.1:{port}",
|
||||
token="",
|
||||
heartbeat_interval=10.0,
|
||||
reconnect_pause=0.05,
|
||||
)
|
||||
|
||||
async def on_bin(data):
|
||||
nonlocal delivered_binary
|
||||
delivered_binary += 1
|
||||
raise RuntimeError("handler broke")
|
||||
|
||||
async def on_msg(m):
|
||||
delivered_control.append(m)
|
||||
|
||||
task = asyncio.create_task(
|
||||
client.run(
|
||||
on_message=on_msg,
|
||||
heartbeat=lambda: {"type": "heartbeat"},
|
||||
on_binary=on_bin,
|
||||
)
|
||||
)
|
||||
for _ in range(60):
|
||||
if delivered_control:
|
||||
break
|
||||
await asyncio.sleep(0.02)
|
||||
client.stop()
|
||||
task.cancel()
|
||||
try:
|
||||
await task
|
||||
except (asyncio.CancelledError, Exception):
|
||||
pass
|
||||
finally:
|
||||
server.close()
|
||||
await server.wait_closed()
|
||||
return delivered_binary, delivered_control
|
||||
|
||||
bins, ctrls = asyncio.run(scenario())
|
||||
# Both binary frames were delivered to the (crashing) handler.
|
||||
assert bins == 2
|
||||
# The subsequent JSON frame still arrived — loop didn't die on the exceptions.
|
||||
assert ctrls and ctrls[0] == {"type": "ping"}
|
||||
Loading…
Reference in New Issue
Block a user