Compare commits

..

158 Commits
v0.1.3 ... main

Author SHA1 Message Date
2baae2f63e Merge pull request 'Update SDR guides, Getting Started Guide and fix Sphinx warnings for release' (#29) from docs/sdr-guides-update into main
All checks were successful
Build Sphinx Docs Set / Build Docs (push) Successful in 30s
Build Project / Build Project (3.10) (push) Successful in 11m37s
Build Project / Build Project (3.12) (push) Successful in 12m20s
Build Project / Build Project (3.11) (push) Successful in 13m40s
Test with tox / Test with tox (3.10) (push) Successful in 13m58s
Test with tox / Test with tox (3.11) (push) Successful in 14m24s
Test with tox / Test with tox (3.12) (push) Successful in 14m10s
Reviewed-on: #29
Reviewed-by: muq <muq@noreply.localhost>
2026-04-24 11:52:45 -04:00
4df5455af4 Merge branch 'main' into docs/sdr-guides-update
All checks were successful
Build Sphinx Docs Set / Build Docs (pull_request) Successful in 35s
Build Project / Build Project (3.10) (pull_request) Successful in 5m49s
Build Project / Build Project (3.11) (pull_request) Successful in 19m39s
Build Project / Build Project (3.12) (pull_request) Successful in 19m21s
Test with tox / Test with tox (3.11) (pull_request) Successful in 21m31s
Test with tox / Test with tox (3.12) (pull_request) Successful in 17m24s
Test with tox / Test with tox (3.10) (pull_request) Successful in 21m51s
2026-04-24 10:36:18 -04:00
2881aaf06e Merge pull request 'zfp-oss' (#27) from zfp-oss into main
Some checks failed
Build Sphinx Docs Set / Build Docs (push) Successful in 44m43s
Test with tox / Test with tox (3.10) (push) Successful in 1h4m45s
Build Project / Build Project (3.10) (push) Successful in 1h16m56s
Build Project / Build Project (3.12) (push) Successful in 1h16m52s
Test with tox / Test with tox (3.12) (push) Successful in 31m45s
Test with tox / Test with tox (3.11) (push) Successful in 47m45s
Build Project / Build Project (3.11) (push) Failing after 1h9m0s
Reviewed-on: #27
2026-04-23 11:10:43 -04:00
ben
50d04161b7 Merge remote-tracking branch 'origin/main' into zfp-oss
All checks were successful
Build Sphinx Docs Set / Build Docs (pull_request) Successful in 35s
Build Project / Build Project (3.10) (pull_request) Successful in 8m8s
Test with tox / Test with tox (3.11) (pull_request) Successful in 8m0s
Build Project / Build Project (3.11) (pull_request) Successful in 8m6s
Build Project / Build Project (3.12) (pull_request) Successful in 8m6s
Test with tox / Test with tox (3.12) (pull_request) Successful in 9m8s
Test with tox / Test with tox (3.10) (pull_request) Successful in 13m58s
2026-04-22 15:44:12 -04:00
ben
07c72294f5 removing orchestrator references
All checks were successful
Build Sphinx Docs Set / Build Docs (pull_request) Successful in 19s
Test with tox / Test with tox (3.12) (pull_request) Successful in 10m47s
Test with tox / Test with tox (3.11) (pull_request) Successful in 15m47s
Build Project / Build Project (3.12) (pull_request) Successful in 15m55s
Build Project / Build Project (3.11) (pull_request) Successful in 16m46s
Build Project / Build Project (3.10) (pull_request) Successful in 16m49s
Test with tox / Test with tox (3.10) (pull_request) Successful in 18m15s
2026-04-22 10:10:25 -04:00
ben
c9b19949ad timeout chunk improvements
All checks were successful
Build Sphinx Docs Set / Build Docs (pull_request) Successful in 19m57s
Build Project / Build Project (3.10) (pull_request) Successful in 19m59s
Test with tox / Test with tox (3.10) (pull_request) Successful in 19m46s
Build Project / Build Project (3.11) (pull_request) Successful in 20m19s
Build Project / Build Project (3.12) (pull_request) Successful in 20m21s
Test with tox / Test with tox (3.11) (pull_request) Successful in 18m48s
Test with tox / Test with tox (3.12) (pull_request) Successful in 1m25s
2026-04-21 17:11:16 -04:00
ben
53e8e5adb6 chunk timeout error
All checks were successful
Build Sphinx Docs Set / Build Docs (pull_request) Successful in 3m26s
Build Project / Build Project (3.10) (pull_request) Successful in 20m28s
Test with tox / Test with tox (3.10) (pull_request) Successful in 22m26s
Build Project / Build Project (3.11) (pull_request) Successful in 24m14s
Build Project / Build Project (3.12) (pull_request) Successful in 24m26s
Test with tox / Test with tox (3.11) (pull_request) Successful in 22m45s
Test with tox / Test with tox (3.12) (pull_request) Successful in 24m13s
2026-04-21 16:40:49 -04:00
a502dd97a9 Merge pull request 'Moved all contents of datatypes to data, refactored accordingly' (#28) from fix/unify_data_folders into main
All checks were successful
Build Project / Build Project (3.10) (push) Successful in 19m47s
Build Sphinx Docs Set / Build Docs (push) Successful in 23m31s
Build Project / Build Project (3.11) (push) Successful in 24m54s
Test with tox / Test with tox (3.11) (push) Successful in 18m11s
Build Project / Build Project (3.12) (push) Successful in 25m35s
Test with tox / Test with tox (3.10) (push) Successful in 26m50s
Test with tox / Test with tox (3.12) (push) Successful in 8m7s
Reviewed-on: #28
Reviewed-by: gillian <gillian@qoherent.ai>
2026-04-21 16:04:28 -04:00
ben
34b67c0c17 campaign loop support
All checks were successful
Build Sphinx Docs Set / Build Docs (pull_request) Successful in 13m32s
Build Project / Build Project (3.12) (pull_request) Successful in 13m49s
Build Project / Build Project (3.11) (pull_request) Successful in 15m28s
Build Project / Build Project (3.10) (pull_request) Successful in 15m37s
Test with tox / Test with tox (3.10) (pull_request) Successful in 6m40s
Test with tox / Test with tox (3.11) (pull_request) Successful in 4m27s
Test with tox / Test with tox (3.12) (pull_request) Successful in 7m57s
2026-04-21 15:56:04 -04:00
ben
39d5d74d6a large memory fix
All checks were successful
Build Sphinx Docs Set / Build Docs (pull_request) Successful in 20m54s
Build Project / Build Project (3.12) (pull_request) Successful in 5m13s
Build Project / Build Project (3.10) (pull_request) Successful in 25m24s
Build Project / Build Project (3.11) (pull_request) Successful in 25m31s
Test with tox / Test with tox (3.10) (pull_request) Successful in 6m18s
Test with tox / Test with tox (3.11) (pull_request) Successful in 15m2s
Test with tox / Test with tox (3.12) (pull_request) Successful in 19m57s
2026-04-21 15:03:57 -04:00
8a66860d33 Moved all contents of to , refactored accordingly
All checks were successful
Build Sphinx Docs Set / Build Docs (pull_request) Successful in 15m51s
Build Project / Build Project (3.10) (pull_request) Successful in 16m14s
Build Project / Build Project (3.11) (pull_request) Successful in 17m9s
Build Project / Build Project (3.12) (pull_request) Successful in 2m29s
Test with tox / Test with tox (3.12) (pull_request) Successful in 21m28s
Test with tox / Test with tox (3.10) (pull_request) Successful in 22m50s
Test with tox / Test with tox (3.11) (pull_request) Successful in 23m18s
2026-04-21 14:38:06 -04:00
ben
4d3aaf6ec8 json access issue
All checks were successful
Build Sphinx Docs Set / Build Docs (pull_request) Successful in 26s
Build Project / Build Project (3.12) (pull_request) Successful in 2m39s
Build Project / Build Project (3.10) (pull_request) Successful in 3m9s
Build Project / Build Project (3.11) (pull_request) Successful in 3m7s
Test with tox / Test with tox (3.10) (pull_request) Successful in 8m2s
Test with tox / Test with tox (3.11) (pull_request) Successful in 13m37s
Test with tox / Test with tox (3.12) (pull_request) Successful in 13m28s
2026-04-21 14:34:48 -04:00
ben
4aea2841be two-machine TX/RX
All checks were successful
Build Sphinx Docs Set / Build Docs (pull_request) Successful in 17m1s
Test with tox / Test with tox (3.10) (pull_request) Successful in 17m10s
Build Project / Build Project (3.10) (pull_request) Successful in 17m31s
Test with tox / Test with tox (3.11) (pull_request) Successful in 17m38s
Build Project / Build Project (3.11) (pull_request) Successful in 17m48s
Build Project / Build Project (3.12) (pull_request) Successful in 17m47s
Test with tox / Test with tox (3.12) (pull_request) Successful in 3m12s
2026-04-21 14:09:36 -04:00
ben
4c2c9c0288 rx and tx test
All checks were successful
Build Sphinx Docs Set / Build Docs (pull_request) Successful in 19s
Build Project / Build Project (3.11) (pull_request) Successful in 1m15s
Build Project / Build Project (3.10) (pull_request) Successful in 1m18s
Build Project / Build Project (3.12) (pull_request) Successful in 1m16s
Test with tox / Test with tox (3.11) (pull_request) Successful in 1m47s
Test with tox / Test with tox (3.12) (pull_request) Successful in 1m44s
Test with tox / Test with tox (3.10) (pull_request) Successful in 2m4s
2026-04-21 13:23:49 -04:00
Mmuq
a68a325cb4 Update SDR guides and fix Sphinx warnings for release
All checks were successful
Build Sphinx Docs Set / Build Docs (pull_request) Successful in 6m4s
Build Project / Build Project (3.10) (pull_request) Successful in 11m52s
Build Project / Build Project (3.11) (pull_request) Successful in 11m50s
Build Project / Build Project (3.12) (pull_request) Successful in 11m51s
Test with tox / Test with tox (3.11) (pull_request) Successful in 12m10s
Test with tox / Test with tox (3.12) (pull_request) Successful in 6m22s
Test with tox / Test with tox (3.10) (pull_request) Successful in 12m28s
Fix Sphinx build errors:
- Add missing blank lines in rtlsdr.rst code-block directives
- Rename duplicate label in examples/sdr/index.rst
- Fix field list indentation in usrp.py and hackrf.py docstrings

Update SDR setup guides (all guides now cover both pip/venv and Radioconda):
- rtlsdr: switch to rtl-sdr-blog fork (required for rtlsdr_set_dithering
  symbol), add pyrtlsdr==0.3.0 and setuptools==69.5.1 version pinning,
  preserve Radioconda blacklist and udev symlink paths alongside new steps
- pluto: simplify primary path to apt install libiio, add Avahi network
  discovery note, preserve Radioconda udev symlink as alternative
- hackrf: note out-of-box support, preserve Radioconda udev symlink
- blade: note no extra Python packages needed, preserve Radioconda udev symlinks
- usrp: add build-from-source path for pip/venv users with cmake flags,
  Python binding copy step, and version mismatch warning; keep conda install
  as primary option; preserve Radioconda udev symlink
- thinkrf: add lib2to3 install step, Python <=3.12 restriction, and full
  Python 3 patching command to replace internal script reference

Update copyright year to 2026 in conf.py
2026-04-21 12:29:18 -04:00
50438558d4 Merge pull request 'qac-cli-commands' (#26) from qac-cli-commands into main
All checks were successful
Build Sphinx Docs Set / Build Docs (push) Successful in 22s
Build Project / Build Project (3.10) (push) Successful in 1m26s
Build Project / Build Project (3.11) (push) Successful in 1m24s
Build Project / Build Project (3.12) (push) Successful in 1m26s
Test with tox / Test with tox (3.11) (push) Successful in 2m5s
Test with tox / Test with tox (3.12) (push) Successful in 1m56s
Test with tox / Test with tox (3.10) (push) Successful in 2m19s
Reviewed-on: #26
Reviewed-by: madrigal <madrigal@qoherent.ai>
2026-04-21 09:03:29 -04:00
ben
c27a5944c7 formats
All checks were successful
Build Sphinx Docs Set / Build Docs (pull_request) Successful in 8m30s
Build Project / Build Project (3.12) (pull_request) Successful in 4m15s
Build Project / Build Project (3.11) (pull_request) Successful in 4m17s
Build Project / Build Project (3.10) (pull_request) Successful in 4m19s
Test with tox / Test with tox (3.11) (pull_request) Successful in 14m59s
Test with tox / Test with tox (3.10) (pull_request) Successful in 20m7s
Test with tox / Test with tox (3.12) (pull_request) Successful in 18m9s
2026-04-20 16:49:52 -04:00
ben
062a0e766f Merge origin/main into zfp-oss; regenerate poetry.lock
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-20 16:44:59 -04:00
ben
cdcc03327b Merge remote-tracking branch 'origin/main' into zfp-oss 2026-04-20 16:42:08 -04:00
ben
8d2f9eebaf black fix
All checks were successful
Build Sphinx Docs Set / Build Docs (pull_request) Successful in 23s
Build Project / Build Project (3.11) (pull_request) Successful in 21m33s
Build Project / Build Project (3.10) (pull_request) Successful in 21m55s
Test with tox / Test with tox (3.10) (pull_request) Successful in 21m31s
Build Project / Build Project (3.12) (pull_request) Successful in 21m41s
Test with tox / Test with tox (3.11) (pull_request) Successful in 23m12s
Test with tox / Test with tox (3.12) (pull_request) Successful in 19m16s
2026-04-20 15:59:03 -04:00
ben
6019a38b8b Merge remote-tracking branch 'origin/main' into qac-cli-commands 2026-04-20 15:58:27 -04:00
ef772a3755 Merge pull request 'annotationsfix' (#19) from annotationsfix into main
All checks were successful
Test with tox / Test with tox (3.10) (push) Successful in 17m15s
Build Sphinx Docs Set / Build Docs (push) Successful in 19m13s
Build Project / Build Project (3.10) (push) Successful in 19m7s
Build Project / Build Project (3.12) (push) Successful in 17m47s
Build Project / Build Project (3.11) (push) Successful in 19m11s
Test with tox / Test with tox (3.11) (push) Successful in 14m44s
Test with tox / Test with tox (3.12) (push) Successful in 7m25s
Reviewed-on: #19
Reviewed-by: madrigal <madrigal@qoherent.ai>
2026-04-20 15:57:23 -04:00
ben
98f63b622b remote transmitter fix
Some checks failed
Build Sphinx Docs Set / Build Docs (pull_request) Successful in 17m18s
Build Project / Build Project (3.12) (pull_request) Successful in 16m12s
Build Project / Build Project (3.11) (pull_request) Successful in 17m50s
Build Project / Build Project (3.10) (pull_request) Successful in 18m33s
Test with tox / Test with tox (3.11) (pull_request) Successful in 5m1s
Test with tox / Test with tox (3.12) (pull_request) Successful in 4m45s
Test with tox / Test with tox (3.10) (pull_request) Failing after 21m16s
2026-04-20 15:27:54 -04:00
ben
db000da517 Add noqa C901 to view() to pass flake8 complexity check
All checks were successful
Build Sphinx Docs Set / Build Docs (pull_request) Successful in 17m39s
Build Project / Build Project (3.10) (pull_request) Successful in 18m13s
Build Project / Build Project (3.12) (pull_request) Successful in 8m42s
Build Project / Build Project (3.11) (pull_request) Successful in 9m2s
Test with tox / Test with tox (3.10) (pull_request) Successful in 4m34s
Test with tox / Test with tox (3.12) (pull_request) Successful in 5m27s
Test with tox / Test with tox (3.11) (pull_request) Successful in 17m4s
2026-04-20 15:08:31 -04:00
ben
c043eb0377 Merge clifix fixes: lint, onnxruntime py3.10, hackrf test, poetry.lock
Some checks failed
Build Project / Build Project (3.10) (pull_request) Successful in 14m39s
Build Sphinx Docs Set / Build Docs (pull_request) Successful in 20m16s
Build Project / Build Project (3.12) (pull_request) Successful in 20m9s
Build Project / Build Project (3.11) (pull_request) Successful in 20m31s
Test with tox / Test with tox (3.11) (pull_request) Successful in 18m1s
Test with tox / Test with tox (3.12) (pull_request) Successful in 8m34s
Test with tox / Test with tox (3.10) (pull_request) Failing after 31m15s
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-20 14:47:00 -04:00
d6c66d2a07 Merge branch 'main' of https://riahub.ai/qoherent/ria-toolkit-oss into annotationsfix
Some checks failed
Build Sphinx Docs Set / Build Docs (pull_request) Successful in 3m24s
Test with tox / Test with tox (3.11) (pull_request) Successful in 4m22s
Test with tox / Test with tox (3.10) (pull_request) Failing after 5m51s
Test with tox / Test with tox (3.12) (pull_request) Successful in 6m37s
Build Project / Build Project (3.10) (pull_request) Successful in 15m30s
Build Project / Build Project (3.12) (pull_request) Successful in 18m52s
Build Project / Build Project (3.11) (pull_request) Successful in 19m20s
2026-04-20 14:38:57 -04:00
414597e940 Merge pull request 'updated tags for new release' (#25) from clifix into main
All checks were successful
Test with tox / Test with tox (3.11) (push) Successful in 6m10s
Build Project / Build Project (3.10) (push) Successful in 15m3s
Build Sphinx Docs Set / Build Docs (push) Successful in 20m10s
Test with tox / Test with tox (3.10) (push) Successful in 20m10s
Build Project / Build Project (3.11) (push) Successful in 21m38s
Build Project / Build Project (3.12) (push) Successful in 21m40s
Test with tox / Test with tox (3.12) (push) Successful in 15m26s
Reviewed-on: #25
Reviewed-by: madrigal <madrigal@qoherent.ai>
2026-04-20 14:36:59 -04:00
ben
1fa9ab2495 Fix onnxruntime Python 3.10 incompatibility and hackrf test import failure
All checks were successful
Build Project / Build Project (3.10) (pull_request) Successful in 13m7s
Build Sphinx Docs Set / Build Docs (pull_request) Successful in 16m5s
Test with tox / Test with tox (3.10) (pull_request) Successful in 16m2s
Build Project / Build Project (3.11) (pull_request) Successful in 16m29s
Build Project / Build Project (3.12) (pull_request) Successful in 16m28s
Test with tox / Test with tox (3.11) (pull_request) Successful in 16m44s
Test with tox / Test with tox (3.12) (pull_request) Successful in 3m59s
- Restrict onnxruntime to Python >=3.11 (1.24.3+ dropped cp310 wheels)
- Fix hackrf tests to mock sys.modules instead of using patch(), which
  triggered a CDLL import of libhackrf.so.0 at module load time in CI

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-20 14:18:04 -04:00
ben
ab4cb0ea5a Remove poetry.toml system-site-packages workaround; regenerate lock file
Some checks failed
Build Sphinx Docs Set / Build Docs (pull_request) Successful in 7m9s
Test with tox / Test with tox (3.10) (pull_request) Failing after 8m38s
Test with tox / Test with tox (3.11) (pull_request) Failing after 9m24s
Build Project / Build Project (3.12) (pull_request) Successful in 9m40s
Build Project / Build Project (3.11) (pull_request) Successful in 9m53s
Test with tox / Test with tox (3.12) (pull_request) Failing after 3m1s
Build Project / Build Project (3.10) (pull_request) Successful in 10m17s
system-site-packages = true caused poetry install to fail in CI due to
conflicting system packages. Lock file regenerated without that constraint.
2026-04-20 13:56:18 -04:00
ben
22b035dbee format fixes
Some checks failed
Build Sphinx Docs Set / Build Docs (pull_request) Has been cancelled
Test with tox / Test with tox (3.10) (pull_request) Has been cancelled
Test with tox / Test with tox (3.11) (pull_request) Has been cancelled
Test with tox / Test with tox (3.12) (pull_request) Has been cancelled
Build Project / Build Project (3.12) (pull_request) Has been cancelled
Build Project / Build Project (3.11) (pull_request) Has been cancelled
Build Project / Build Project (3.10) (pull_request) Has been cancelled
2026-04-20 13:51:15 -04:00
ben
8e23558d90 Fix flake8 lint errors and regenerate poetry.lock
- Add TYPE_CHECKING guard for paramiko/zmq annotations in remote_transmitter_controller.py
- Remove unused imports (sys, threading, importlib, call) from remote_control tests
- Remove unused mock_ctrl_kwarg variable
- Add noqa C901 to _handle_tx_start (legitimately complex interlock logic)
- Regenerate poetry.lock to sync with pyproject.toml

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-20 13:50:59 -04:00
ben
ae07eef885 Regenerate poetry.lock to sync with pyproject.toml paramiko addition
Some checks failed
Build Sphinx Docs Set / Build Docs (pull_request) Successful in 22s
Build Project / Build Project (3.10) (pull_request) Successful in 2m54s
Test with tox / Test with tox (3.10) (pull_request) Failing after 3m24s
Build Project / Build Project (3.12) (pull_request) Successful in 4m55s
Build Project / Build Project (3.11) (pull_request) Successful in 10m10s
Test with tox / Test with tox (3.11) (pull_request) Has been cancelled
Test with tox / Test with tox (3.12) (pull_request) Has been cancelled
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-20 13:33:00 -04:00
ben
912fc54f25 Merge remote-tracking branch 'origin/qac-cli-commands' into zfp-oss
Some checks failed
Build Sphinx Docs Set / Build Docs (pull_request) Successful in 23s
Build Project / Build Project (3.10) (pull_request) Successful in 11m47s
Test with tox / Test with tox (3.10) (pull_request) Failing after 21m33s
Build Project / Build Project (3.12) (pull_request) Successful in 21m47s
Build Project / Build Project (3.11) (pull_request) Successful in 21m52s
Test with tox / Test with tox (3.12) (pull_request) Failing after 26m45s
Test with tox / Test with tox (3.11) (pull_request) Failing after 28m40s
2026-04-20 13:28:34 -04:00
ben
b884397f1f Merge remote-tracking branch 'origin/main' into zfp-oss 2026-04-20 13:28:12 -04:00
ben
f03825a6db Merge remote-tracking branch 'origin/main' into qac-cli-commands
Some checks failed
Build Sphinx Docs Set / Build Docs (pull_request) Successful in 18s
Test with tox / Test with tox (3.10) (pull_request) Failing after 17m10s
Build Project / Build Project (3.12) (pull_request) Successful in 17m21s
Build Project / Build Project (3.11) (pull_request) Successful in 17m30s
Build Project / Build Project (3.10) (pull_request) Successful in 17m36s
Test with tox / Test with tox (3.11) (pull_request) Failing after 18m4s
Test with tox / Test with tox (3.12) (pull_request) Failing after 18m0s
2026-04-20 13:27:17 -04:00
e506d26450 Replaces the getting_started.rst placeholder with a full CLI reference
Some checks failed
Build Sphinx Docs Set / Build Docs (pull_request) Successful in 19s
Test with tox / Test with tox (3.11) (pull_request) Failing after 17s
Test with tox / Test with tox (3.12) (pull_request) Failing after 16s
Test with tox / Test with tox (3.10) (pull_request) Failing after 33s
Build Project / Build Project (3.10) (pull_request) Successful in 1m5s
Build Project / Build Project (3.12) (pull_request) Successful in 1m4s
Build Project / Build Project (3.11) (pull_request) Successful in 1m6s
covering installation, commands, YAML config patterns, and a cheat sheet.
Adds custom CSS/JS for heading colours, warning admonition styling, code
block colours, and ria command highlighting. Fixes .gitignore to exclude
docs/_build/.
2026-04-20 13:10:31 -04:00
138fdeb68b Added getting started guide 2026-04-20 13:06:54 -04:00
ben
dae9510981 transmission code 2026-04-20 12:33:14 -04:00
Mmuq
d3a7e9ef0f Resolve merge conflict: keep Pass 2 spillover fix over remote's buggy version
Some checks failed
Build Sphinx Docs Set / Build Docs (pull_request) Successful in 18s
Build Project / Build Project (3.12) (pull_request) Successful in 3m21s
Test with tox / Test with tox (3.11) (pull_request) Successful in 3m24s
Build Project / Build Project (3.11) (pull_request) Successful in 3m35s
Build Project / Build Project (3.10) (pull_request) Successful in 3m37s
Test with tox / Test with tox (3.10) (pull_request) Failing after 4m6s
Test with tox / Test with tox (3.12) (pull_request) Successful in 4m25s
Remote annotationsfix had a partial port of threshold_qualifier without
the Pass 2 hysteresis spillover fix. Kept our corrected version in both
conflicting sections:
- Pass 2 mask expanded by window_size guard band around Pass 1 ranges
- Pass 2 expansion runs against residual_power instead of smoothed_power
2026-04-20 12:22:46 -04:00
Mmuq
93ae08bc91 Port threshold_qualifier improvements and Pass 2 spillover fix from utils
The OSS threshold_qualifier was last synced from utils on Feb 23 2026,
before the major robustness improvements landed in utils on Mar 19 2026.
This commit brings it fully up to date.

Changes ported from utils:
- Multi-pass detection (Pass 1 strong burst, Pass 2 weak residual,
  Pass 3 sustained faint burst via macro-window averaging)
- Noise floor estimation via percentile instead of simple max*threshold
- Dynamic range ratio guard (early exit on low-contrast captures)
- Improved _find_ranges, _expand_and_filter_ranges, _merge_ranges helpers
- Spectral smoothing in _estimate_spectral_bounds for wideband bursts
- Minimum duration filter expressed in absolute time (5ms) not sample count

Also includes the Pass 2 hysteresis spillover fix:
- Pass 2 expansion now runs against residual_power (masked) instead of
  smoothed_power, preventing it from walking into Pass 1 territory
- Pass 2 mask now has a window_size guard band around Pass 1 ranges,
  matching the guard already used in Pass 3

Only change from utils: import swapped to ria_toolkit_oss.datatypes.
2026-04-20 12:11:05 -04:00
0642dcc2db Added paramiko to dependencies
Some checks failed
Build Sphinx Docs Set / Build Docs (pull_request) Successful in 8m50s
Test with tox / Test with tox (3.11) (pull_request) Failing after 18s
Test with tox / Test with tox (3.10) (pull_request) Failing after 8m16s
Test with tox / Test with tox (3.12) (pull_request) Failing after 16s
Build Project / Build Project (3.10) (pull_request) Successful in 9m23s
Build Project / Build Project (3.11) (pull_request) Successful in 8m52s
Build Project / Build Project (3.12) (pull_request) Successful in 8m53s
2026-04-20 11:51:57 -04:00
J jonny
ea8ed56a7d add random name genration to agent registration
Some checks failed
Build Sphinx Docs Set / Build Docs (pull_request) Successful in 20s
Build Project / Build Project (3.10) (pull_request) Successful in 11m17s
Build Project / Build Project (3.12) (pull_request) Successful in 11m2s
Build Project / Build Project (3.11) (pull_request) Successful in 13m18s
Test with tox / Test with tox (3.11) (pull_request) Successful in 15m24s
Test with tox / Test with tox (3.10) (pull_request) Failing after 17m19s
Test with tox / Test with tox (3.12) (pull_request) Successful in 17m48s
2026-04-20 11:50:15 -04:00
84a7893c8f Updated poetry.lock, linting
Some checks failed
Build Sphinx Docs Set / Build Docs (pull_request) Successful in 17s
Test with tox / Test with tox (3.12) (pull_request) Failing after 2m24s
Test with tox / Test with tox (3.11) (pull_request) Failing after 2m12s
Build Project / Build Project (3.12) (pull_request) Successful in 8m59s
Build Project / Build Project (3.10) (pull_request) Successful in 9m17s
Build Project / Build Project (3.11) (pull_request) Successful in 9m16s
Test with tox / Test with tox (3.10) (pull_request) Failing after 16m0s
2026-04-20 11:43:03 -04:00
8e542919a8 updated tags for new release
Some checks failed
Build Sphinx Docs Set / Build Docs (pull_request) Successful in 18s
Test with tox / Test with tox (3.10) (pull_request) Failing after 13m9s
Build Project / Build Project (3.11) (pull_request) Successful in 13m21s
Build Project / Build Project (3.12) (pull_request) Successful in 13m19s
Build Project / Build Project (3.10) (pull_request) Successful in 13m26s
Test with tox / Test with tox (3.11) (pull_request) Failing after 13m40s
Test with tox / Test with tox (3.12) (pull_request) Failing after 13m33s
2026-04-20 11:21:36 -04:00
27049f00ea Merge pull request 'quick agent fix' (#24) from zfp-oss into main
Some checks failed
Build Sphinx Docs Set / Build Docs (push) Successful in 21s
Build Project / Build Project (3.10) (push) Successful in 16m7s
Build Project / Build Project (3.12) (push) Successful in 20m21s
Build Project / Build Project (3.11) (push) Successful in 21m9s
Test with tox / Test with tox (3.10) (push) Failing after 21m25s
Test with tox / Test with tox (3.11) (push) Failing after 15m0s
Test with tox / Test with tox (3.12) (push) Failing after 14m7s
Reviewed-on: #24
2026-04-17 11:50:10 -04:00
ben
78ecd171bd quick agent fix
Some checks failed
Build Sphinx Docs Set / Build Docs (pull_request) Successful in 18s
Test with tox / Test with tox (3.10) (pull_request) Failing after 51s
Build Project / Build Project (3.10) (pull_request) Successful in 1m12s
Build Project / Build Project (3.11) (pull_request) Successful in 1m12s
Build Project / Build Project (3.12) (pull_request) Successful in 1m14s
Test with tox / Test with tox (3.12) (pull_request) Failing after 2m29s
Test with tox / Test with tox (3.11) (pull_request) Failing after 1m37s
2026-04-17 11:49:44 -04:00
6fb73c1daa Merge pull request 'Ria Composer Support' (#23) from zfp-oss into main
Some checks failed
Build Sphinx Docs Set / Build Docs (push) Successful in 21s
Build Project / Build Project (3.11) (push) Successful in 6m27s
Build Project / Build Project (3.10) (push) Successful in 8m3s
Test with tox / Test with tox (3.10) (push) Failing after 11m43s
Build Project / Build Project (3.12) (push) Successful in 13m18s
Test with tox / Test with tox (3.11) (push) Failing after 11m46s
Test with tox / Test with tox (3.12) (push) Failing after 11m36s
Reviewed-on: #23
2026-04-17 10:11:56 -04:00
ben
83515d6e3f Merge remote-tracking branch 'origin' into zfp-oss
Some checks failed
Build Project / Build Project (3.10) (pull_request) Successful in 7m23s
Test with tox / Test with tox (3.11) (pull_request) Failing after 7m15s
Test with tox / Test with tox (3.10) (pull_request) Failing after 7m21s
Build Sphinx Docs Set / Build Docs (pull_request) Successful in 7m32s
Build Project / Build Project (3.11) (pull_request) Successful in 8m53s
Build Project / Build Project (3.12) (pull_request) Successful in 9m3s
Test with tox / Test with tox (3.12) (pull_request) Failing after 2m23s
2026-04-17 10:11:31 -04:00
ben
638fe5df1f test suite
Some checks failed
Build Sphinx Docs Set / Build Docs (pull_request) Successful in 18s
Test with tox / Test with tox (3.10) (pull_request) Failing after 1m51s
Test with tox / Test with tox (3.11) (pull_request) Failing after 2m43s
Test with tox / Test with tox (3.12) (pull_request) Failing after 2m45s
Build Project / Build Project (3.10) (pull_request) Successful in 4m2s
Build Project / Build Project (3.12) (pull_request) Successful in 3m48s
Build Project / Build Project (3.11) (pull_request) Successful in 4m40s
2026-04-17 10:04:06 -04:00
ben
efc0948110 ria composer support
Some checks failed
Build Sphinx Docs Set / Build Docs (pull_request) Successful in 18s
Test with tox / Test with tox (3.10) (pull_request) Failing after 8m47s
Build Project / Build Project (3.10) (pull_request) Successful in 9m5s
Build Project / Build Project (3.12) (pull_request) Successful in 9m3s
Build Project / Build Project (3.11) (pull_request) Successful in 9m5s
Test with tox / Test with tox (3.11) (pull_request) Failing after 9m16s
Test with tox / Test with tox (3.12) (pull_request) Failing after 9m10s
2026-04-17 09:43:59 -04:00
J jonny
5035f0654a tx_race_condtion_fix 2026-04-16 15:38:35 -04:00
J jonny
8c247f9f7a transmit further updates 2026-04-16 15:12:56 -04:00
J jonny
b955256479 Pluto TX streaming functionality base 2026-04-16 11:13:43 -04:00
J jonny
20fe86d399 allow sudo calls 2026-04-14 13:18:34 -04:00
J jonny
87bc78e063 new commands 2026-04-14 13:03:26 -04:00
8f39c4d855 Merge pull request 'Expand Number of samples' (#22) from zfp-oss into main
Some checks failed
Build Sphinx Docs Set / Build Docs (push) Successful in 1m16s
Test with tox / Test with tox (3.11) (push) Successful in 16m50s
Build Project / Build Project (3.10) (push) Successful in 17m2s
Test with tox / Test with tox (3.10) (push) Failing after 16m59s
Build Project / Build Project (3.11) (push) Successful in 17m34s
Build Project / Build Project (3.12) (push) Successful in 17m32s
Test with tox / Test with tox (3.12) (push) Successful in 16m19s
Reviewed-on: #22
2026-04-14 10:53:33 -04:00
ben
195db4a27d quick fix
Some checks failed
Build Sphinx Docs Set / Build Docs (pull_request) Successful in 18s
Test with tox / Test with tox (3.10) (pull_request) Failing after 58s
Build Project / Build Project (3.11) (pull_request) Successful in 1m11s
Build Project / Build Project (3.10) (pull_request) Successful in 1m15s
Build Project / Build Project (3.12) (pull_request) Successful in 1m12s
Test with tox / Test with tox (3.11) (pull_request) Successful in 1m23s
Test with tox / Test with tox (3.12) (pull_request) Successful in 1m18s
2026-04-14 10:45:54 -04:00
J jonny
84f3a63e8b simplifying cli 2026-04-13 12:54:05 -04:00
J jonny
a268f2ab25 Add changes for screens agent connections. 2026-04-13 11:48:15 -04:00
5718e109b5 Merge pull request 'Agent Error fix' (#21) from zfp-oss into main
Some checks failed
Build Sphinx Docs Set / Build Docs (push) Successful in 21s
Build Project / Build Project (3.10) (push) Successful in 5m34s
Test with tox / Test with tox (3.11) (push) Successful in 4m57s
Build Project / Build Project (3.11) (push) Successful in 7m24s
Build Project / Build Project (3.12) (push) Successful in 10m19s
Test with tox / Test with tox (3.10) (push) Failing after 17m4s
Test with tox / Test with tox (3.12) (push) Successful in 15m16s
Reviewed-on: #21
Reviewed-by: jonny <jonny@noreply.localhost>
2026-04-13 09:11:41 -04:00
d81c61c3cf Merge branch 'main' into zfp-oss
Some checks failed
Test with tox / Test with tox (3.11) (pull_request) Successful in 10m40s
Build Project / Build Project (3.10) (pull_request) Successful in 10m54s
Test with tox / Test with tox (3.10) (pull_request) Failing after 10m53s
Build Sphinx Docs Set / Build Docs (pull_request) Successful in 11m4s
Build Project / Build Project (3.11) (pull_request) Successful in 11m31s
Build Project / Build Project (3.12) (pull_request) Successful in 11m32s
Test with tox / Test with tox (3.12) (pull_request) Successful in 2m26s
2026-04-13 09:11:09 -04:00
ben
54b9bd4fc8 Agent Error fix
Some checks failed
Build Sphinx Docs Set / Build Docs (pull_request) Successful in 18s
Build Project / Build Project (3.10) (pull_request) Successful in 1m15s
Build Project / Build Project (3.11) (pull_request) Successful in 1m13s
Build Project / Build Project (3.12) (pull_request) Successful in 1m23s
Test with tox / Test with tox (3.11) (pull_request) Successful in 2m25s
Test with tox / Test with tox (3.12) (pull_request) Successful in 2m18s
Test with tox / Test with tox (3.10) (pull_request) Failing after 1m12s
2026-04-10 16:39:11 -04:00
1005228d69 Linting and updated poetry.lock
Some checks failed
Build Sphinx Docs Set / Build Docs (pull_request) Successful in 15m55s
Build Project / Build Project (3.10) (pull_request) Successful in 59s
Build Project / Build Project (3.11) (pull_request) Successful in 58s
Build Project / Build Project (3.12) (pull_request) Successful in 56s
Test with tox / Test with tox (3.12) (pull_request) Successful in 1m4s
Test with tox / Test with tox (3.11) (pull_request) Successful in 3m8s
Test with tox / Test with tox (3.10) (pull_request) Failing after 46s
2026-04-02 11:07:12 -04:00
3eb084bc08 Merge branch 'main' of https://riahub.ai/qoherent/ria-toolkit-oss into annotationsfix
Some checks failed
Test with tox / Test with tox (3.12) (pull_request) Failing after 0s
Build Sphinx Docs Set / Build Docs (pull_request) Successful in 15s
Build Project / Build Project (3.10) (pull_request) Successful in 57s
Build Project / Build Project (3.11) (pull_request) Successful in 56s
Build Project / Build Project (3.12) (pull_request) Successful in 55s
Test with tox / Test with tox (3.11) (pull_request) Failing after 12s
Test with tox / Test with tox (3.10) (pull_request) Failing after 23s
2026-04-02 11:00:42 -04:00
c7b88f1f14 Fixed import typo
Some checks failed
Build Sphinx Docs Set / Build Docs (pull_request) Failing after 1m56s
Build Project / Build Project (3.11) (pull_request) Failing after 1m54s
Build Project / Build Project (3.12) (pull_request) Failing after 5s
Test with tox / Test with tox (3.10) (pull_request) Failing after 3s
Test with tox / Test with tox (3.11) (pull_request) Failing after 0s
Test with tox / Test with tox (3.12) (pull_request) Failing after 0s
Build Project / Build Project (3.10) (pull_request) Successful in 3m3s
2026-04-02 10:52:31 -04:00
e1025794a8 Fixed overwrite issue 2026-04-02 10:52:17 -04:00
cfa1da9f4d Linting
Some checks failed
Build Project / Build Project (3.12) (pull_request) Failing after 1s
Test with tox / Test with tox (3.11) (pull_request) Failing after 1s
Test with tox / Test with tox (3.12) (pull_request) Failing after 1s
Build Sphinx Docs Set / Build Docs (pull_request) Successful in 46s
Build Project / Build Project (3.10) (pull_request) Successful in 7m7s
Test with tox / Test with tox (3.10) (pull_request) Failing after 13m42s
Build Project / Build Project (3.11) (pull_request) Successful in 4m31s
2026-04-02 10:37:42 -04:00
0e3e022084 Updated poetry.lock 2026-04-02 10:35:21 -04:00
2182899162 Merge pull request 'reporting campaign status' (#20) from zfp-oss into main
Some checks failed
Build Sphinx Docs Set / Build Docs (push) Successful in 20s
Test with tox / Test with tox (3.12) (push) Successful in 6m44s
Test with tox / Test with tox (3.11) (push) Successful in 7m59s
Build Project / Build Project (3.12) (push) Successful in 12m58s
Build Project / Build Project (3.10) (push) Successful in 13m5s
Build Project / Build Project (3.11) (push) Successful in 13m4s
Test with tox / Test with tox (3.10) (push) Failing after 13m6s
Reviewed-on: #20
2026-04-01 15:25:02 -04:00
ben
da9a0b07bd poetry update
Some checks failed
Build Sphinx Docs Set / Build Docs (pull_request) Successful in 38s
Test with tox / Test with tox (3.11) (pull_request) Failing after 14m19s
Build Project / Build Project (3.10) (pull_request) Successful in 17m2s
Test with tox / Test with tox (3.10) (pull_request) Failing after 17m1s
Build Project / Build Project (3.12) (pull_request) Successful in 17m29s
Build Project / Build Project (3.11) (pull_request) Successful in 17m31s
Test with tox / Test with tox (3.12) (pull_request) Failing after 28m46s
2026-04-01 15:05:30 -04:00
ben
3e9ac43800 reporting campaign status
Some checks failed
Build Sphinx Docs Set / Build Docs (pull_request) Successful in 7m42s
Test with tox / Test with tox (3.10) (pull_request) Failing after 8m9s
Build Project / Build Project (3.10) (pull_request) Successful in 8m22s
Build Project / Build Project (3.11) (pull_request) Successful in 8m21s
Build Project / Build Project (3.12) (pull_request) Successful in 8m21s
Test with tox / Test with tox (3.11) (pull_request) Successful in 8m33s
Test with tox / Test with tox (3.12) (pull_request) Successful in 1m14s
2026-04-01 14:08:13 -04:00
44df45160e Merge pull request 'zfp-oss tools' (#18) from zfp-oss into main
Some checks failed
Build Sphinx Docs Set / Build Docs (push) Successful in 17s
Test with tox / Test with tox (3.10) (push) Failing after 1m54s
Test with tox / Test with tox (3.11) (push) Successful in 2m54s
Build Project / Build Project (3.10) (push) Successful in 3m19s
Build Project / Build Project (3.12) (push) Successful in 3m25s
Build Project / Build Project (3.11) (push) Successful in 3m27s
Test with tox / Test with tox (3.12) (push) Successful in 3m23s
Reviewed-on: #18
Reviewed-by: gillian <gillian@qoherent.ai>
2026-04-01 13:52:06 -04:00
ben
f67c995846 reformats
Some checks failed
Build Sphinx Docs Set / Build Docs (pull_request) Successful in 16s
Test with tox / Test with tox (3.11) (pull_request) Successful in 3m56s
Test with tox / Test with tox (3.12) (pull_request) Successful in 7m51s
Build Project / Build Project (3.12) (pull_request) Successful in 12m32s
Build Project / Build Project (3.11) (pull_request) Successful in 13m9s
Build Project / Build Project (3.10) (pull_request) Successful in 13m14s
Test with tox / Test with tox (3.10) (pull_request) Failing after 13m13s
2026-04-01 12:28:41 -04:00
ben
c36fdcf607 optimiztions and fixes
Some checks failed
Build Sphinx Docs Set / Build Docs (pull_request) Successful in 16s
Test with tox / Test with tox (3.10) (pull_request) Failing after 17m6s
Build Project / Build Project (3.10) (pull_request) Successful in 17m26s
Build Project / Build Project (3.11) (pull_request) Successful in 17m25s
Build Project / Build Project (3.12) (pull_request) Successful in 17m27s
Test with tox / Test with tox (3.12) (pull_request) Successful in 17m21s
Test with tox / Test with tox (3.11) (pull_request) Failing after 21m50s
2026-04-01 11:57:59 -04:00
F fordg1
5d909c4a22 Update to be accurate to ria toolkit
Some checks failed
Build Sphinx Docs Set / Build Docs (pull_request) Successful in 14m18s
Build Project / Build Project (3.10) (pull_request) Successful in 5m4s
Test with tox / Test with tox (3.10) (pull_request) Failing after 47s
Build Project / Build Project (3.11) (pull_request) Successful in 1m16s
Build Project / Build Project (3.12) (pull_request) Successful in 1m16s
Test with tox / Test with tox (3.11) (pull_request) Failing after 34s
Test with tox / Test with tox (3.12) (pull_request) Failing after 34s
2026-03-31 15:27:45 -04:00
F fordg1
5cfced8855 Fix merge conflicts and port all imports from utils to ria_toolkit_oss
Some checks failed
Build Sphinx Docs Set / Build Docs (pull_request) Failing after 1s
Build Project / Build Project (3.10) (pull_request) Failing after 1s
Build Project / Build Project (3.11) (pull_request) Failing after 1s
Build Project / Build Project (3.12) (pull_request) Failing after 1s
Test with tox / Test with tox (3.10) (pull_request) Failing after 1s
Test with tox / Test with tox (3.11) (pull_request) Failing after 1s
Test with tox / Test with tox (3.12) (pull_request) Failing after 1s
Resolves unresolved merge conflict markers left in committed files across
the annotations, view, data, and CLI packages. Updates all remaining
imports from the old utils.* namespace to ria_toolkit_oss.datatypes,
ria_toolkit_oss.io, and ria_toolkit_oss.view equivalents.
2026-03-31 15:16:32 -04:00
F fordg1
ee2ce3b1f4 Merge branch 'annotationsfix' of https://riahub.ai/qoherent/ria-toolkit-oss into annotationsfix
Some checks failed
Build Sphinx Docs Set / Build Docs (pull_request) Failing after 1s
Build Project / Build Project (3.10) (pull_request) Failing after 1s
Build Project / Build Project (3.11) (pull_request) Failing after 1s
Build Project / Build Project (3.12) (pull_request) Failing after 1s
Test with tox / Test with tox (3.10) (pull_request) Failing after 1s
Test with tox / Test with tox (3.11) (pull_request) Failing after 1s
Test with tox / Test with tox (3.12) (pull_request) Failing after 1s
2026-03-31 14:56:36 -04:00
F fordg1
5b1c51797b logos 2026-03-31 14:54:27 -04:00
ben
9a960e2f29 zfp functionality and servers
Some checks failed
Build Sphinx Docs Set / Build Docs (pull_request) Failing after 1s
Build Project / Build Project (3.10) (pull_request) Successful in 57s
Build Project / Build Project (3.11) (pull_request) Successful in 1m7s
Build Project / Build Project (3.12) (pull_request) Successful in 56s
Test with tox / Test with tox (3.12) (pull_request) Failing after 5m13s
Test with tox / Test with tox (3.11) (pull_request) Failing after 5m48s
Test with tox / Test with tox (3.10) (pull_request) Failing after 8m46s
2026-03-31 13:51:10 -04:00
F fordg1
2bb2d9d5a7 Removing logos
Some checks failed
Build Sphinx Docs Set / Build Docs (pull_request) Failing after 1s
Build Project / Build Project (3.10) (pull_request) Failing after 1s
Build Project / Build Project (3.11) (pull_request) Failing after 1s
Build Project / Build Project (3.12) (pull_request) Failing after 1s
Test with tox / Test with tox (3.10) (pull_request) Failing after 1s
Test with tox / Test with tox (3.11) (pull_request) Failing after 1s
Test with tox / Test with tox (3.12) (pull_request) Failing after 1s
2026-03-31 13:48:38 -04:00
F fordg1
11d9532b5c Port annotation system from utils and fix ria package imports
Annotations package (new):
- Add threshold_qualifier with 3-pass hysteresis detector (Pass 1: strong
  bursts, Pass 2: weak residual bursts, Pass 3: macro-window faint burst
  detection), auto window_size scaled to 1ms, channel selection, and
  stable noise_floor baseline throughout
- Add energy_detector, cusum_annotator, parallel_signal_separator,
  qualify_slice, signal_isolation, annotation_transforms
- Add __init__.py exporting the four functions used by the CLI
- Fix all imports from utils.data → ria_toolkit_oss.datatypes

CLI annotate command (new):
- Port full annotate CLI from utils including list, add, remove, clear,
  energy, cusum, threshold, and separate subcommands
- Fix imports from utils.* → ria_toolkit_oss.* and utils_cli.* →
  ria_toolkit_oss_cli.*
- Safe overwrite logic: _annotated files always writable, originals
  protected; --overwrite writes in-place only on _annotated inputs

CLI view command:
- Add 'annotations' as a valid --type, wiring view_annotations from
  view_signal

view_signal.py:
- Add view_annotations function with blue/purple alternating palette and
  threshold %-sorted drawing order (lower % renders on top)

recording.py (datatypes):
- Fix lazy imports in to_wav() and to_blue() from utils.io → ria_toolkit_oss.io

io/recording.py:
- Add compatibility shim in from_npy to remap utils.data.annotation.Annotation
  to ria_toolkit_oss.datatypes.annotation.Annotation when loading .npy files
  pickled by the utils package
2026-03-31 13:34:00 -04:00
ben
7335dc4c52 server change 2026-03-12 11:45:07 -04:00
ben
019b0c6f4b reformats and campaign additions 2026-03-11 10:27:18 -04:00
Mmuq
e41f061caa Merge branch 'annotationsfix' of https://riahub.ai/qoherent/ria-toolkit-oss into annotationsfix 2026-02-23 14:14:16 -05:00
Mmuq
16ac8dbfb6 updated annotations from utils to oss 2026-02-23 14:12:34 -05:00
af3ae03baf Moving annotate into CLI 2026-02-23 14:09:42 -05:00
5c0c20619f Moving over from utils 2026-02-23 14:00:59 -05:00
4ee8ee5fe0 Moving from utils 2026-02-23 14:00:06 -05:00
f7eedfa2bd Annotate added to cli 2026-02-23 13:48:46 -05:00
fc6a1824a5 Added change log for future code from utils 2026-02-20 16:38:27 -05:00
b1e3ebf74f Merge pull request 'viewfix' (#17) from viewfix into main
All checks were successful
Build Sphinx Docs Set / Build Docs (push) Successful in 10s
Build Project / Build Project (3.10) (push) Successful in 32s
Build Project / Build Project (3.11) (push) Successful in 31s
Build Project / Build Project (3.12) (push) Successful in 30s
Test with tox / Test with tox (3.11) (push) Successful in 22s
Test with tox / Test with tox (3.10) (push) Successful in 28s
Test with tox / Test with tox (3.12) (push) Successful in 23s
Reviewed-on: #17
Reviewed-by: madrigal <madrigal@qoherent.ai>
2026-02-02 11:02:59 -05:00
1cb9cb6463 Linting
All checks were successful
Build Sphinx Docs Set / Build Docs (pull_request) Successful in 7s
Build Project / Build Project (3.10) (pull_request) Successful in 33s
Build Project / Build Project (3.11) (pull_request) Successful in 30s
Build Project / Build Project (3.12) (pull_request) Successful in 32s
Test with tox / Test with tox (3.10) (pull_request) Successful in 29s
Test with tox / Test with tox (3.11) (pull_request) Successful in 23s
Test with tox / Test with tox (3.12) (pull_request) Successful in 22s
2026-01-30 17:51:01 -05:00
823a0aba85 Linting
Some checks failed
Build Sphinx Docs Set / Build Docs (pull_request) Successful in 10s
Build Project / Build Project (3.10) (pull_request) Successful in 31s
Build Project / Build Project (3.11) (pull_request) Successful in 32s
Build Project / Build Project (3.12) (pull_request) Successful in 29s
Test with tox / Test with tox (3.11) (pull_request) Successful in 22s
Test with tox / Test with tox (3.10) (pull_request) Failing after 29s
Test with tox / Test with tox (3.12) (pull_request) Successful in 23s
2026-01-30 17:43:10 -05:00
1719057529 Moved titles to the left
Some checks failed
Build Sphinx Docs Set / Build Docs (pull_request) Successful in 8s
Build Project / Build Project (3.10) (pull_request) Successful in 33s
Build Project / Build Project (3.11) (pull_request) Successful in 30s
Test with tox / Test with tox (3.10) (pull_request) Failing after 27s
Build Project / Build Project (3.12) (pull_request) Successful in 30s
Test with tox / Test with tox (3.11) (pull_request) Successful in 22s
Test with tox / Test with tox (3.12) (pull_request) Successful in 22s
2026-01-29 16:21:58 -05:00
19f63bf3d0 Titles moved to center
Some checks failed
Build Sphinx Docs Set / Build Docs (pull_request) Successful in 8s
Build Project / Build Project (3.10) (pull_request) Successful in 32s
Build Project / Build Project (3.11) (pull_request) Successful in 30s
Test with tox / Test with tox (3.10) (pull_request) Failing after 26s
Build Project / Build Project (3.12) (pull_request) Successful in 30s
Test with tox / Test with tox (3.11) (pull_request) Successful in 23s
Test with tox / Test with tox (3.12) (pull_request) Successful in 22s
2026-01-29 16:16:38 -05:00
0b4824d1cb Adding Qoherent images
Some checks failed
Build Sphinx Docs Set / Build Docs (pull_request) Successful in 8s
Build Project / Build Project (3.10) (pull_request) Successful in 43s
Build Project / Build Project (3.11) (pull_request) Successful in 40s
Test with tox / Test with tox (3.10) (pull_request) Failing after 30s
Build Project / Build Project (3.12) (pull_request) Successful in 39s
Test with tox / Test with tox (3.11) (pull_request) Successful in 25s
Test with tox / Test with tox (3.12) (pull_request) Successful in 24s
2026-01-29 16:11:38 -05:00
44032cd799 Removed EngFormatter 2026-01-29 16:05:41 -05:00
f5209233d7 - Changed font size to be smaller
- Changed graph size to be larger
2026-01-29 16:00:40 -05:00
262032f424 Moved view titles to the left 2026-01-29 14:03:18 -05:00
6cc120062d Adjusted view formatting to give the title more room 2026-01-29 13:28:25 -05:00
00aec7278a Merge branch 'main' of https://riahub.ai/qoherent/ria-toolkit-oss into viewfix 2026-01-29 13:23:54 -05:00
71f23e3a96 Merge pull request 'recording performance improvements' (#16) from sred-recording-fix into main
Some checks failed
Build Sphinx Docs Set / Build Docs (push) Successful in 8s
Build Project / Build Project (3.10) (push) Successful in 34s
Build Project / Build Project (3.11) (push) Successful in 31s
Build Project / Build Project (3.12) (push) Successful in 30s
Test with tox / Test with tox (3.10) (push) Failing after 29s
Test with tox / Test with tox (3.11) (push) Successful in 22s
Test with tox / Test with tox (3.12) (push) Successful in 22s
Reviewed-on: #16
2026-01-27 12:56:38 -05:00
ben
0178adcdb5 lint fix
Some checks failed
Build Sphinx Docs Set / Build Docs (pull_request) Successful in 7s
Build Project / Build Project (3.10) (pull_request) Successful in 33s
Build Project / Build Project (3.11) (pull_request) Successful in 30s
Build Project / Build Project (3.12) (pull_request) Successful in 32s
Test with tox / Test with tox (3.11) (pull_request) Successful in 23s
Test with tox / Test with tox (3.12) (pull_request) Successful in 22s
Test with tox / Test with tox (3.10) (pull_request) Failing after 27s
2026-01-27 12:49:09 -05:00
ben
0ee6f5e63f recording performance improvements
Some checks failed
Build Sphinx Docs Set / Build Docs (pull_request) Successful in 8s
Build Project / Build Project (3.10) (pull_request) Successful in 34s
Build Project / Build Project (3.11) (pull_request) Successful in 31s
Test with tox / Test with tox (3.10) (pull_request) Failing after 26s
Build Project / Build Project (3.12) (pull_request) Successful in 30s
Test with tox / Test with tox (3.11) (pull_request) Successful in 23s
Test with tox / Test with tox (3.12) (pull_request) Successful in 22s
2026-01-27 12:44:27 -05:00
70f132c54c Merge pull request 'cli' (#15) from cli into main
All checks were successful
Build Sphinx Docs Set / Build Docs (push) Successful in 4m12s
Build Project / Build Project (3.10) (push) Successful in 6m50s
Build Project / Build Project (3.11) (push) Successful in 8m22s
Test with tox / Test with tox (3.11) (push) Successful in 4m57s
Test with tox / Test with tox (3.12) (push) Successful in 4m57s
Test with tox / Test with tox (3.10) (push) Successful in 6m33s
Build Project / Build Project (3.12) (push) Successful in 6m39s
Reviewed-on: #15
Reviewed-by: gillian <gillian@qoherent.ai>
2025-12-22 10:42:57 -05:00
9d010b8cd7 Updated poetry.lock file
All checks were successful
Build Sphinx Docs Set / Build Docs (pull_request) Successful in 1m25s
Test with tox / Test with tox (3.11) (pull_request) Successful in 1m58s
Test with tox / Test with tox (3.10) (pull_request) Successful in 2m28s
Build Project / Build Project (3.10) (pull_request) Successful in 4m8s
Build Project / Build Project (3.11) (pull_request) Successful in 4m33s
Test with tox / Test with tox (3.12) (pull_request) Successful in 2m38s
Build Project / Build Project (3.12) (pull_request) Successful in 4m35s
2025-12-19 11:27:24 -05:00
262d6ce9ee Formatting
Some checks failed
Build Sphinx Docs Set / Build Docs (pull_request) Successful in 1m29s
Test with tox / Test with tox (3.11) (pull_request) Failing after 34s
Test with tox / Test with tox (3.12) (pull_request) Failing after 33s
Test with tox / Test with tox (3.10) (pull_request) Failing after 1m15s
Build Project / Build Project (3.10) (pull_request) Successful in 2m34s
Build Project / Build Project (3.11) (pull_request) Successful in 2m39s
Build Project / Build Project (3.12) (pull_request) Successful in 2m35s
2025-12-19 11:25:06 -05:00
787ad8449b Added generated files to gitignore
Some checks failed
Test with tox / Test with tox (3.11) (pull_request) Failing after 36s
Build Sphinx Docs Set / Build Docs (pull_request) Successful in 2m13s
Test with tox / Test with tox (3.10) (pull_request) Failing after 1m1s
Test with tox / Test with tox (3.12) (pull_request) Failing after 56s
Build Project / Build Project (3.12) (pull_request) Successful in 2m49s
Build Project / Build Project (3.10) (pull_request) Successful in 3m37s
Build Project / Build Project (3.11) (pull_request) Successful in 3m35s
2025-12-19 11:19:03 -05:00
e967a33024 Removed recordings 2025-12-17 15:21:47 -05:00
97dc1ff272 Added number of samples to capture example 2025-12-17 15:21:04 -05:00
ab6e5dcb2f Updated imports to reflect new file structure 2025-12-15 15:59:48 -05:00
fb799013ad Merge branch 'cli' of https://riahub.ai/qoherent/ria-toolkit-oss into cli 2025-12-15 15:20:42 -05:00
a87af3d835 removed src formatting 2025-12-15 15:19:54 -05:00
a898a0a076 Fixed imports and capture error message 2025-12-15 15:19:49 -05:00
b33c0c2e01 Merge branch 'cli' of https://riahub.ai/qoherent/ria-toolkit-oss into cli 2025-12-15 15:08:16 -05:00
873b5bcab8 Fixed importing of blade drivers 2025-12-15 15:07:48 -05:00
dd0b1dd215 added click and matplotlib to pyproject 2025-12-15 15:07:21 -05:00
fb622f2ec6 moved the cli directory location 2025-12-15 13:51:17 -05:00
7868e2cf2a added examples and formatting 2025-12-15 12:08:57 -05:00
eccbe9f187 Removed unincluded channel models from use in cli 2025-12-12 16:46:21 -05:00
5c56fac7b4 Fixed usage of 'ria' extension in sigmf files 2025-12-12 16:45:42 -05:00
35a87131c2 Formatting 2025-12-12 14:52:19 -05:00
6ba108c908 Added necessary methods, suppressed unnecessary warnings 2025-12-12 14:43:26 -05:00
5f0ab7ac71 Fixed merging errors and import errors 2025-12-11 16:53:26 -05:00
806fcf8293 added tests for cli 2025-12-11 15:59:08 -05:00
155b13928b changed file to a string 2025-12-11 15:12:01 -05:00
54b41c246e changed load_rec to load_recording 2025-12-11 13:37:16 -05:00
cbd94c8fe0 Added signal package 2025-12-11 11:13:27 -05:00
e32f987715 added generate file back in, will need to change things when signal is added to ria toolkit oss 2025-12-09 14:49:34 -05:00
a8642f7b1d changed examples to use ria 2025-12-09 14:48:30 -05:00
14539d9269 removed iq_channel_models in transform.py
removed view_annotations from view.py
2025-12-09 14:38:49 -05:00
18395a0af8 adding view channels to view_signal 2025-12-09 14:12:11 -05:00
5398b292e7 added extra files to view, changed common and annotate files to be compatible with ria oss 2025-12-09 12:51:16 -05:00
2429d62067 Structure for cli implemented 2025-12-09 12:40:55 -05:00
7c1313a210 added IO and datatypes WAV file and midas blu file format 2025-12-08 15:28:45 -05:00
aeccbbdcae Merge pull request 'PSD, FFT, and Spectrogram 3D' (#14) from new-recording-widgets into main
All checks were successful
Build Sphinx Docs Set / Build Docs (push) Successful in 16s
Test with tox / Test with tox (3.11) (push) Successful in 37s
Test with tox / Test with tox (3.12) (push) Successful in 33s
Test with tox / Test with tox (3.10) (push) Successful in 47s
Build Project / Build Project (3.10) (push) Successful in 56s
Build Project / Build Project (3.11) (push) Successful in 55s
Build Project / Build Project (3.12) (push) Successful in 54s
Reviewed-on: #14
Reviewed-by: madrigal <madrigal@qoherent.ai>
2025-12-05 13:01:47 -05:00
f8007014d3 Merge branch 'main' into new-recording-widgets
All checks were successful
Build Sphinx Docs Set / Build Docs (pull_request) Successful in 19s
Test with tox / Test with tox (3.11) (pull_request) Successful in 34s
Test with tox / Test with tox (3.12) (pull_request) Successful in 32s
Build Project / Build Project (3.10) (pull_request) Successful in 53s
Build Project / Build Project (3.11) (pull_request) Successful in 53s
Test with tox / Test with tox (3.10) (pull_request) Successful in 47s
Build Project / Build Project (3.12) (pull_request) Successful in 52s
2025-12-05 11:42:32 -05:00
ben
5d3c67bb89 PSD, FFT, and Spectrogram 3D
All checks were successful
Build Sphinx Docs Set / Build Docs (pull_request) Successful in 21s
Test with tox / Test with tox (3.11) (pull_request) Successful in 43s
Test with tox / Test with tox (3.12) (pull_request) Successful in 37s
Test with tox / Test with tox (3.10) (pull_request) Successful in 54s
Build Project / Build Project (3.12) (pull_request) Successful in 59s
Build Project / Build Project (3.10) (pull_request) Successful in 1m6s
Build Project / Build Project (3.11) (pull_request) Successful in 1m4s
2025-12-05 11:19:06 -05:00
c251bf3633 Merge pull request 'New stylings for onnx and pytorch' (#13) from widget-style-fixes into main
All checks were successful
Build Sphinx Docs Set / Build Docs (push) Successful in 15s
Test with tox / Test with tox (3.11) (push) Successful in 33s
Test with tox / Test with tox (3.12) (push) Successful in 30s
Build Project / Build Project (3.10) (push) Successful in 50s
Test with tox / Test with tox (3.10) (push) Successful in 44s
Build Project / Build Project (3.11) (push) Successful in 51s
Build Project / Build Project (3.12) (push) Successful in 50s
Reviewed-on: #13
Reviewed-by: madrigal <madrigal@qoherent.ai>
2025-11-20 09:38:24 -05:00
ben
557c46f632 format fix
All checks were successful
Build Sphinx Docs Set / Build Docs (pull_request) Successful in 33s
Test with tox / Test with tox (3.11) (pull_request) Successful in 28s
Test with tox / Test with tox (3.10) (pull_request) Successful in 41s
Build Project / Build Project (3.10) (pull_request) Successful in 50s
Build Project / Build Project (3.12) (pull_request) Successful in 48s
Build Project / Build Project (3.11) (pull_request) Successful in 51s
Test with tox / Test with tox (3.12) (pull_request) Successful in 27s
2025-11-19 16:16:17 -05:00
ben
19a86e2a67 New stylings for onnx and pytorch
Some checks failed
Build Sphinx Docs Set / Build Docs (pull_request) Successful in 32s
Test with tox / Test with tox (3.11) (pull_request) Successful in 30s
Test with tox / Test with tox (3.10) (pull_request) Failing after 40s
Build Project / Build Project (3.11) (pull_request) Successful in 53s
Build Project / Build Project (3.10) (pull_request) Successful in 57s
Test with tox / Test with tox (3.12) (pull_request) Successful in 30s
Build Project / Build Project (3.12) (pull_request) Successful in 56s
2025-11-19 16:07:29 -05:00
fc098a12ee Merge pull request 'updates_and_fixes' (#12) from updates_and_fixes into main
All checks were successful
Build Sphinx Docs Set / Build Docs (push) Successful in 35s
Test with tox / Test with tox (3.11) (push) Successful in 31s
Build Project / Build Project (3.10) (push) Successful in 52s
Test with tox / Test with tox (3.10) (push) Successful in 47s
Build Project / Build Project (3.11) (push) Successful in 53s
Build Project / Build Project (3.12) (push) Successful in 54s
Test with tox / Test with tox (3.12) (push) Successful in 31s
Reviewed-on: #12
Reviewed-by: benchinnery <ben@qoherent.ai>
Reviewed-by: gillian <gillian@qoherent.ai>
2025-11-18 15:01:25 -05:00
ad2ffe7a3a Merge branch 'main' into updates_and_fixes
All checks were successful
Build Sphinx Docs Set / Build Docs (pull_request) Successful in 14s
Build Project / Build Project (3.12) (pull_request) Successful in 44s
Build Project / Build Project (3.11) (pull_request) Successful in 47s
Test with tox / Test with tox (3.10) (pull_request) Successful in 43s
Build Project / Build Project (3.10) (pull_request) Successful in 56s
Test with tox / Test with tox (3.12) (pull_request) Successful in 27s
Test with tox / Test with tox (3.11) (pull_request) Successful in 24s
2025-11-18 09:32:32 -05:00
e88cfafc50 Added garbage collection to viewers to prevent crashing, minor fixes to plots
All checks were successful
Build Sphinx Docs Set / Build Docs (pull_request) Successful in 15s
Test with tox / Test with tox (3.12) (pull_request) Successful in 32s
Test with tox / Test with tox (3.11) (pull_request) Successful in 37s
Build Project / Build Project (3.10) (pull_request) Successful in 51s
Test with tox / Test with tox (3.10) (pull_request) Successful in 45s
Build Project / Build Project (3.12) (pull_request) Successful in 51s
Build Project / Build Project (3.11) (pull_request) Successful in 55s
2025-11-17 13:51:27 -05:00
10801ffb57 Implemented close method, minor updates and improvements 2025-11-17 13:40:40 -05:00
a0d0899eab Added gain setter, fixed setting sample rate in init_rx 2025-11-17 13:29:32 -05:00
c575fa798c Updated setters, added buffer size calculation, standardized errors 2025-11-17 12:09:45 -05:00
0ea81c37ba Updated setters, removed redundant shutdown, added chunked recording 2025-11-17 11:54:05 -05:00
bca962d7b2 Added setter methods, fixed rx sample conversion, minor fixes 2025-11-17 11:39:57 -05:00
96d864aa0b Fixed shutdown and cleanup, standardized setters, and improved TX 2025-11-17 11:24:54 -05:00
c673967a90 Updated methods, added setters, and created standardized SDRError classes 2025-11-17 11:20:38 -05:00
4ac4e9c642 Merge pull request 'Added extra blacklist items to rtlsdr.rst' (#11) from rtlsdr into main
All checks were successful
Build Sphinx Docs Set / Build Docs (push) Successful in 12s
Test with tox / Test with tox (3.11) (push) Successful in 32s
Test with tox / Test with tox (3.12) (push) Successful in 30s
Test with tox / Test with tox (3.10) (push) Successful in 43s
Build Project / Build Project (3.10) (push) Successful in 49s
Build Project / Build Project (3.11) (push) Successful in 49s
Build Project / Build Project (3.12) (push) Successful in 49s
Reviewed-on: #11
Reviewed-by: madrigal <madrigal@qoherent.ai>
2025-11-14 11:13:06 -05:00
ca45c5e86d Added extra blacklist items to rtlsdr.rst
All checks were successful
Build Sphinx Docs Set / Build Docs (pull_request) Successful in 15s
Test with tox / Test with tox (3.11) (pull_request) Successful in 31s
Test with tox / Test with tox (3.12) (pull_request) Successful in 32s
Test with tox / Test with tox (3.10) (pull_request) Successful in 42s
Build Project / Build Project (3.10) (pull_request) Successful in 51s
Build Project / Build Project (3.11) (pull_request) Successful in 49s
Build Project / Build Project (3.12) (pull_request) Successful in 49s
2025-11-14 11:01:51 -05:00
61ea2c2a0a Merge pull request 'pytorch-widgets' (#10) from pytorch-widgets into main
All checks were successful
Build Sphinx Docs Set / Build Docs (push) Successful in 17s
Test with tox / Test with tox (3.12) (push) Successful in 33s
Test with tox / Test with tox (3.10) (push) Successful in 40s
Build Project / Build Project (3.10) (push) Successful in 54s
Build Project / Build Project (3.11) (push) Successful in 51s
Build Project / Build Project (3.12) (push) Successful in 49s
Test with tox / Test with tox (3.11) (push) Successful in 23s
Reviewed-on: #10
Reviewed-by: madrigal <madrigal@qoherent.ai>
2025-11-13 11:02:18 -05:00
c237164a68 Merge branch 'main' into pytorch-widgets
All checks were successful
Build Sphinx Docs Set / Build Docs (pull_request) Successful in 17s
Test with tox / Test with tox (3.11) (pull_request) Successful in 33s
Test with tox / Test with tox (3.12) (pull_request) Successful in 33s
Build Project / Build Project (3.10) (pull_request) Successful in 54s
Test with tox / Test with tox (3.10) (pull_request) Successful in 45s
Build Project / Build Project (3.11) (pull_request) Successful in 54s
Build Project / Build Project (3.12) (pull_request) Successful in 52s
2025-11-13 10:32:26 -05:00
ben
48f6b303f5 Pytorch Widgets
All checks were successful
Build Sphinx Docs Set / Build Docs (pull_request) Successful in 54s
Test with tox / Test with tox (3.12) (pull_request) Successful in 38s
Test with tox / Test with tox (3.11) (pull_request) Successful in 1m9s
Test with tox / Test with tox (3.10) (pull_request) Successful in 2m3s
Build Project / Build Project (3.10) (pull_request) Successful in 2m15s
Build Project / Build Project (3.11) (pull_request) Successful in 2m13s
Build Project / Build Project (3.12) (pull_request) Successful in 2m13s
2025-10-31 12:12:24 -04:00
ben
b8ccead21e Updated to Version 0.1.4
Some checks failed
Build Sphinx Docs Set / Build Docs (pull_request) Successful in 18s
Test with tox / Test with tox (3.12) (pull_request) Successful in 36s
Test with tox / Test with tox (3.11) (pull_request) Successful in 45s
Build Project / Build Project (3.10) (pull_request) Successful in 59s
Build Project / Build Project (3.12) (pull_request) Successful in 56s
Build Project / Build Project (3.11) (pull_request) Successful in 1m0s
Test with tox / Test with tox (3.10) (pull_request) Failing after 29s
2025-10-27 11:50:03 -04:00
ben
c21b522a67 Format Fixes 2025-10-24 08:59:44 -04:00
241 changed files with 40813 additions and 2180 deletions

12
.gitignore vendored
View File

@ -52,6 +52,7 @@ tests/sdr/
# Sphinx documentation # Sphinx documentation
docs/build/ docs/build/
docs/_build/
# Jupyter Notebook # Jupyter Notebook
.ipynb_checkpoints .ipynb_checkpoints
@ -88,3 +89,14 @@ cython_debug/
# pyenv # pyenv
.python-version .python-version
# Generated files
*.dot
*.hdf5
*.npy
*.png
*.sigmf-data
*.sigmf-meta
*.blue
*.wav
images/

36
CHANGELOG.md Normal file
View File

@ -0,0 +1,36 @@
# Changelog
## [0.1.0] - 2026-02-20
### Added
- **Dual-Threshold Detection:** Logic to capture the start and end of signals, not just the peak.
- **Signal Smoothing & Noise Filters:** Prevents detections from breaking into fragments and ignores short interference spikes.
- **Auto-Frequency Calculation:** Automatically adjusts bounding boxes to fit signal frequency ranges tightly.
### Changed
- **Signal Power Detection:** Switched from raw signal strength to power for improved accuracy.
- **CLI Workflow:** `Clear` and `Remove` commands now modify files directly (in-place) to avoid redundant copies.
- **Metadata Logic:** Updated labels to show detection percentages and overhauled internal metadata cleaning.
- **Viewer UI:** Moved legend outside the plot, added a black background, and adjusted transparency for better spectrogram visibility.
### Fixed
- Prevented redundant `_annotated` suffixes in file naming patterns.
- Simplified internal math to increase processing speed and precision.
All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/) and [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
---
## [0.1.1] - 2026-03-20
### Added
- **Campaign orchestration** — new `orchestration` module that manages the full lifecycle of an RF data collection campaign: SDR capture, automatic labeling, QA checks, and dataset packaging.
- **HTTP inference server**`ria-server` command starts a REST API server for deploying campaigns and controlling live inference from external systems such as the RIA Hub platform.
- **Campaign CLI**`ria campaign` commands for starting, monitoring, and managing campaigns from the terminal.
### Changed
- **Visualization layout** — recording and dataset views have been reformatted with improved sizing, repositioned titles, and updated Qoherent branding.
---

View File

@ -159,7 +159,7 @@ Finally, RIA Toolkit OSS can be installed directly from the source code. This ap
Once the project is installed, you can import modules, functions, and classes from the Toolkit for use in your Python code. For example, you can use the following import statement to access the `Recording` object: Once the project is installed, you can import modules, functions, and classes from the Toolkit for use in your Python code. For example, you can use the following import statement to access the `Recording` object:
```python ```python
from ria_toolkit_oss.datatypes import Recording from ria_toolkit_oss.data import Recording
``` ```
Additional usage information is provided in the project documentation: [RIA Toolkit OSS Documentation](https://ria-toolkit-oss.readthedocs.io/). Additional usage information is provided in the project documentation: [RIA Toolkit OSS Documentation](https://ria-toolkit-oss.readthedocs.io/).

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,29 @@
/* Change the hex values below to customize heading colours */
.rst-content h1 { color: #2c3e50; }
.rst-content h2,
.rst-content h2 a { color: #ffffff !important; font-size: 22px !important; }
.rst-content h3,
.rst-content h3 a { color: #ffffff !important; font-size: 16px !important; }
.rst-content h3 code { font-size: inherit !important; }
.rst-content .admonition.warning {
background: #1a1a2e !important;
border-left: 4px solid #c0392b !important;
}
.rst-content .admonition.warning .admonition-title {
background: #c0392b !important;
color: #ffffff !important;
}
.rst-content .admonition.warning p {
color: #ffffff !important;
}
.rst-content h4 { color: #404040; }
.highlight * { color: #ffffff !important; }
.ria-cmd { color: #2980b9 !important; }

View File

@ -0,0 +1,8 @@
document.addEventListener('DOMContentLoaded', function () {
document.querySelectorAll('.highlight pre').forEach(function (pre) {
pre.innerHTML = pre.innerHTML.replace(
/((?:^|\n|>))(ria)(?=[ \t]|<)/g,
'$1<span class="ria-cmd">$2</span>'
);
});
});

View File

@ -12,9 +12,9 @@ sys.path.insert(0, os.path.abspath(os.path.join('..', '..')))
# https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information # https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information
project = 'ria-toolkit-oss' project = 'ria-toolkit-oss'
copyright = '2025, Qoherent Inc' copyright = '2026, Qoherent Inc'
author = 'Qoherent Inc.' author = 'Qoherent Inc.'
release = '0.1.3' release = '0.1.5'
# -- General configuration --------------------------------------------------- # -- General configuration ---------------------------------------------------
# https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration # https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration
@ -73,3 +73,6 @@ def setup(app):
# https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output # https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output
html_theme = 'sphinx_rtd_theme' html_theme = 'sphinx_rtd_theme'
html_static_path = ['_static']
html_css_files = ['custom.css']
html_js_files = ['custom.js']

View File

@ -1,4 +1,4 @@
.. _examples: .. _sdr_examples:
############ ############
SDR Examples SDR Examples

View File

@ -25,7 +25,7 @@ In this example, we initialize the `Blade` SDR, configure it to record a signal
import time import time
from ria_toolkit_oss.datatypes.recording import Recording from ria_toolkit_oss.data.recording import Recording
from ria_toolkit_oss.sdr.blade import Blade from ria_toolkit_oss.sdr.blade import Blade
my_radio = Blade() my_radio = Blade()

View File

@ -21,7 +21,7 @@ Code
import numpy as np import numpy as np
from ria_toolkit_oss.datatypes.recording import Recording from ria_toolkit_oss.data.recording import Recording
from ria_toolkit_oss.sdr.blade import Blade from ria_toolkit_oss.sdr.blade import Blade
# Parameters # Parameters

File diff suppressed because it is too large Load Diff

View File

@ -11,15 +11,15 @@ The Radio Dataset Framework provides a software interface to access and manipula
the need for users to interface with the source files directly. Instead, users initialize and interact with a Python the need for users to interface with the source files directly. Instead, users initialize and interact with a Python
object, while the complexities of efficient data retrieval and source file manipulation are managed behind the scenes. object, while the complexities of efficient data retrieval and source file manipulation are managed behind the scenes.
Utils includes an abstract class called :py:obj:`ria_toolkit_oss.datatypes.datasets.RadioDataset`, which defines common properties and Ria Toolkit OSS includes an abstract class called :py:obj:`ria_toolkit_oss.data.datasets.RadioDataset`, which defines common properties and
behaviors for all radio datasets. :py:obj:`ria_toolkit_oss.datatypes.datasets.RadioDataset` can be considered a blueprint for all behaviors for all radio datasets. :py:obj:`ria_toolkit_oss.data.datasets.RadioDataset` can be considered a blueprint for all
other radio dataset classes. This class is then subclassed to define more specific blueprints for different types other radio dataset classes. This class is then subclassed to define more specific blueprints for different types
of radio datasets. For example, :py:obj:`ria_toolkit_oss.datatypes.datasets.IQDataset`, which is tailored for machine learning tasks of radio datasets. For example, :py:obj:`ria_toolkit_oss.data.datasets.IQDataset`, which is tailored for machine learning tasks
involving the processing of signals represented as IQ (In-phase and Quadrature) samples. involving the processing of signals represented as IQ (In-phase and Quadrature) samples.
Then, in the various project backends, there are concrete dataset classes, which inherit from both Utils and the base Then, in the various project backends, there are concrete dataset classes, which inherit from both Ria Toolkit OSS and the base
dataset class from the respective backend. For example, the :py:obj:`TorchIQDataset` class extends both dataset class from the respective backend. For example, the :py:obj:`TorchIQDataset` class extends both
:py:obj:`ria_toolkit_oss.datatypes.datasets.IQDataset` from Utils and :py:obj:`torch.ria_toolkit_oss.datatypes.IterableDataset` from :py:obj:`ria_toolkit_oss.data.datasets.IQDataset` from Ria Toolkit OSS and :py:obj:`torch.ria_toolkit_oss.data.IterableDataset` from
PyTorch, providing a concrete dataset class tailored for IQ datasets and optimized for the PyTorch backend. PyTorch, providing a concrete dataset class tailored for IQ datasets and optimized for the PyTorch backend.
Dataset initialization Dataset initialization
@ -130,7 +130,7 @@ Dataset processing and manipulation
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
All radio datasets support methods tailored specifically for radio processing. These methods are backend-independent, All radio datasets support methods tailored specifically for radio processing. These methods are backend-independent,
inherited from the blueprints in Utils like :py:obj:`ria_toolkit_oss.datatypes.datasets.RadioDataset`. inherited from the blueprints in Ria Toolkit OSS like :py:obj:`ria_toolkit_oss.data.datasets.RadioDataset`.
For example, we can trim down the length of the examples from 1,024 to 512 samples, and then augment the dataset: For example, we can trim down the length of the examples from 1,024 to 512 samples, and then augment the dataset:

View File

@ -1,7 +1,7 @@
Dataset License SubModule Dataset License SubModule
========================= =========================
.. automodule:: ria_toolkit_oss.datatypes.datasets.license .. automodule:: ria_toolkit_oss.data.datasets.license
:members: :members:
:undoc-members: :undoc-members:
:show-inheritance: :show-inheritance:

View File

@ -1,11 +1,11 @@
Datatypes Package (ria_toolkit_oss.datatypes) Datatypes Package (ria_toolkit_oss.data)
============================================= =============================================
.. |br| raw:: html .. |br| raw:: html
<br /> <br />
.. automodule:: ria_toolkit_oss.datatypes .. automodule:: ria_toolkit_oss.data
:members: :members:
:undoc-members: :undoc-members:
:show-inheritance: :show-inheritance:
@ -13,7 +13,7 @@ Datatypes Package (ria_toolkit_oss.datatypes)
Radio Dataset SubPackage Radio Dataset SubPackage
------------------------ ------------------------
.. automodule:: ria_toolkit_oss.datatypes.datasets .. automodule:: ria_toolkit_oss.data.datasets
:members: :members:
:undoc-members: :undoc-members:
:show-inheritance: :show-inheritance:
@ -21,5 +21,5 @@ Radio Dataset SubPackage
.. toctree:: .. toctree::
:maxdepth: 2 :maxdepth: 2
Dataset License SubModule <ria_toolkit_oss.datatypes.datasets.license> Dataset License SubModule <ria_toolkit_oss.data.datasets.license>
Radio Datasets <radio_datasets> Radio Datasets <radio_datasets>

View File

@ -11,7 +11,7 @@ class and function signatures, and doctest examples where available.
:maxdepth: 2 :maxdepth: 2
:caption: Contents: :caption: Contents:
Datatypes Package <datatypes/ria_toolkit_oss.datatypes> Data Package <data/ria_toolkit_oss.data>
SDR Package <ria_toolkit_oss.sdr> SDR Package <ria_toolkit_oss.sdr>
IO Package <ria_toolkit_oss.io> IO Package <ria_toolkit_oss.io>
Transforms Package <ria_toolkit_oss.transforms> Transforms Package <ria_toolkit_oss.transforms>

View File

@ -40,34 +40,44 @@ Limitations
- USB 3.0 connectivity is required for optimal performance; using USB 2.0 will significantly limit data - USB 3.0 connectivity is required for optimal performance; using USB 2.0 will significantly limit data
transfer rates. transfer rates.
Set up instructions (Linux, Radioconda) Set up instructions (Linux)
--------------------------------------- ---------------------------
1. Activate your Radioconda environment. No additional Python packages are required for BladeRF beyond the base RIA Toolkit OSS installation.
1. Install the system library:
.. code-block:: bash .. code-block:: bash
conda activate <your-env-name> sudo apt install libbladerf-dev
2. Install the base dependencies and drivers (*Easy method*): For a more complete installation including CLI tools and FPGA images, use the Nuand PPA:
.. code-block:: bash .. code-block:: bash
sudo add-apt-repository ppa:nuandllc/bladerf sudo add-apt-repository ppa:nuandllc/bladerf
sudo apt-get update sudo apt-get update
sudo apt-get install bladerf sudo apt-get install bladerf libbladerf-dev
sudo apt-get install libbladerf-dev sudo apt-get install bladerf-fpga-hostedxa4 # Necessary for BladeRF 2.0 Micro xA4
sudo apt-get install bladerf-fpga-hostedxa4 # Necessary for installation of bladeRF 2.0 Micro A4.
3. Install a ``udev`` rule by creating a link into your Radioconda installation: 2. Install udev rules:
For most users:
.. code-block:: bash .. code-block:: bash
sudo ln -s $CONDA_PREFIX/lib/udev/rules.d/88-nuand-bladerf1.rules /etc/udev/rules.d/88-radioconda-nuand-bladerf1.rules sudo udevadm control --reload
sudo ln -s $CONDA_PREFIX/lib/udev/rules.d/88-nuand-bladerf2.rules /etc/udev/rules.d/88-radioconda-nuand-bladerf2.rules sudo udevadm trigger
sudo ln -s $CONDA_PREFIX/lib/udev/rules.d/88-nuand-bootloader.rules /etc/udev/rules.d/88-radioconda-nuand-bootloader.rules
sudo udevadm control --reload For **Radioconda** users, create symlinks from your conda environment instead:
sudo udevadm trigger
.. code-block:: bash
sudo ln -s $CONDA_PREFIX/lib/udev/rules.d/88-nuand-bladerf1.rules /etc/udev/rules.d/88-radioconda-nuand-bladerf1.rules
sudo ln -s $CONDA_PREFIX/lib/udev/rules.d/88-nuand-bladerf2.rules /etc/udev/rules.d/88-radioconda-nuand-bladerf2.rules
sudo ln -s $CONDA_PREFIX/lib/udev/rules.d/88-nuand-bootloader.rules /etc/udev/rules.d/88-radioconda-nuand-bootloader.rules
sudo udevadm control --reload
sudo udevadm trigger
Further Information Further Information
------------------- -------------------

View File

@ -39,39 +39,44 @@ Limitations
- Bandwidth is limited to 20 MHz. - Bandwidth is limited to 20 MHz.
- USB 2.0 connectivity might limit data transfer rates compared to USB 3.0 or Ethernet-based SDRs. - USB 2.0 connectivity might limit data transfer rates compared to USB 3.0 or Ethernet-based SDRs.
Set up instructions (Linux, Radioconda) Set up instructions (Linux)
--------------------------------------- ---------------------------
1. Activate your Radioconda environment: HackRF is supported out of the box after installing RIA Toolkit OSS.
1. Ensure ``libhackrf`` is installed at the system level. On most Ubuntu installations this is already
present. If not:
.. code-block:: bash .. code-block:: bash
conda activate <your-env-name> sudo apt install libhackrf-dev
2. Install the System Package (Ubuntu / Debian): 2. Install udev rules to allow non-root device access:
For most users:
.. code-block:: bash .. code-block:: bash
sudo apt-get update sudo udevadm control --reload
sudo apt-get install hackrf sudo udevadm trigger
3. Install a ``udev`` rule by creating a link into your Radioconda installation: For **Radioconda** users, create a symlink from your conda environment instead:
.. code-block:: bash .. code-block:: bash
sudo ln -s $CONDA_PREFIX/lib/udev/rules.d/53-hackrf.rules /etc/udev/rules.d/53-radioconda-hackrf.rules sudo ln -s $CONDA_PREFIX/lib/udev/rules.d/53-hackrf.rules /etc/udev/rules.d/53-radioconda-hackrf.rules
sudo udevadm control --reload sudo udevadm control --reload
sudo udevadm trigger sudo udevadm trigger
Make sure your user account belongs to the plugdev group in order to access your device: Make sure your user account belongs to the ``plugdev`` group in order to access your device:
.. code-block:: bash .. code-block:: bash
sudo usermod -a -G plugdev <user> sudo usermod -a -G plugdev <user>
.. note:: .. note::
You may have to restart your system for changes to take effect. You may have to restart your system for group membership changes to take effect.
Further information Further information
------------------- -------------------

View File

@ -43,34 +43,34 @@ Limitations
affect stability. affect stability.
- USB 2.0 connectivity might limit data transfer rates compared to USB 3.0 or Ethernet-based SDRs. - USB 2.0 connectivity might limit data transfer rates compared to USB 3.0 or Ethernet-based SDRs.
Set up instructions (Linux, Radioconda) Set up instructions (Linux)
--------------------------------------- ---------------------------
1. Activate your Radioconda environment: The PlutoSDR is supported out of the box after installing RIA Toolkit OSS. The required Python package
(``pyadi-iio``) is included in the toolkit's dependencies.
1. Ensure ``libiio`` is installed at the system level. On most Ubuntu installations this is already present.
If not:
.. code-block:: bash .. code-block:: bash
conda activate <your-env-name> sudo apt install libiio-dev libiio-utils libiio0
2. Install system dependencies: .. note::
PlutoSDR devices are discoverable over both USB and network (mDNS). Network discovery uses Avahi — if
``avahi-daemon`` is not running, network discovery will be skipped but USB discovery still works.
2. Install a ``udev`` rule to allow non-root device access:
For most users:
.. code-block:: bash .. code-block:: bash
sudo apt-get update sudo udevadm control --reload
sudo apt-get install -y \ sudo udevadm trigger
build-essential \
git \
libxml2-dev \
bison \
flex \
libcdk5-dev \
cmake \
libusb-1.0-0-dev \
libavahi-client-dev \
libavahi-common-dev \
libaio-dev
3. Install a ``udev`` rule by creating a link into your Radioconda installation: For **Radioconda** users, create a symlink from your conda environment instead:
.. code-block:: bash .. code-block:: bash
@ -78,11 +78,18 @@ Set up instructions (Linux, Radioconda)
sudo udevadm control --reload sudo udevadm control --reload
sudo udevadm trigger sudo udevadm trigger
Once you can talk to the hardware, you may want to perform the post-install steps detailed on the `PlutoSDR Documentation <https://wiki.analog.com/university/tools/pluto>`_. Once you can communicate with the hardware, you may want to perform the post-install steps detailed on
the `PlutoSDR Documentation <https://wiki.analog.com/university/tools/pluto>`_.
4. (Optional) Building ``libiio`` or ``libad9361-iio`` from source: 3. (Optional) Building ``libiio`` or ``libad9361-iio`` from source:
This step is only required if you want the latest version of these libraries not provided in Radioconda. This step is only required if you need a version not available via ``apt``. First install build
dependencies:
.. code-block:: bash
sudo apt-get install -y build-essential git libxml2-dev bison flex libcdk5-dev cmake \
libusb-1.0-0-dev libavahi-client-dev libavahi-common-dev libaio-dev
.. code-block:: bash .. code-block:: bash

View File

@ -30,18 +30,10 @@ Limitations
- Sensitivity and performance can vary depending on the specific model and components. - Sensitivity and performance can vary depending on the specific model and components.
- Requires external software for signal processing and analysis. - Requires external software for signal processing and analysis.
Set up instructions (Linux, Radioconda) Set up instructions (Linux)
--------------------------------------- ---------------------------
1. Activate your Radioconda environment: 1. If you previously had RTL-SDR drivers installed, purge them first:
.. code-block:: bash
conda activate <your-env-name>
2. Purge drivers:
If you already have other drivers installed, purge them from your system.
.. code-block:: bash .. code-block:: bash
@ -53,34 +45,95 @@ If you already have other drivers installed, purge them from your system.
sudo rm -rvf /usr/local/include/rtl_* sudo rm -rvf /usr/local/include/rtl_*
sudo rm -rvf /usr/local/bin/rtl_* sudo rm -rvf /usr/local/bin/rtl_*
3. Install RTL-SDR Blog drivers: 2. Install build dependencies:
.. code-block:: bash .. code-block:: bash
sudo apt-get install libusb-1.0-0-dev git cmake pkg-config build-essential sudo apt install libusb-1.0-0-dev git cmake pkg-config build-essential
git clone https://github.com/osmocom/rtl-sdr
cd rtl-sdr 3. Build ``librtlsdr`` from source:
mkdir build
cd build The standard ``librtlsdr`` package available via ``apt`` is missing symbols required by the Python
cmake ../ -DINSTALL_UDEV_RULES=ON bindings. Build from the **rtl-sdr-blog fork**:
.. code-block:: bash
git clone https://github.com/rtlsdrblog/rtl-sdr-blog.git
cd rtl-sdr-blog
mkdir build && cd build
cmake .. -DINSTALL_UDEV_RULES=ON
make make
sudo make install sudo make install
sudo cp ../rtl-sdr.rules /etc/udev/rules.d/ sudo cp ../rtl-sdr.rules /etc/udev/rules.d/
sudo ldconfig sudo ldconfig
4. Blacklist the DVB-T modules that would otherwise claim the device: .. important::
Do not use the osmocom ``rtl-sdr`` repository or the Ubuntu ``librtlsdr-dev`` apt package. Neither
provides the ``rtlsdr_set_dithering`` symbol that the Python bindings require.
4. Blacklist the kernel DVB driver:
The kernel DVB-T driver (``dvb_usb_rtl28xxu``) claims the RTL-SDR device and prevents ``librtlsdr``
from accessing it.
For most users:
.. code-block:: bash .. code-block:: bash
echo 'blacklist dvb_usb_rtl28xxu' | sudo tee /etc/modprobe.d/blacklist-rtlsdr.conf
sudo modprobe -r dvb_usb_rtl28xxu
For **Radioconda** users, a blacklist configuration is already provided in your conda environment:
.. code-block:: bash
sudo ln -s $CONDA_PREFIX/etc/modprobe.d/rtl-sdr-blacklist.conf /etc/modprobe.d/radioconda-rtl-sdr-blacklist.conf sudo ln -s $CONDA_PREFIX/etc/modprobe.d/rtl-sdr-blacklist.conf /etc/modprobe.d/radioconda-rtl-sdr-blacklist.conf
sudo modprobe -r $(cat $CONDA_PREFIX/etc/modprobe.d/rtl-sdr-blacklist.conf | sed -n -e 's/^blacklist //p') sudo modprobe -r $(cat $CONDA_PREFIX/etc/modprobe.d/rtl-sdr-blacklist.conf | sed -n -e 's/^blacklist //p')
5. Install a udev rule by creating a link into your radioconda installation: If ``modprobe -r`` fails with "Module is in use", unplug the RTL-SDR dongle, run the command again,
then plug it back in. Alternatively, reboot — the blacklist takes effect on next boot.
.. note::
Some systems also require blacklisting additional DVB-T modules. Add these entries to your
blacklist configuration if needed:
- ``rtl2832``
- ``rtl2830``
5. Reload udev rules:
For most users (rules are installed by the build step above):
.. code-block:: bash .. code-block:: bash
sudo udevadm control --reload
sudo udevadm trigger
For **Radioconda** users, create a symlink from your conda environment instead:
.. code-block:: bash
sudo ln -s $CONDA_PREFIX/lib/udev/rules.d/rtl-sdr.rules /etc/udev/rules.d/radioconda-rtl-sdr.rules sudo ln -s $CONDA_PREFIX/lib/udev/rules.d/rtl-sdr.rules /etc/udev/rules.d/radioconda-rtl-sdr.rules
sudo udevadm control --reload sudo udevadm control --reload
sudo udevadm trigger sudo udevadm trigger
6. Install Python packages:
.. code-block:: bash
pip install pyrtlsdr==0.3.0
pip install setuptools==69.5.1
.. note::
``pyrtlsdr`` 0.4.0 references a ``rtlsdr_set_dithering`` symbol not present in standard
``librtlsdr`` builds. Version 0.3.0 works correctly.
``pyrtlsdr`` 0.3.0 depends on ``pkg_resources``, which was removed in ``setuptools`` >= 82.
Pinning to 69.5.1 ensures ``pkg_resources`` is available.
Further Information Further Information
------------------- -------------------
- `RTL-SDR Official Website <https://www.rtl-sdr.com/>`_ - `RTL-SDR Official Website <https://www.rtl-sdr.com/>`_

View File

@ -39,18 +39,48 @@ Limitations
Set up instructions (Linux) Set up instructions (Linux)
--------------------------------- ---------------------------------
Install PyRF ThinkRF devices require the ``pyrf`` package, which is written in Python 2 syntax and must be patched
after installation to work with Python 3.
.. note::
``lib2to3`` was fully removed in Python 3.13. ThinkRF support is currently limited to
**Python 3.12 and below**.
1. Install ``lib2to3``:
On some distributions (including Ubuntu 24.04+), ``lib2to3`` is not included by default:
.. code-block:: bash .. code-block:: bash
pip install 'pyrf>=2.8.0' sudo apt install python3-lib2to3
Convert PyRF scripts to Python 3 2. Install ``pyrf``:
.. code-block:: bash .. code-block:: bash
cd ../scripts pip install pyrf
./convert_pyrf_to_python3.sh
3. Patch ``pyrf`` for Python 3:
The ``pyrf`` package contains Python 2 syntax throughout (e.g., ``dict.iteritems()``, ``print``
statements). Run the following to automatically convert the entire package to Python 3:
.. code-block:: bash
python -c "
from lib2to3.refactor import RefactoringTool, get_fixers_from_package
import pyrf, os
pyrf_path = os.path.dirname(pyrf.__file__)
fixers = get_fixers_from_package('lib2to3.fixes')
tool = RefactoringTool(fixers)
tool.refactor_dir(pyrf_path, write=True)
print('Done')
"
.. note::
This patches the entire ``pyrf`` package in place, which is required for the driver to fully load.
Further Information Further Information
------------------- -------------------

View File

@ -41,48 +41,111 @@ Limitations
- Compatibility with certain software tools may vary depending on the version of the UHD. - Compatibility with certain software tools may vary depending on the version of the UHD.
- Price range can be a consideration, especially for high-end models. - Price range can be a consideration, especially for high-end models.
Set up instructions (Linux, Radioconda) Set up instructions (Linux)
--------------------------------------- ---------------------------
1. Activate your Radioconda environment: USRP devices require the UHD (USRP Hardware Driver) library with Python bindings. There is no pip-installable
UHD package — it must either be installed via conda or built from source.
**Option A: Install via conda (recommended for conda environments)**
.. code-block:: bash .. code-block:: bash
conda activate <your-env-name> conda install conda-forge::uhd
2. Install UHD and Python bindings: **Option B: Build from source (required for pip/venv environments)**
.. code-block:: bash The Python bindings must target the same Python version used in your virtual environment.
conda install conda-forge::uhd 1. Install build dependencies:
3. Download UHD images: .. code-block:: bash
sudo apt install cmake build-essential libboost-all-dev libusb-1.0-0-dev \
python3-dev python3-numpy libncurses-dev
2. Install the Mako template library into your virtual environment (used by UHD's build system):
.. code-block:: bash
pip install mako
3. Clone and build UHD with your virtual environment activated:
.. code-block:: bash
git clone https://github.com/EttusResearch/uhd.git
cd uhd
git checkout v4.7.0.0
cd host
mkdir build && cd build
cmake -DENABLE_PYTHON_API=ON -DPYTHON_EXECUTABLE=$(which python3) ..
make -j$(nproc)
sudo make install
sudo ldconfig
.. important::
Run the ``cmake`` command with your virtual environment activated so ``$(which python3)`` points
to the correct interpreter. Before running ``make``, verify the cmake output includes::
-- * LibUHD - Python API → must say "Enabling"
-- Python interpreter: .../your-venv/bin/python3
If "LibUHD - Python API" is not listed under enabled components, the Python bindings will not be
built. The build typically takes 1030 minutes.
4. Copy the Python bindings into your virtual environment if ``import uhd`` fails after installation:
.. code-block:: bash
cp -r ~/uhd/host/build/python/uhd ~/.venv/lib/python3.XX/site-packages/
Replace ``python3.XX`` with your Python version (e.g., ``python3.12``).
.. note::
If you have a pre-existing UHD installation built against a different Python version, you will see
a circular import error. The bindings must match the Python version in your virtual environment exactly.
**After either installation method:**
1. Download UHD FPGA/firmware images:
.. code-block:: bash .. code-block:: bash
uhd_images_downloader uhd_images_downloader
4. Verify access to your device: 2. Verify device access:
.. code-block:: bash .. code-block:: bash
uhd_find_devices uhd_find_devices
For USB devices only (e.g. B series), install a ``udev`` rule by creating a link into your Radioconda installation. For USB devices (e.g. B-series), install a ``udev`` rule.
.. code-block:: bash For most users:
sudo ln -s $CONDA_PREFIX/lib/uhd/utils/uhd-usrp.rules /etc/udev/rules.d/radioconda-uhd-usrp.rules .. code-block:: bash
sudo udevadm control --reload
sudo udevadm trigger
5. (Optional) Update firmware/FPGA images: sudo udevadm control --reload
sudo udevadm trigger
.. code-block:: bash For **Radioconda** users, create a symlink from your conda environment instead:
uhd_usrp_probe .. code-block:: bash
This will ensure your device is running the latest firmware and FPGA versions. sudo ln -s $CONDA_PREFIX/lib/uhd/utils/uhd-usrp.rules /etc/udev/rules.d/radioconda-uhd-usrp.rules
sudo udevadm control --reload
sudo udevadm trigger
3. (Optional) Update firmware/FPGA images:
.. code-block:: bash
uhd_usrp_probe
This will ensure your device is running the latest firmware and FPGA versions.
Further information Further information
------------------- -------------------

3311
poetry.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -1,6 +1,6 @@
[project] [project]
name = "ria-toolkit-oss" name = "ria-toolkit-oss"
version = "0.1.3" version = "0.1.5"
description = "An open-source version of the RIA Toolkit, including the fundamental tools to get started developing, testing, and deploying radio intelligence applications" description = "An open-source version of the RIA Toolkit, including the fundamental tools to get started developing, testing, and deploying radio intelligence applications"
license = { text = "AGPL-3.0-only" } license = { text = "AGPL-3.0-only" }
readme = "README.md" readme = "README.md"
@ -12,6 +12,7 @@ maintainers = [
{ name = "Benjamin Chinnery", email = "ben@qoherent.ai" }, { name = "Benjamin Chinnery", email = "ben@qoherent.ai" },
{ name = "Ashkan Beigi", email = "ash@qoherent.ai" }, { name = "Ashkan Beigi", email = "ash@qoherent.ai" },
{ name = "Madrigal Weersink", email = "madrigal@qoherent.ai" }, { name = "Madrigal Weersink", email = "madrigal@qoherent.ai" },
{ name = "Gillian Ford", email = "gillian@qoherent.ai" }
] ]
keywords = [ keywords = [
"radio", "radio",
@ -46,6 +47,10 @@ dependencies = [
"h5py (>=3.14.0,<4.0.0)", "h5py (>=3.14.0,<4.0.0)",
"pandas (>=2.3.2,<3.0.0)", "pandas (>=2.3.2,<3.0.0)",
"pyzmq (>=27.1.0,<28.0.0)", "pyzmq (>=27.1.0,<28.0.0)",
"pyyaml (>=6.0.3,<7.0.0)",
"click (>=8.1.0,<9.0.0)",
"matplotlib (>=3.8.0,<4.0.0)",
"paramiko (>=3.5.1)"
] ]
# [project.optional-dependencies] Commented out to prevent Tox tests from failing # [project.optional-dependencies] Commented out to prevent Tox tests from failing
@ -67,7 +72,8 @@ all-sdr = [
[tool.poetry] [tool.poetry]
packages = [ packages = [
{ include = "ria_toolkit_oss", from = "src" } { include = "ria_toolkit_oss", from = "src" },
{ include = "ria_toolkit_oss_cli", from = "src" }
] ]
include = [ include = [
"**/*.so", # Required for Nuitkaification "**/*.so", # Required for Nuitkaification
@ -80,15 +86,26 @@ build-backend = "poetry.core.masonry.api"
[tool.poetry.group.test.dependencies] [tool.poetry.group.test.dependencies]
pytest = "^8.0.0" pytest = "^8.0.0"
tox = "^4.19.0" tox = "^4.19.0"
fastapi = ">=0.111,<1.0"
uvicorn = {version = ">=0.29,<1.0", extras = ["standard"]}
onnxruntime = {version = ">=1.17,<2.0", python = ">=3.11"}
httpx = ">=0.27,<1.0"
[tool.poetry.group.docs.dependencies] [tool.poetry.group.docs.dependencies]
sphinx = "^7.2.6" sphinx = "^7.2.6"
sphinx-rtd-theme = "^2.0.0" sphinx-rtd-theme = "^2.0.0"
sphinx-autobuild = "^2024.2.4" sphinx-autobuild = "^2024.2.4"
[tool.poetry.group.agent]
optional = true
[tool.poetry.group.agent.dependencies]
requests = ">=2.28,<3.0"
websockets = ">=12.0,<14.0"
[tool.poetry.group.dev.dependencies] [tool.poetry.group.dev.dependencies]
flake8 = "^7.1.0" flake8 = "^7.1.0"
black = "^24.3.0" black = "^26.3.1"
isort = "^5.13.2" isort = "^5.13.2"
pylint = "^3.2.6" # For pyreverse, to automate the creation of UML diagrams pylint = "^3.2.6" # For pyreverse, to automate the creation of UML diagrams
@ -97,6 +114,18 @@ pylint = "^3.2.6" # For pyreverse, to automate the creation of UML diagrams
"Source" = "https://riahub.ai/qoherent/ria-toolkit-oss" "Source" = "https://riahub.ai/qoherent/ria-toolkit-oss"
"Issues Board" = "https://riahub.ai/qoherent/ria-toolkit-oss/issues" "Issues Board" = "https://riahub.ai/qoherent/ria-toolkit-oss/issues"
[tool.poetry.scripts]
ria = "ria_toolkit_oss_cli.cli:cli"
ria-tools = "ria_toolkit_oss_cli.cli:cli"
ria-server = "ria_toolkit_oss.server.cli:serve"
ria-agent = "ria_toolkit_oss.agent.cli:main"
ria-app = "ria_toolkit_oss.app.cli:main"
[tool.poetry.group.server.dependencies]
fastapi = ">=0.111,<1.0"
uvicorn = {version = ">=0.29,<1.0", extras = ["standard"]}
onnxruntime = {version = ">=1.17,<2.0", python = ">=3.11"}
[tool.black] [tool.black]
line-length = 119 line-length = 119
target-version = ["py310"] target-version = ["py310"]
@ -118,5 +147,13 @@ exclude = '''
)/ )/
''' '''
[tool.pytest.ini_options]
pythonpath = ["src"]
filterwarnings = [
# FastAPI emits this internally when handling 422 responses; the constant
# is not yet renamed in the installed starlette version, so we can't migrate.
"ignore:'HTTP_422_UNPROCESSABLE_ENTITY' is deprecated:DeprecationWarning",
]
[tool.isort] [tool.isort]
profile = "black" profile = "black"

225
scripts/pluto_tx_smoke.py Executable file
View File

@ -0,0 +1,225 @@
#!/usr/bin/env python3
"""Transmit a continuous tone through the agent's TX pipeline on a real Pluto.
End-to-end smoke test for the Pluto + Streamer TX path. Drives the same
``Streamer`` the hub talks to, but in-process with a logging ``FakeWs`` so
the script is self-contained no hub required.
Default: 100 kHz baseband tone × 2 450 MHz LO carrier at 2 450.1 MHz,
continuous until you Ctrl-C (or the ``--duration`` timer fires). A spectrum
analyzer tuned to 2 450.1 MHz should show a clean CW spike as long as
``tx_status: transmitting`` prints.
Usage::
python3 scripts/pluto_tx_smoke.py # auto-discover Pluto
python3 scripts/pluto_tx_smoke.py --identifier 192.168.3.1
python3 scripts/pluto_tx_smoke.py --frequency 2.4e9 --gain -20 --duration 60
Flags map 1:1 onto the agent's ``radio_config``:
--identifier Pluto IP or hostname (omitted ip:pluto.local).
--frequency TX LO in Hz. Default 2 450 MHz.
--gain Pluto TX gain in dB. Pluto range is ``[-89, 0]``; more negative
= more attenuation = less power. Default -30.
--sample-rate Baseband sample rate. Default 1 MHz.
--tone Baseband tone offset in Hz. Default 100 kHz; set 0 for DC
(unmodulated carrier at exactly --frequency, but Pluto's
LO leakage will dominate).
--buffer-size Complex samples per WS frame. Default 4096.
--duration Stop after this many seconds (0 = run until Ctrl-C).
"""
from __future__ import annotations
import argparse
import asyncio
import logging
import signal
import sys
import numpy as np
from ria_toolkit_oss.agent.config import AgentConfig
from ria_toolkit_oss.agent.streamer import Streamer
class LoggingFakeWs:
"""In-process stand-in for the hub's WebSocket.
Prints every ``tx_status`` + ``error`` frame the Streamer emits so the
operator can watch the lifecycle (armed transmitting done) on stdout.
"""
async def send_json(self, payload: dict) -> None:
t = payload.get("type")
if t == "tx_status":
state = payload.get("state")
msg = payload.get("message")
tail = f"{msg}" if msg else ""
print(f"[tx_status] {state}{tail}")
elif t == "error":
print(f"[error] {payload.get('message')}")
async def send_bytes(self, data: bytes) -> None:
# Agent side won't send RX bytes in this script (no RX session).
pass
def _make_iq_frame(
buffer_size: int, tone_hz: float, sample_rate: float, phase_offset: float = 0.0
) -> tuple[bytes, float]:
"""Return ``(interleaved_float32_bytes, next_phase)`` for a sine tone.
Emitting one continuous phase-coherent tone requires threading the phase
across frames; the returned ``next_phase`` should be fed back as
``phase_offset`` on the next call so the sinusoid doesn't glitch at frame
boundaries. Amplitude is 0.7 to leave some headroom below the [-1, 1] cap
that ``_verify_sample_format`` polices elsewhere in the toolkit.
"""
n = np.arange(buffer_size, dtype=np.float64)
phase = 2.0 * np.pi * tone_hz / sample_rate * n + phase_offset
amp = 0.7
iq = amp * (np.cos(phase) + 1j * np.sin(phase))
iq = iq.astype(np.complex64)
interleaved = np.empty(buffer_size * 2, dtype=np.float32)
interleaved[0::2] = iq.real
interleaved[1::2] = iq.imag
next_phase = (2.0 * np.pi * tone_hz / sample_rate * buffer_size + phase_offset) % (2.0 * np.pi)
return interleaved.tobytes(), next_phase
def _make_pluto_factory(identifier: str | None):
def factory(device: str, _ident: str | None):
if device != "pluto":
raise ValueError(f"this script only drives pluto; got device={device!r}")
from ria_toolkit_oss.sdr.pluto import Pluto
return Pluto(identifier=identifier)
return factory
async def _run(args: argparse.Namespace) -> int:
ws = LoggingFakeWs()
cfg = AgentConfig(
tx_enabled=True,
# Pluto's TX gain range is [-89, 0]. Cap at 0 so a fat-fingered
# --gain=+5 still gets rejected at the agent boundary rather than
# turned into mystery attenuation by Pluto's setter.
tx_max_gain_db=0.0,
tx_max_duration_s=float(args.duration) if args.duration > 0 else None,
)
streamer = Streamer(ws=ws, sdr_factory=_make_pluto_factory(args.identifier), cfg=cfg)
await streamer.on_message(
{
"type": "tx_start",
"app_id": "smoke",
"radio_config": {
"device": "pluto",
"identifier": args.identifier,
"tx_sample_rate": int(args.sample_rate),
"tx_center_frequency": int(args.frequency),
"tx_gain": int(args.gain),
"buffer_size": int(args.buffer_size),
# "repeat" keeps the last buffer on the air if we ever stall,
# so a continuous carrier stays up even when Python GC or
# asyncio scheduling briefly pauses the producer.
"underrun_policy": "repeat",
},
}
)
# Abort if tx_start was rejected by an interlock (no session → nothing to do).
if streamer._tx is None:
print("tx_start rejected — see [tx_status] line above for the reason.", file=sys.stderr)
return 2
print(
f"Transmitting at {args.frequency/1e6:.3f} MHz with "
f"{args.tone/1e3:.1f} kHz baseband tone at gain {args.gain} dB. "
f"{'Running for ' + str(args.duration) + 's' if args.duration > 0 else 'Run until Ctrl-C'}."
)
# Arrange a clean shutdown on Ctrl-C.
stop = asyncio.Event()
loop = asyncio.get_running_loop()
for sig in (signal.SIGINT, signal.SIGTERM):
try:
loop.add_signal_handler(sig, stop.set)
except NotImplementedError:
# add_signal_handler is not available on Windows event loops.
pass
# Produce buffers at the nominal sample-rate pace. We deliberately stay
# slightly ahead of the radio — queue is bounded at 8, so backpressure
# flows naturally.
phase = 0.0
buffer_dt = args.buffer_size / args.sample_rate
# Aim for one buffer every ``buffer_dt * 0.5`` seconds so the queue stays
# topped up. The queue's own backpressure keeps us from spinning.
produce_interval = buffer_dt * 0.5
try:
async def producer():
nonlocal phase
while not stop.is_set():
frame, phase = _make_iq_frame(args.buffer_size, args.tone, args.sample_rate, phase)
await streamer.on_binary(frame)
await asyncio.sleep(produce_interval)
producer_task = asyncio.create_task(producer())
if args.duration > 0:
try:
await asyncio.wait_for(stop.wait(), timeout=args.duration)
except asyncio.TimeoutError:
pass
else:
await stop.wait()
stop.set()
producer_task.cancel()
try:
await producer_task
except (asyncio.CancelledError, Exception):
pass
finally:
await streamer.on_message({"type": "tx_stop", "app_id": "smoke"})
print("TX session closed.")
return 0
def main() -> int:
p = argparse.ArgumentParser(
description="End-to-end TX smoke test: agent → Pluto continuous tone.",
)
p.add_argument("--identifier", default=None, help="Pluto IP/hostname (default: auto-discover pluto.local)")
p.add_argument("--frequency", type=float, default=3_410_000_000.0, help="TX LO in Hz (default 2.45 GHz)")
p.add_argument("--gain", type=float, default=-0.0, help="TX gain in dB; Pluto range [-89, 0] (default -30)")
p.add_argument("--sample-rate", type=float, default=1_000_000.0, help="Baseband sample rate (default 1 Msps)")
p.add_argument(
"--tone", type=float, default=100_000.0, help="Baseband tone offset in Hz; 0 = DC (default 100 kHz)"
)
p.add_argument("--buffer-size", type=int, default=4096, help="Complex samples per frame (default 4096)")
p.add_argument(
"--duration", type=float, default=60.0, help="Seconds to transmit; 0 = run until Ctrl-C (default 30)"
)
p.add_argument("--log-level", default="INFO")
args = p.parse_args()
logging.basicConfig(
level=getattr(logging, args.log_level.upper(), logging.INFO),
format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
)
try:
return asyncio.run(_run(args))
except KeyboardInterrupt:
return 130
if __name__ == "__main__":
sys.exit(main())

230
scripts/pluto_tx_ws_smoke.py Executable file
View File

@ -0,0 +1,230 @@
#!/usr/bin/env python3
"""Full-stack TX smoke test: localhost mock-hub → WS → agent → real Pluto.
Same radio output as ``pluto_tx_smoke.py`` (continuous tone at 2 450.1 MHz),
but drives the agent through the *real* WebSocket path instead of calling
handlers in-process. Proves that the hub-driven path behaves identically:
mock hub ws:// WsClient.run() Streamer.on_message
Streamer.on_binary
real Pluto
This is the most rigorous check short of pointing the real ``ria-agent stream``
at a live ria-hub. If a tone appears on the spectrum analyzer here but *not*
when ria-hub drives it, the fault is above the WS decoder (registration,
capability gate, TX operator, hub's binary-frame publisher); everything
downstream of ``ws.recv()`` is this script's code path.
Usage::
python3 scripts/pluto_tx_ws_smoke.py # default 30s tone
python3 scripts/pluto_tx_ws_smoke.py --identifier 192.168.3.1
python3 scripts/pluto_tx_ws_smoke.py --duration 0 # until Ctrl-C
"""
from __future__ import annotations
import argparse
import asyncio
import json
import logging
import signal
import sys
import numpy as np
import websockets
from ria_toolkit_oss.agent.config import AgentConfig
from ria_toolkit_oss.agent.streamer import Streamer
from ria_toolkit_oss.agent.ws_client import WsClient
def _make_iq_frame(buffer_size: int, tone_hz: float, sample_rate: float, phase_offset: float) -> tuple[bytes, float]:
n = np.arange(buffer_size, dtype=np.float64)
phase = 2.0 * np.pi * tone_hz / sample_rate * n + phase_offset
amp = 0.7
iq = (amp * (np.cos(phase) + 1j * np.sin(phase))).astype(np.complex64)
interleaved = np.empty(buffer_size * 2, dtype=np.float32)
interleaved[0::2] = iq.real
interleaved[1::2] = iq.imag
next_phase = (2.0 * np.pi * tone_hz / sample_rate * buffer_size + phase_offset) % (2.0 * np.pi)
return interleaved.tobytes(), next_phase
def _make_pluto_factory(identifier: str | None):
def factory(device: str, _ident: str | None):
if device != "pluto":
raise ValueError(f"this script only drives pluto; got device={device!r}")
from ria_toolkit_oss.sdr.pluto import Pluto
return Pluto(identifier=identifier)
return factory
async def _mock_hub_handler(ws, args, stop: asyncio.Event):
"""Server side of the WS. Sends tx_start, streams IQ, then tx_stop."""
# Drain the first heartbeat so the log is clean; we don't need to gate on
# it for a localhost smoke test.
try:
first = await asyncio.wait_for(ws.recv(), timeout=2.0)
if isinstance(first, str):
payload = json.loads(first)
if payload.get("type") == "heartbeat":
caps = payload.get("capabilities")
print(f"[mock-hub] agent heartbeat: capabilities={caps} " f"tx_enabled={payload.get('tx_enabled')}")
except asyncio.TimeoutError:
print("[mock-hub] warning: no heartbeat received in first 2s")
# Arm the agent's TX path.
await ws.send(
json.dumps(
{
"type": "tx_start",
"app_id": "ws-smoke",
"radio_config": {
"device": "pluto",
"identifier": args.identifier,
"tx_sample_rate": int(args.sample_rate),
"tx_center_frequency": int(args.frequency),
"tx_gain": int(args.gain),
"buffer_size": int(args.buffer_size),
"underrun_policy": "repeat",
},
}
)
)
print(f"[mock-hub] sent tx_start at {args.frequency/1e6:.3f} MHz, " f"gain={args.gain} dB")
# Producer: push IQ frames at a steady clip. Use a concurrent receiver so
# tx_status frames show up in real time rather than being queued behind
# the sends.
phase = 0.0
buffer_dt = args.buffer_size / args.sample_rate
async def receiver():
try:
while True:
msg = await ws.recv()
if isinstance(msg, str):
print(f"[mock-hub] ← {msg}")
except (websockets.ConnectionClosed, asyncio.CancelledError):
pass
recv_task = asyncio.create_task(receiver())
try:
deadline = None if args.duration <= 0 else (asyncio.get_event_loop().time() + args.duration)
while not stop.is_set():
if deadline is not None and asyncio.get_event_loop().time() >= deadline:
break
frame, phase = _make_iq_frame(args.buffer_size, args.tone, args.sample_rate, phase)
try:
await ws.send(frame)
except websockets.ConnectionClosed:
break
# Slightly ahead of real-time; WS backpressure handles the rest.
await asyncio.sleep(buffer_dt * 0.5)
finally:
try:
await ws.send(json.dumps({"type": "tx_stop", "app_id": "ws-smoke"}))
print("[mock-hub] sent tx_stop")
except websockets.ConnectionClosed:
pass
# Give the agent a moment to emit `tx_status: done` before we tear down.
await asyncio.sleep(0.3)
recv_task.cancel()
try:
await recv_task
except (asyncio.CancelledError, Exception):
pass
async def _run(args: argparse.Namespace) -> int:
stop = asyncio.Event()
loop = asyncio.get_running_loop()
for sig in (signal.SIGINT, signal.SIGTERM):
try:
loop.add_signal_handler(sig, stop.set)
except NotImplementedError:
pass
# Start the mock hub on a local port.
async def handler(ws):
try:
await _mock_hub_handler(ws, args, stop)
finally:
stop.set()
server = await websockets.serve(handler, "127.0.0.1", 0)
port = server.sockets[0].getsockname()[1]
print(f"[mock-hub] listening on ws://127.0.0.1:{port}")
# Run the agent — exactly as ``ria-agent stream`` would, just with a
# different URL and an in-memory AgentConfig instead of one loaded from
# ``~/.ria/agent.json``.
client = WsClient(
f"ws://127.0.0.1:{port}",
token="",
heartbeat_interval=5.0,
reconnect_pause=0.5,
)
streamer = Streamer(
ws=client,
sdr_factory=_make_pluto_factory(args.identifier),
cfg=AgentConfig(tx_enabled=True, tx_max_gain_db=0.0),
)
client_task = asyncio.create_task(
client.run(
on_message=streamer.on_message,
heartbeat=streamer.build_heartbeat,
on_binary=streamer.on_binary,
)
)
try:
await stop.wait()
finally:
client.stop()
client_task.cancel()
try:
await client_task
except (asyncio.CancelledError, Exception):
pass
server.close()
await server.wait_closed()
print("Done.")
return 0
def main() -> int:
p = argparse.ArgumentParser(
description="Full-stack TX smoke: localhost mock-hub → WS → agent → Pluto.",
)
p.add_argument("--identifier", default=None, help="Pluto IP/hostname (default: auto-discover pluto.local)")
p.add_argument("--frequency", type=float, default=2_450_000_000.0, help="TX LO in Hz (default 2.45 GHz)")
p.add_argument("--gain", type=float, default=0.0, help="TX gain in dB; Pluto range [-89, 0] (default 0)")
p.add_argument("--sample-rate", type=float, default=1_000_000.0, help="Baseband sample rate (default 1 Msps)")
p.add_argument("--tone", type=float, default=100_000.0, help="Baseband tone offset in Hz (default 100 kHz)")
p.add_argument("--buffer-size", type=int, default=4096, help="Complex samples per frame (default 4096)")
p.add_argument(
"--duration", type=float, default=30.0, help="Seconds to transmit; 0 = run until Ctrl-C (default 30)"
)
p.add_argument("--log-level", default="INFO")
args = p.parse_args()
logging.basicConfig(
level=getattr(logging, args.log_level.upper(), logging.INFO),
format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
)
try:
return asyncio.run(_run(args))
except KeyboardInterrupt:
return 130
if __name__ == "__main__":
sys.exit(main())

0
src/__init__.py Normal file
View File

View File

@ -0,0 +1,26 @@
"""RIA Toolkit agent package.
Provides two execution modes:
- **Legacy long-poll executor** (`NodeAgent` in :mod:`legacy_executor`) an
HTTP long-polling agent that runs ONNX inference locally on the host.
- **Streamer** (:mod:`streamer`) a thin WebSocket client that opens an SDR
and streams raw IQ to the RIA Hub server, which performs all inference.
Back-compat: ``from ria_toolkit_oss.agent import NodeAgent`` and the ``main``
entry point continue to work.
"""
from __future__ import annotations
from .legacy_executor import NodeAgent
from .legacy_executor import main as _legacy_main
__all__ = ["NodeAgent", "main"]
def main() -> None:
"""Unified CLI entry point. Dispatches to streamer/legacy subcommands."""
from .cli import main as _cli_main
_cli_main()

View File

@ -0,0 +1,212 @@
"""Unified ``ria-agent`` CLI.
Subcommands:
- ``ria-agent run [legacy args]`` legacy long-poll NodeAgent (unchanged).
- ``ria-agent stream`` new WebSocket-based IQ streamer.
- ``ria-agent detect`` print SDR drivers whose modules import cleanly.
- ``ria-agent register --hub URL --api-key KEY`` register with the hub and
save credentials (and optional TX interlocks) to ``~/.ria/agent.json``.
Invoking ``ria-agent`` with no subcommand falls through to the legacy
long-poll behavior for back-compatibility with existing deployments.
"""
from __future__ import annotations
import argparse
import asyncio
import json
import logging
import sys
from . import config as _config
from .hardware import available_devices
from .legacy_executor import main as _legacy_main
from .namegen import generate_agent_name
_LEGACY_ALIASES = {"--hub", "--key", "--name", "--device", "--insecure", "--log-level", "--config"}
def _cmd_detect(_args: argparse.Namespace) -> int:
devices = available_devices()
if not devices:
print("No SDR drivers available (install ria-toolkit-oss[all-sdr] or per-driver extras).")
return 0
for name in devices:
print(name)
return 0
def _cmd_register(args: argparse.Namespace) -> int:
import urllib.request
hub_url = args.hub.rstrip("/")
url = f"{hub_url}/screens/agents/register"
name = args.name or generate_agent_name()
body = json.dumps({"name": name}).encode()
req = urllib.request.Request(
url,
data=body,
headers={
"Content-Type": "application/json",
"X-API-Key": args.api_key,
},
)
try:
with urllib.request.urlopen(req) as resp:
data = json.loads(resp.read())
except Exception as e:
print(f"error: registration failed: {e}", file=sys.stderr)
return 1
agent_id = data["agent_id"]
token = data["token"]
cfg = _config.load()
cfg.hub_url = hub_url
cfg.agent_id = agent_id
cfg.token = token
cfg.api_key = args.api_key
cfg.name = name
cfg.insecure = bool(args.insecure)
cfg.tx_enabled = bool(getattr(args, "allow_tx", False))
if (v := getattr(args, "tx_max_gain_db", None)) is not None:
cfg.tx_max_gain_db = float(v)
if (v := getattr(args, "tx_max_duration_s", None)) is not None:
cfg.tx_max_duration_s = float(v)
freq_ranges = getattr(args, "tx_freq_range", None) or []
if freq_ranges:
cfg.tx_allowed_freq_ranges = [[float(lo), float(hi)] for lo, hi in freq_ranges]
path = _config.save(cfg)
print(f"Registered agent: {agent_id}")
if cfg.tx_enabled:
caps: list[str] = []
if cfg.tx_max_gain_db is not None:
caps.append(f"gain<={cfg.tx_max_gain_db} dB")
if cfg.tx_max_duration_s is not None:
caps.append(f"duration<={cfg.tx_max_duration_s} s")
if cfg.tx_allowed_freq_ranges:
caps.append(f"freq in {cfg.tx_allowed_freq_ranges}")
tail = f" ({', '.join(caps)})" if caps else ""
print(f"TX enabled{tail}")
print(f"Credentials saved to {path}")
return 0
def _cmd_stream(args: argparse.Namespace) -> int:
from .streamer import run_streamer
cfg = _config.load()
url = args.url or _derive_ws_url(cfg.hub_url, cfg.agent_id)
token = args.token or cfg.token
if not url:
print("error: --url is required (or run `ria-agent register` first)", file=sys.stderr)
return 2
if getattr(args, "allow_tx", False):
cfg.tx_enabled = True
try:
asyncio.run(run_streamer(url, token, cfg=cfg))
except KeyboardInterrupt:
pass
return 0
def _derive_ws_url(hub_url: str, agent_id: str) -> str:
if not hub_url:
return ""
base = hub_url.rstrip("/")
if base.startswith("https://"):
base = "wss://" + base[len("https://") :]
elif base.startswith("http://"):
base = "ws://" + base[len("http://") :]
suffix = f"/screens/agent/ws?agent_id={agent_id}" if agent_id else "/screens/agent/ws"
return base + suffix
def main() -> None:
# Back-compat: if the first non-flag token matches a known legacy flag,
# or there is no subcommand at all, dispatch to the legacy CLI.
argv = sys.argv[1:]
if not argv or (argv[0].startswith("--") and argv[0] in _LEGACY_ALIASES):
_legacy_main()
return
parser = argparse.ArgumentParser(prog="ria-agent")
sub = parser.add_subparsers(dest="command", required=True)
sub.add_parser("run", help="Legacy long-poll agent (NodeAgent)")
sub.add_parser("detect", help="List available SDR drivers")
p_reg = sub.add_parser("register", help="Register agent with RIA Hub and save credentials")
p_reg.add_argument("--hub", required=True, help="RIA Hub URL (e.g. http://whitehorse:3005)")
p_reg.add_argument("--api-key", dest="api_key", required=True, help="Hub API key")
p_reg.add_argument("--name", default=None, help="Human-friendly agent name")
p_reg.add_argument("--insecure", action="store_true", help="Skip TLS verification")
p_reg.add_argument(
"--allow-tx",
dest="allow_tx",
action="store_true",
help="Opt this agent in to TX (required for any transmission from the hub)",
)
p_reg.add_argument(
"--tx-max-gain-db",
dest="tx_max_gain_db",
type=float,
default=None,
help="Reject tx_start frames whose tx_gain exceeds this cap (dB)",
)
p_reg.add_argument(
"--tx-max-duration-s",
dest="tx_max_duration_s",
type=float,
default=None,
help="Auto-stop any TX session after this many seconds",
)
p_reg.add_argument(
"--tx-freq-range",
dest="tx_freq_range",
type=float,
nargs=2,
action="append",
metavar=("LO", "HI"),
default=None,
help="Allowed TX center-frequency range in Hz (repeat for multiple bands)",
)
p_stream = sub.add_parser("stream", help="Run the WebSocket IQ streamer")
p_stream.add_argument("--url", default=None, help="Override WebSocket URL")
p_stream.add_argument("--token", default=None, help="Override bearer token")
p_stream.add_argument("--log-level", default="INFO")
p_stream.add_argument(
"--allow-tx",
dest="allow_tx",
action="store_true",
help="Runtime override: enable TX for this process without writing config",
)
# Unknown extras are forwarded to the legacy CLI when command == "run".
args, extras = parser.parse_known_args(argv)
logging.basicConfig(
level=getattr(logging, getattr(args, "log_level", "INFO"), logging.INFO),
format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
)
if args.command == "run":
sys.argv = [sys.argv[0], *extras]
_legacy_main()
return
if args.command == "detect":
sys.exit(_cmd_detect(args))
if args.command == "register":
sys.exit(_cmd_register(args))
if args.command == "stream":
sys.exit(_cmd_stream(args))
parser.error(f"unknown command: {args.command}")
if __name__ == "__main__":
main()

View File

@ -0,0 +1,89 @@
"""Agent configuration stored at ``~/.ria/agent.json``.
Schema::
{
"hub_url": "https://riahub.example.com",
"agent_id": "agent-abc123",
"token": "rha_xxxx",
"name": "lab-bench-1",
"insecure": false,
"tx_enabled": false,
"tx_max_gain_db": null,
"tx_max_duration_s": null,
"tx_allowed_freq_ranges": null
}
"""
from __future__ import annotations
import json
import os
from dataclasses import asdict, dataclass, field
from pathlib import Path
def _resolve_default_path() -> Path:
return Path(os.environ.get("RIA_AGENT_CONFIG", str(Path.home() / ".ria" / "agent.json")))
@dataclass
class AgentConfig:
hub_url: str = ""
agent_id: str = ""
token: str = ""
name: str = ""
insecure: bool = False
api_key: str = ""
tx_enabled: bool = False
tx_max_gain_db: float | None = None
tx_max_duration_s: float | None = None
tx_allowed_freq_ranges: list[list[float]] | None = None
extra: dict = field(default_factory=dict)
def default_path() -> Path:
return _resolve_default_path()
def _coerce_ranges(raw) -> list[list[float]] | None:
if raw is None:
return None
out: list[list[float]] = []
for pair in raw:
lo, hi = pair
out.append([float(lo), float(hi)])
return out
def load(path: Path | None = None) -> AgentConfig:
p = path or _resolve_default_path()
if not p.exists():
return AgentConfig()
data = json.loads(p.read_text())
known = {f for f in AgentConfig.__dataclass_fields__ if f != "extra"}
extra = {k: v for k, v in data.items() if k not in known}
return AgentConfig(
hub_url=data.get("hub_url", ""),
agent_id=data.get("agent_id", ""),
token=data.get("token", ""),
name=data.get("name", ""),
insecure=bool(data.get("insecure", False)),
api_key=data.get("api_key", ""),
tx_enabled=bool(data.get("tx_enabled", False)),
tx_max_gain_db=(float(v) if (v := data.get("tx_max_gain_db")) is not None else None),
tx_max_duration_s=(float(v) if (v := data.get("tx_max_duration_s")) is not None else None),
tx_allowed_freq_ranges=_coerce_ranges(data.get("tx_allowed_freq_ranges")),
extra=extra,
)
def save(cfg: AgentConfig, path: Path | None = None) -> Path:
p = path or _resolve_default_path()
p.parent.mkdir(parents=True, exist_ok=True)
data = asdict(cfg)
extra = data.pop("extra", {}) or {}
data.update(extra)
p.write_text(json.dumps(data, indent=2))
os.chmod(p, 0o600)
return p

View File

@ -0,0 +1,54 @@
"""Hardware detection and heartbeat payload construction for the streamer."""
from __future__ import annotations
from ria_toolkit_oss.sdr import detect_available
from .config import AgentConfig
def available_devices() -> list[str]:
"""Return a sorted list of device names whose driver modules import cleanly."""
return sorted(detect_available().keys())
def heartbeat_payload(
status: str = "idle",
app_id: str | None = None,
*,
cfg: AgentConfig | None = None,
sessions: dict | None = None,
) -> dict:
"""Build the JSON body of a periodic heartbeat frame.
*cfg* drives the ``capabilities`` list and the ``tx_enabled`` flag. If not
supplied, the heartbeat advertises RX-only with ``tx_enabled=False``
matching the pre-TX shape.
"""
c = cfg or AgentConfig()
capabilities = ["rx"]
if c.tx_enabled:
capabilities.append("tx")
payload: dict = {
"type": "heartbeat",
"hardware": available_devices(),
"status": status,
"capabilities": capabilities,
"tx_enabled": bool(c.tx_enabled),
}
# Surface configured interlock values so the hub can pre-filter UI controls
# before sending a tx_start that would be rejected. Only included when TX
# is opted in AND the operator set a cap.
if c.tx_enabled:
if c.tx_max_gain_db is not None:
payload["tx_max_gain_db"] = float(c.tx_max_gain_db)
if c.tx_max_duration_s is not None:
payload["tx_max_duration_s"] = float(c.tx_max_duration_s)
if c.tx_allowed_freq_ranges:
payload["tx_allowed_freq_ranges"] = [[float(lo), float(hi)] for lo, hi in c.tx_allowed_freq_ranges]
if app_id:
payload["app_id"] = app_id
if sessions:
payload["sessions"] = sessions
return payload

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,147 @@
"""Generate random human-readable agent names.
Produces names in the form ``adjective-colour-animal``, e.g.
``swift-teal-falcon`` or ``brave-coral-otter``. All words are chosen
to be friendly and inoffensive.
"""
from __future__ import annotations
import random
ADJECTIVES: list[str] = [
"brave",
"bright",
"calm",
"clever",
"cool",
"daring",
"eager",
"fair",
"fancy",
"fast",
"fierce",
"gentle",
"grand",
"happy",
"jolly",
"keen",
"kind",
"lively",
"lucky",
"mighty",
"noble",
"plucky",
"proud",
"quick",
"quiet",
"sharp",
"shiny",
"sleek",
"smart",
"steady",
"stellar",
"strong",
"sturdy",
"sunny",
"sure",
"swift",
"tall",
"vivid",
"warm",
"wise",
]
COLOURS: list[str] = [
"amber",
"aqua",
"azure",
"beige",
"blue",
"bronze",
"coral",
"copper",
"crimson",
"cyan",
"denim",
"gold",
"green",
"grey",
"indigo",
"ivory",
"jade",
"lemon",
"lilac",
"lime",
"maroon",
"mint",
"navy",
"olive",
"onyx",
"peach",
"pearl",
"plum",
"red",
"rose",
"ruby",
"rust",
"sage",
"sand",
"scarlet",
"silver",
"slate",
"steel",
"teal",
"violet",
]
ANIMALS: list[str] = [
"badger",
"bear",
"bison",
"crane",
"deer",
"dolphin",
"eagle",
"elk",
"falcon",
"finch",
"fox",
"gecko",
"hawk",
"heron",
"horse",
"ibis",
"jaguar",
"jay",
"kite",
"koala",
"lark",
"lion",
"lynx",
"marten",
"moose",
"newt",
"orca",
"osprey",
"otter",
"owl",
"panda",
"puma",
"raven",
"robin",
"salmon",
"seal",
"shark",
"stork",
"swift",
"wolf",
]
def generate_agent_name() -> str:
"""Return a random ``adjective-colour-animal`` name."""
adj = random.choice(ADJECTIVES)
col = random.choice(COLOURS)
ani = random.choice(ANIMALS)
return f"{adj}-{col}-{ani}"

View File

@ -0,0 +1,747 @@
"""IQ-streaming agent.
Listens for control messages from the RIA Hub over a persistent WebSocket.
Supports:
- An **RX session** (hub sends ``start``/``stop``/``configure``; agent opens
the SDR, loops ``sdr.rx()`` and ships raw interleaved float32 IQ).
- A **TX session** (hub sends ``tx_start``/``tx_stop``/``tx_configure`` plus
binary IQ frames; agent feeds them into ``sdr._stream_tx``). Phase 3 wires
up the session plumbing and rejects TX when ``cfg.tx_enabled`` is False;
Phase 4 implements the full TX loop.
Both sessions can run concurrently on the same physical SDR (FDD) a
ref-counted SDR registry shares one driver instance when RX and TX name the
same ``(device, identifier)``.
"""
from __future__ import annotations
import asyncio
import logging
import queue
import threading
import time
from dataclasses import dataclass, field
from typing import Any
import numpy as np
from .config import AgentConfig
from .hardware import heartbeat_payload
from .ws_client import WsClient
logger = logging.getLogger("ria_agent.streamer")
_DEFAULT_BUFFER_SIZE = 1024
# ---------------------------------------------------------------------------
# Session dataclasses
@dataclass
class RxSession:
app_id: str
sdr: Any
device_key: tuple[str, str | None]
buffer_size: int
task: asyncio.Task | None = None
pending_config: dict = field(default_factory=dict)
@dataclass
class TxSession:
app_id: str
sdr: Any
device_key: tuple[str, str | None]
buffer_size: int
task: Any = None # concurrent.futures.Future from run_in_executor
pending_config: dict = field(default_factory=dict)
underrun_policy: str = "pause"
last_buffer: np.ndarray | None = None
stop_event: threading.Event = field(default_factory=threading.Event)
started_at: float = 0.0
max_duration_s: float | None = None
state: str = "armed"
# Thread-safe queue of inbound interleaved-float32 IQ frames. Bounded so
# hub-side over-production triggers WS backpressure rather than memory
# growth in the agent.
in_queue: "queue.Queue[bytes]" = field(default_factory=lambda: queue.Queue(maxsize=8))
# Set by the TX callback when it hits an underrun while policy=="pause";
# asyncio side flips the session state and emits tx_status.
underrun_flag: threading.Event = field(default_factory=threading.Event)
# ---------------------------------------------------------------------------
# SDR registry (ref-counted so one Pluto handle serves RX + TX simultaneously)
class _SdrRegistry:
def __init__(self, factory):
self._factory = factory
self._instances: dict[tuple[str, str | None], tuple[Any, int]] = {}
self._lock = threading.Lock()
def acquire(self, device: str, identifier: str | None) -> tuple[Any, tuple[str, str | None]]:
key = (device, identifier)
with self._lock:
if key in self._instances:
sdr, rc = self._instances[key]
self._instances[key] = (sdr, rc + 1)
return sdr, key
# Build outside the lock: driver init can be slow and we don't want to
# block concurrent releases on unrelated devices.
sdr = self._factory(device, identifier)
with self._lock:
if key in self._instances:
# Raced another acquirer; discard our duplicate and share theirs.
other_sdr, rc = self._instances[key]
try:
sdr.close()
except Exception:
pass
self._instances[key] = (other_sdr, rc + 1)
return other_sdr, key
self._instances[key] = (sdr, 1)
return sdr, key
def release(self, key: tuple[str, str | None]) -> bool:
"""Decrement refcount. Returns True if the caller owns the last reference
and should close the SDR."""
with self._lock:
sdr, rc = self._instances.get(key, (None, 0))
if sdr is None:
return False
if rc <= 1:
del self._instances[key]
return True
self._instances[key] = (sdr, rc - 1)
return False
def refcount(self, key: tuple[str, str | None]) -> int:
with self._lock:
return self._instances.get(key, (None, 0))[1]
# ---------------------------------------------------------------------------
# Streamer
class Streamer:
"""Main streamer loop.
Parameters
----------
ws:
Connected :class:`WsClient`.
sdr_factory:
Callable ``(device, identifier) -> SDR``. Defaults to the helper in
:mod:`ria_toolkit_oss.sdr`. Injectable for tests.
cfg:
:class:`AgentConfig` for interlocks (``tx_enabled`` and caps) and
heartbeat capabilities. Defaults to an empty ``AgentConfig()`` which
leaves TX disabled.
"""
def __init__(
self,
ws,
sdr_factory=None,
cfg: AgentConfig | None = None,
) -> None:
self.ws = ws
self._cfg = cfg or AgentConfig()
self._registry = _SdrRegistry(sdr_factory or _default_sdr_factory)
self._rx: RxSession | None = None
self._tx: TxSession | None = None
# Pending radio_config accepted via ``configure`` before ``start``.
self._standalone_pending_config: dict = {}
# Cached asyncio event loop, set the first time a handler runs. Used
# to schedule async callbacks from the TX executor thread.
self._loop: asyncio.AbstractEventLoop | None = None
# ------------------------------------------------------------------
# Back-compat read-only shims for callers that check ``._sdr`` etc.
# Writes to these attributes are not supported — use the session objects.
@property
def _sdr(self):
return self._rx.sdr if self._rx is not None else None
@property
def _pending_config(self) -> dict:
return self._rx.pending_config if self._rx is not None else self._standalone_pending_config
# ------------------------------------------------------------------
# WsClient wiring
def build_heartbeat(self) -> dict:
status = "streaming" if (self._rx is not None or self._tx is not None) else "idle"
app_id: str | None = None
if self._rx is not None:
app_id = self._rx.app_id
elif self._tx is not None:
app_id = self._tx.app_id
sessions: dict[str, dict] = {}
if self._rx is not None:
sessions["rx"] = {"app_id": self._rx.app_id, "state": "streaming"}
if self._tx is not None:
sessions["tx"] = {"app_id": self._tx.app_id, "state": self._tx.state}
return heartbeat_payload(
status=status,
app_id=app_id,
cfg=self._cfg,
sessions=sessions or None,
)
# Advisory / keepalive message types we accept and ignore without warning.
_IGNORED_MESSAGE_TYPES = frozenset({"tx_data_available"})
async def on_message(self, msg: dict) -> None:
t = msg.get("type")
if t in self._IGNORED_MESSAGE_TYPES:
logger.debug("Ignoring advisory message: %r", t)
return
handler = {
"start": self._handle_rx_start,
"stop": self._handle_rx_stop,
"configure": self._handle_rx_configure,
"tx_start": self._handle_tx_start,
"tx_stop": self._handle_tx_stop,
"tx_configure": self._handle_tx_configure,
}.get(t)
if handler is None:
logger.warning("Unknown server message type: %r", t)
return
await handler(msg)
async def on_binary(self, data: bytes) -> None:
tx = self._tx
if tx is None:
logger.debug("Dropping %d-byte binary frame: no TX session", len(data))
return
# Backpressure: if the TX queue is full, await briefly so the hub's
# ``await ws.send`` throttles naturally via TCP. We don't block
# indefinitely — a 2s stall means something else is wrong.
loop = asyncio.get_running_loop()
try:
await loop.run_in_executor(None, lambda: tx.in_queue.put(data, timeout=2.0))
except queue.Full:
logger.warning("TX queue stalled; dropping frame")
# ==================================================================
# RX
async def _handle_rx_start(self, msg: dict) -> None:
if self._rx is not None:
logger.warning("start received while already streaming — ignoring")
return
app_id = msg.get("app_id") or ""
radio_config = dict(msg.get("radio_config") or {})
device = radio_config.pop("device", None)
identifier = radio_config.pop("identifier", None)
buffer_size = int(radio_config.pop("buffer_size", _DEFAULT_BUFFER_SIZE))
if not device:
await self._send_error(app_id, "start missing radio_config.device")
return
try:
sdr, device_key = self._registry.acquire(device, identifier)
_apply_sdr_config(sdr, radio_config)
except Exception as exc:
logger.exception("Failed to open SDR %r", device)
await self._send_error(app_id, f"SDR init failed: {exc}")
return
# Inherit any pending config that was queued before start.
pending = dict(self._standalone_pending_config)
self._standalone_pending_config = {}
session = RxSession(
app_id=app_id,
sdr=sdr,
device_key=device_key,
buffer_size=buffer_size,
pending_config=pending,
)
self._rx = session
await self._send_status("streaming", app_id)
session.task = asyncio.create_task(self._capture_loop(session), name="ria-streamer-capture")
async def _handle_rx_stop(self, msg: dict) -> None:
session = self._rx
if session is None:
return
if session.task is not None:
session.task.cancel()
try:
await session.task
except (asyncio.CancelledError, Exception):
pass
self._close_session_sdr(session)
app_id = session.app_id
self._rx = None
await self._send_status("idle", app_id)
async def _handle_rx_configure(self, msg: dict) -> None:
cfg = dict(msg.get("radio_config") or {})
if self._rx is not None:
self._rx.pending_config.update(cfg)
else:
self._standalone_pending_config.update(cfg)
logger.debug("Queued configure: %s", cfg)
async def _capture_loop(self, session: RxSession) -> None:
loop = asyncio.get_running_loop()
try:
while True:
if session.pending_config:
cfg = session.pending_config
session.pending_config = {}
try:
_apply_sdr_config(session.sdr, cfg)
except Exception as exc:
logger.warning("Applying configure failed: %s", exc)
try:
samples = await loop.run_in_executor(None, session.sdr.rx, session.buffer_size)
except Exception as exc:
from ria_toolkit_oss.sdr import SdrDisconnectedError
if isinstance(exc, SdrDisconnectedError):
logger.warning("SDR disconnected: %s", exc)
await self._send_error(session.app_id, f"SDR disconnected: {exc}")
else:
logger.exception("SDR rx error")
await self._send_error(session.app_id, f"SDR capture failed: {exc}")
break
payload = _samples_to_interleaved_float32(samples)
try:
await self.ws.send_bytes(payload)
except Exception as exc:
logger.warning("Send failed: %s — ending capture", exc)
break
except asyncio.CancelledError:
raise
finally:
self._close_session_sdr(session)
# If the loop died on its own (e.g. SDR disconnect), clear the
# session handle so future ``start`` messages can proceed.
if self._rx is session:
self._rx = None
# ==================================================================
# TX
async def _handle_tx_start(self, msg: dict) -> None: # noqa: C901
app_id = msg.get("app_id") or ""
radio_config = dict(msg.get("radio_config") or {})
# --- interlocks (agent-enforced; never trust the hub alone) ---
if not self._cfg.tx_enabled:
await self._send_tx_status(app_id, "error", "tx disabled on this agent")
return
tx_gain = radio_config.get("tx_gain")
if (
self._cfg.tx_max_gain_db is not None
and tx_gain is not None
and float(tx_gain) > float(self._cfg.tx_max_gain_db)
):
await self._send_tx_status(
app_id,
"error",
f"tx_gain {tx_gain} exceeds cap {self._cfg.tx_max_gain_db}",
)
return
tx_freq = radio_config.get("tx_center_frequency")
if self._cfg.tx_allowed_freq_ranges and tx_freq is not None:
f = float(tx_freq)
if not any(float(lo) <= f <= float(hi) for lo, hi in self._cfg.tx_allowed_freq_ranges):
await self._send_tx_status(
app_id,
"error",
f"tx_center_frequency {tx_freq} outside allowed ranges",
)
return
if self._tx is not None:
await self._send_tx_status(app_id, "error", "tx already active on this agent")
return
# --- device ---
device = radio_config.pop("device", None)
identifier = radio_config.pop("identifier", None)
buffer_size = int(radio_config.pop("buffer_size", _DEFAULT_BUFFER_SIZE))
underrun_policy = str(radio_config.pop("underrun_policy", "pause"))
if underrun_policy not in ("pause", "zero", "repeat"):
await self._send_tx_status(app_id, "error", f"invalid underrun_policy {underrun_policy!r}")
return
if not device:
await self._send_tx_status(app_id, "error", "tx_start missing radio_config.device")
return
device_key: tuple[str, str | None] | None = None
sdr: Any = None
try:
sdr, device_key = self._registry.acquire(device, identifier)
_apply_sdr_config(sdr, radio_config)
# init_tx is mandatory for any driver that exposes it: drivers
# that gate _stream_tx on _tx_initialized (Pluto, HackRF, USRP,
# …) crash with a confusing "TX was not initialized" error 2 s
# later in the executor thread if we skip it. Treat the three
# required keys as a hard contract — a missing one is a hub-side
# manifest bug and we want it surfaced immediately, not papered
# over with stale radio state.
if hasattr(sdr, "init_tx"):
init_args = {k: radio_config.get(f"tx_{k}") for k in ("sample_rate", "center_frequency", "gain")}
missing = [f"tx_{k}" for k, v in init_args.items() if v is None]
if missing:
raise ValueError(f"tx_start missing required radio_config keys: {missing}")
sdr.init_tx(
sample_rate=init_args["sample_rate"],
center_frequency=init_args["center_frequency"],
gain=init_args["gain"],
channel=radio_config.get("tx_channel", 0),
gain_mode=radio_config.get("tx_gain_mode", "manual"),
)
except Exception as exc:
if device_key is not None:
if self._registry.release(device_key):
try:
sdr.close()
except Exception:
pass
logger.exception("Failed to init TX on %r", device)
await self._send_tx_status(app_id, "error", f"tx init failed: {exc}")
return
self._loop = asyncio.get_running_loop()
session = TxSession(
app_id=app_id,
sdr=sdr,
device_key=device_key,
buffer_size=buffer_size,
underrun_policy=underrun_policy,
started_at=time.monotonic(),
max_duration_s=self._cfg.tx_max_duration_s,
)
self._tx = session
await self._send_tx_status(app_id, "armed")
session.task = self._loop.run_in_executor(None, self._tx_executor_body, session)
# Spawn a small watchdog that transitions armed → transmitting when
# the first buffer has been consumed, and surfaces underrun / max-
# duration terminations back to the hub.
asyncio.create_task(self._tx_watchdog(session))
async def _handle_tx_stop(self, msg: dict) -> None:
session = self._tx
if session is None:
return
app_id = session.app_id
session.stop_event.set()
try:
session.sdr.pause_tx()
except Exception:
logger.debug("pause_tx raised during stop", exc_info=True)
# Wake the executor thread if it's blocked on ``queue.get``.
self._drain_tx_queue(session)
if session.task is not None:
try:
await asyncio.wait_for(asyncio.wrap_future(session.task), timeout=1.5)
except asyncio.TimeoutError:
logger.warning("TX executor did not exit within 1.5s after stop")
except Exception:
logger.debug("TX executor raised on shutdown", exc_info=True)
self._close_session_sdr(session)
self._tx = None
await self._send_tx_status(app_id, "done")
async def _handle_tx_configure(self, msg: dict) -> None:
if self._tx is None:
return
self._tx.pending_config.update(msg.get("radio_config") or {})
# ------------------------------------------------------------------
# TX executor & watchdog
def _tx_executor_body(self, session: TxSession) -> None:
try:
session.sdr._stream_tx(lambda n: self._tx_callback(session, n))
except Exception as exc:
logger.exception("TX stream crashed")
# Schedule both the error frame and session teardown on the loop
# so ``self._tx`` clears, subsequent binary frames are rejected,
# and the SDR handle is released.
self._schedule(self._tx_crash_teardown(session, str(exc)))
def _tx_callback(self, session: TxSession, num_samples) -> np.ndarray:
n = int(num_samples)
# Honor stop requests: return silence one last time and let the driver
# exit its loop on the next iteration (pause_tx flips _enable_tx).
if session.stop_event.is_set():
return _silence(n)
# Max-duration watchdog.
if session.max_duration_s is not None and (time.monotonic() - session.started_at) >= float(
session.max_duration_s
):
session.stop_event.set()
try:
session.sdr.pause_tx()
except Exception:
pass
self._schedule(self._send_tx_status(session.app_id, "done", "max duration reached"))
return _silence(n)
# Apply queued configure at buffer boundary.
if session.pending_config:
cfg = session.pending_config
session.pending_config = {}
try:
_apply_sdr_config(session.sdr, cfg)
except Exception as exc:
logger.debug("tx_configure apply failed: %s", exc)
try:
raw = session.in_queue.get(timeout=0.1)
except queue.Empty:
return self._underrun_fill(session, n)
arr = np.frombuffer(raw, dtype=np.float32)
if arr.size < 2 or arr.size % 2 != 0:
logger.warning("Malformed TX frame: %d floats (must be non-zero even count)", arr.size)
return self._underrun_fill(session, n)
samples = arr[0::2].astype(np.complex64) + 1j * arr[1::2].astype(np.complex64)
if samples.size < n:
out = np.zeros(n, dtype=np.complex64)
out[: samples.size] = samples
session.last_buffer = out
return out
if samples.size > n:
samples = samples[:n]
session.last_buffer = samples
if session.state == "armed":
session.state = "transmitting"
self._schedule(self._send_tx_status(session.app_id, "transmitting"))
return samples
def _underrun_fill(self, session: TxSession, n: int) -> np.ndarray:
policy = session.underrun_policy
if policy == "zero":
return _silence(n)
if policy == "repeat" and session.last_buffer is not None:
buf = session.last_buffer
if buf.size == n:
return buf
if buf.size > n:
return buf[:n].copy()
out = np.zeros(n, dtype=np.complex64)
out[: buf.size] = buf
return out
# "pause" policy (default) or "repeat" before any buffer arrived.
if not session.underrun_flag.is_set():
session.underrun_flag.set()
session.stop_event.set()
try:
session.sdr.pause_tx()
except Exception:
pass
return _silence(n)
async def _tx_watchdog(self, session: TxSession) -> None:
# Poll the underrun flag so we can emit status + tear down cleanly
# when the callback flips the flag from the executor thread. Check
# underrun_flag before stop_event, since the "pause" path sets both.
while session is self._tx:
if session.underrun_flag.is_set():
await self._send_tx_status(session.app_id, "underrun")
await self._teardown_tx_after_underrun(session)
return
if session.stop_event.is_set():
return
await asyncio.sleep(0.05)
async def _tx_crash_teardown(self, session: TxSession, message: str) -> None:
# Called from the executor thread via _schedule when _stream_tx raises.
# Emit the error, mark stopped, drain the queue, release the SDR.
await self._send_tx_status(session.app_id, "error", f"tx stream crashed: {message}")
if self._tx is not session:
return
session.stop_event.set()
self._drain_tx_queue(session)
self._close_session_sdr(session)
if self._tx is session:
self._tx = None
async def _teardown_tx_after_underrun(self, session: TxSession) -> None:
if self._tx is not session:
return
self._drain_tx_queue(session)
if session.task is not None:
try:
await asyncio.wait_for(asyncio.wrap_future(session.task), timeout=1.0)
except asyncio.TimeoutError:
logger.warning("TX executor did not exit within 1s after underrun")
except Exception:
logger.debug("TX executor raised during underrun teardown", exc_info=True)
self._close_session_sdr(session)
if self._tx is session:
self._tx = None
def _drain_tx_queue(self, session: TxSession) -> None:
try:
while True:
session.in_queue.get_nowait()
except queue.Empty:
pass
def _schedule(self, coro) -> None:
loop = self._loop
if loop is None:
return
try:
asyncio.run_coroutine_threadsafe(coro, loop)
except Exception:
logger.debug("_schedule failed", exc_info=True)
# ==================================================================
# Helpers
def _close_session_sdr(self, session) -> None:
if session.sdr is None:
return
should_close = self._registry.release(session.device_key)
if should_close:
try:
session.sdr.close()
except Exception:
logger.debug("SDR close raised", exc_info=True)
async def _send_status(self, status: str, app_id: str) -> None:
try:
await self.ws.send_json({"type": "status", "status": status, "app_id": app_id})
except Exception as exc:
logger.debug("Status send failed: %s", exc)
async def _send_error(self, app_id: str, message: str) -> None:
try:
await self.ws.send_json({"type": "error", "app_id": app_id, "message": message})
except Exception as exc:
logger.debug("Error-frame send failed: %s", exc)
async def _send_tx_status(self, app_id: str, state: str, message: str | None = None) -> None:
payload: dict = {"type": "tx_status", "app_id": app_id, "state": state}
if message is not None:
payload["message"] = message
try:
await self.ws.send_json(payload)
except Exception as exc:
logger.debug("tx_status send failed: %s", exc)
# ---------------------------------------------------------------------------
# Helpers
_CONFIG_ATTR_MAP = {
"sample_rate": ("sample_rate", "rx_sample_rate"),
"center_frequency": ("center_freq", "rx_center_frequency"),
"center_freq": ("center_freq", "rx_center_frequency"),
"gain": ("gain", "rx_gain"),
"bandwidth": ("bandwidth", "rx_bandwidth"),
"tx_sample_rate": ("tx_sample_rate",),
"tx_center_frequency": ("tx_center_frequency", "tx_lo"),
"tx_gain": ("tx_gain",),
"tx_bandwidth": ("tx_bandwidth",),
}
def _is_stub_setter(method: Any) -> bool:
"""True when *method* is an unimplemented base-class stub.
The ``SDR`` abstract base defines ``set_rx_sample_rate`` / ``set_tx_gain``
etc. as zero-argument ``NotImplementedError`` stubs. A driver (Pluto) that
actually transmits overrides them with a real ``(value, ...)`` signature.
Comparing ``__qualname__`` against ``SDR.`` lets us skip the stubs cheaply.
"""
return getattr(method, "__qualname__", "").startswith("SDR.")
def _apply_sdr_config(sdr: Any, cfg: dict) -> None:
"""Apply a radio_config dict to an SDR.
Prefers ``sdr.set_<attr>(value)`` when the driver implements it Pluto's
setters take ``_param_lock``, so routing through them keeps concurrent
RX + TX reconfigures from racing on shared native attributes. Falls back
to ``setattr`` for drivers (MockSDR, tests) that don't override the
base-class stubs.
"""
for key, value in cfg.items():
if value is None:
continue
attrs = _CONFIG_ATTR_MAP.get(key, (key,))
applied = False
for attr in attrs:
setter = getattr(sdr, f"set_{attr}", None)
if callable(setter) and not _is_stub_setter(setter):
try:
setter(value)
applied = True
break
except Exception as exc:
logger.debug("set_%s(%r) failed: %s", attr, value, exc)
# Fall through to setattr; some drivers may partially
# implement setters.
if applied:
continue
for attr in attrs:
if hasattr(sdr, attr):
try:
setattr(sdr, attr, value)
applied = True
break
except Exception as exc:
logger.debug("setattr %s=%r failed: %s", attr, value, exc)
if not applied:
logger.debug("radio_config key %r ignored (no matching attr)", key)
def _silence(num_samples: int) -> np.ndarray:
"""Return a ``num_samples``-length zero-filled complex64 buffer."""
return np.zeros(int(num_samples), dtype=np.complex64)
def _samples_to_interleaved_float32(samples: Any) -> bytes:
"""Convert complex IQ samples (any numeric dtype) to interleaved float32 bytes."""
arr = np.asarray(samples)
if np.iscomplexobj(arr):
interleaved = np.empty(arr.size * 2, dtype=np.float32)
interleaved[0::2] = arr.real.astype(np.float32, copy=False).ravel()
interleaved[1::2] = arr.imag.astype(np.float32, copy=False).ravel()
return interleaved.tobytes()
return arr.astype(np.float32, copy=False).tobytes()
def _default_sdr_factory(device: str, identifier: str | None):
from ria_toolkit_oss.sdr import get_sdr_device
return get_sdr_device(device, ident=identifier)
# ---------------------------------------------------------------------------
# Top-level entry
async def run_streamer(ws_url: str, token: str, *, cfg: AgentConfig | None = None) -> None:
"""Connect to *ws_url* and run the streamer loop until cancelled."""
ws = WsClient(ws_url, token)
streamer = Streamer(ws, cfg=cfg)
await ws.run(
streamer.on_message,
streamer.build_heartbeat,
on_binary=streamer.on_binary,
)

View File

@ -0,0 +1,128 @@
"""Persistent WebSocket client for the streamer agent.
Handles connection lifecycle: connect, heartbeat, auto-reconnect on drop.
The caller drives the I/O loop via ``run()`` with a message handler callback.
"""
from __future__ import annotations
import asyncio
import json
import logging
from typing import Awaitable, Callable
logger = logging.getLogger("ria_agent.ws")
MessageHandler = Callable[[dict], Awaitable[None]]
HeartbeatBuilder = Callable[[], dict]
BinaryHandler = Callable[[bytes], Awaitable[None]]
class WsClient:
"""Persistent WebSocket connection with heartbeat and auto-reconnect.
``url`` should be a full ``wss://host/path`` (or ``ws://``) URL. ``token``
is sent as a bearer in the ``Authorization`` header on connect.
"""
def __init__(
self,
url: str,
token: str,
*,
heartbeat_interval: float = 30.0,
reconnect_pause: float = 5.0,
) -> None:
self.url = url
self.token = token
self.heartbeat_interval = heartbeat_interval
self.reconnect_pause = reconnect_pause
self._ws = None
self._stop = asyncio.Event()
# ------------------------------------------------------------------
async def _connect(self):
import websockets
headers = [("Authorization", f"Bearer {self.token}")] if self.token else None
# websockets >= 12 accepts additional_headers; fall back to extra_headers for older versions.
try:
return await websockets.connect(self.url, additional_headers=headers)
except TypeError:
return await websockets.connect(self.url, extra_headers=headers)
# ------------------------------------------------------------------
async def send_json(self, payload: dict) -> None:
if self._ws is None:
raise ConnectionError("WebSocket is not connected")
await self._ws.send(json.dumps(payload))
async def send_bytes(self, data: bytes) -> None:
if self._ws is None:
raise ConnectionError("WebSocket is not connected")
await self._ws.send(data)
def stop(self) -> None:
self._stop.set()
# ------------------------------------------------------------------
async def run(
self,
on_message: MessageHandler,
heartbeat: HeartbeatBuilder,
on_binary: BinaryHandler | None = None,
) -> None:
"""Main loop: connect, heartbeat, dispatch messages, reconnect on drop."""
while not self._stop.is_set():
try:
self._ws = await self._connect()
logger.info("Connected to %s", self.url)
hb_task = asyncio.create_task(self._heartbeat_loop(heartbeat))
try:
async for raw in self._ws:
if isinstance(raw, bytes):
if on_binary is None:
logger.debug("Discarding unexpected %d-byte binary frame", len(raw))
continue
try:
await on_binary(raw)
except Exception:
logger.exception("on_binary handler raised; dropping frame")
continue
try:
msg = json.loads(raw)
except json.JSONDecodeError:
logger.warning("Malformed control frame: %r", raw[:200])
continue
await on_message(msg)
finally:
hb_task.cancel()
try:
await hb_task
except (asyncio.CancelledError, Exception):
pass
except asyncio.CancelledError:
raise
except Exception as exc:
if self._stop.is_set():
break
logger.warning("WS error: %s — reconnecting in %.1fs", exc, self.reconnect_pause)
finally:
try:
if self._ws is not None:
await self._ws.close()
except Exception:
pass
self._ws = None
if self._stop.is_set():
break
await asyncio.sleep(self.reconnect_pause)
async def _heartbeat_loop(self, heartbeat: HeartbeatBuilder) -> None:
while True:
try:
await self.send_json(heartbeat())
except Exception as exc:
logger.debug("Heartbeat send failed: %s", exc)
return
await asyncio.sleep(self.heartbeat_interval)

View File

@ -0,0 +1,54 @@
"""
The annotations package contains tools and utilities for creating, managing, and processing annotations.
Provides automatic annotation generation using various signal detection algorithms:
- Energy-based detection (detect_signals_energy)
- CUSUM-based segmentation (annotate_with_cusum)
- Threshold-based qualification (threshold_qualifier)
- Signal isolation and extraction (isolate_signal)
- Occupied bandwidth analysis (calculate_occupied_bandwidth, calculate_nominal_bandwidth)
All detection functions return Recording objects with added annotations.
"""
__all__ = [
# Energy-based detection
"detect_signals_energy",
"calculate_occupied_bandwidth",
"calculate_nominal_bandwidth",
"calculate_full_detected_bandwidth",
"annotate_with_obw",
# CUSUM detection
"annotate_with_cusum",
# Threshold detection
"threshold_qualifier",
# Parallel signal separation (Phase 2)
"find_spectral_components",
"split_annotation_by_components",
"split_recording_annotations",
# Signal isolation
"isolate_signal",
# Annotation transforms
"remove_contained_boxes",
"is_annotation_contained",
# Dataset creation
"qualify_slice_from_annotations",
]
from .annotation_transforms import is_annotation_contained, remove_contained_boxes
from .cusum_annotator import annotate_with_cusum
from .energy_detector import (
annotate_with_obw,
calculate_full_detected_bandwidth,
calculate_nominal_bandwidth,
calculate_occupied_bandwidth,
detect_signals_energy,
)
from .parallel_signal_separator import (
find_spectral_components,
split_annotation_by_components,
split_recording_annotations,
)
from .qualify_slice import qualify_slice_from_annotations
from .signal_isolation import isolate_signal
from .threshold_qualifier import threshold_qualifier

View File

@ -0,0 +1,55 @@
from ria_toolkit_oss.data.annotation import Annotation
# TODO figure out how to transfer labels in the merge case
def remove_contained_boxes(annotations: list[Annotation]):
"""
Remove all annotations (bounding boxes) that are entirely contained within other boxes in the list.
:param annotations: A list of Annotation objects.
:type annotations: list[Annotation]
:returns: A new list of Annotation objects.
:rtype: list[Annotation]"""
output_boxes = []
for i in range(len(annotations)):
contained = False
for j in range(len(annotations)):
if i != j and is_annotation_contained(annotations[i], annotations[j]):
contained = True
break
if not contained:
output_boxes.append(annotations[i])
return output_boxes
def is_annotation_contained(inner: Annotation, outer: Annotation) -> bool:
"""
Check if an annotation box is entirely contained within another annotation bounding box.
:param inner: The inner box.
:type inner: Annotation.
:param outer: The outer box.
:type outer: Annotation.
:returns: True if inner is within outer, false otherwise.
:rtype: bool
"""
inner_sample_stop = inner.sample_start + inner.sample_count
outer_sample_stop = outer.sample_start + outer.sample_count
if inner.sample_start > outer.sample_start and inner_sample_stop < outer_sample_stop:
if inner.freq_lower_edge > outer.freq_lower_edge and inner.freq_upper_edge < outer.freq_upper_edge:
return True
return False
def merge_annotations(annotations: list[Annotation], overlap_threshold) -> list[Annotation]:
raise NotImplementedError

View File

@ -0,0 +1,203 @@
import json
from typing import Optional
import numpy as np
from ria_toolkit_oss.data import Annotation, Recording
def annotate_with_cusum(
recording: Recording,
label: Optional[str] = "segment",
window_size: Optional[int] = 1,
min_duration: Optional[float] = None,
tolerance: Optional[int] = None,
annotation_type: Optional[str] = "standalone",
):
"""
Add annotations that divide the recording into distinct time segments.
This algorithm computes the cumulative sum of the sample magnitudes and
determines break points in the signal.
This tool can be used to find points where a signal turns on or off, or
changes between a low and high amplitude.
:param recording: A ``Recording`` object to annotate.
:type recording: ``ria_toolkit_oss.data.Recording``
:param label: Label for the detected segments.
:type label: str
:param window_size: The length (in samples) of the moving average window.
:type window_size: int
:param min_duration: The minimum duration (in ms) of a segment.
The algorithm will not produce annotations shorter than this length.
:type min_duration: float
:param tolerance: The minimum length (in samples) of a segment.
:type tolerance: int
:param annotation_type: Annotation type (standalone, parallel, intersection).
:type annotation_type: str
"""
sample_rate = recording.metadata["sample_rate"]
center_frequency = recording.metadata.get("center_frequency", 0)
# Create an object of the time segmenter
time_segmenter = TimeSegmenter(sample_rate, min_duration, window_size, tolerance)
change_points = time_segmenter.apply(recording.data[0])
time_segments_indices = np.append(np.insert(change_points, 0, 0), len(recording.data[0]))
annotations = []
for i in range(len(time_segments_indices) - 1):
# Build comment JSON with type metadata
comment_data = {
"type": annotation_type,
"generator": "cusum_annotator",
"params": {
"window_size": window_size,
"min_duration": min_duration,
"tolerance": tolerance,
},
}
f_min, f_max = detect_frequency(
signal=recording.data[0],
start=time_segments_indices[i],
stop=time_segments_indices[i + 1],
sample_rate=sample_rate,
)
annotations.append(
Annotation(
sample_start=time_segments_indices[i],
sample_count=time_segments_indices[i + 1] - time_segments_indices[i],
freq_lower_edge=center_frequency + f_min,
freq_upper_edge=center_frequency + f_max,
label=label,
comment=json.dumps(comment_data),
detail={"generator": "cusum_annotator"},
)
)
return Recording(data=recording.data, metadata=recording.metadata, annotations=recording.annotations + annotations)
def _compute_cusum(_signal, sample_rate: int, tolerance: int = None, min_duration: float = -1):
"""
This function efficiently computes the cumulative sum of a give list (_signal), with an optional tolerance.
Args:
- _signal: array of iq samples.
- Tolerance: the least acceptable length of a block, Defaults to None.
Returns:
- cusum (array): Array of the cumulative sum of the given list
- sample_rate (int): __description_
- change_points (array): Array of the indices at which a change in the CUSUM direction happens.
- min_duration (float): The least acceptable time width of each segment (in ms). Defaults to -1.
"""
# efficiently calculate the running sum of the signal
# cusum = list(itertools.accumulate((_signal - np.mean(_signal))))
x = _signal - np.mean(_signal)
cusum = np.cumsum(x)
# 'diff' computes the differences between the consecutive values,
# then 'sign' determines if it is +ve or -ve.
change_indicators = np.sign(np.diff(cusum))
change_points = np.where(np.diff(change_indicators))[0] + 1
# Limit the change_points
# Reject those whose number of samples < minimum accepted #n of samples in (min duration) ms.
if min_duration is not None and min_duration > 0:
min_samples_wide = int(min_duration * sample_rate / 1000)
segments_lengths = np.diff(change_points)
segments_lengths = np.insert(segments_lengths, 0, change_points[0])
change_points = change_points[np.where(segments_lengths > min_samples_wide)[0]]
return cusum, change_points
def detect_frequency(signal, start, stop, sample_rate):
signal_segment = signal[start:stop]
if len(signal_segment) > 0:
fft_data = np.abs(np.fft.fftshift(np.fft.fft(signal_segment)))
fft_freqs = np.fft.fftshift(np.fft.fftfreq(len(signal_segment), 1 / sample_rate))
# Use a spectral threshold to find the 'height' of the orange block
spectral_thresh = np.max(fft_data) * 0.15
sig_indices = np.where(fft_data > spectral_thresh)[0]
if len(sig_indices) > 4:
return fft_freqs[sig_indices[0]], fft_freqs[sig_indices[-1]]
else:
return -sample_rate / 4, sample_rate / 4
else:
return -sample_rate / 4, sample_rate / 4
class TimeSegmenter:
"""Time Segmenter class, it creates a segmenter object with certain\
characteristics to easily split an input signal to segments based on\
the cumulative sum of deviations (of the signal mean)
"""
def __init__(
self, sample_rate: int, min_duration: float = 1, moving_average_window: int = 3, tolerance: int = None
):
"""_summary_
Args:
sample_rate (int): _description_
min_duration (float, optional): _description_. Defaults to 1.
moving_average_window (int, optional): _description_. Defaults to 3.
tolerance (int, optional): _description_. Defaults to None.
"""
self.sample_rate = sample_rate
self.min_duration = min_duration
self.moving_average_window = moving_average_window
self._moving_avg_filter = self._init_filter()
self.tolerance = tolerance
def _init_filter(self):
"""_summary_
Returns:
_type_: _description_
"""
return np.ones(self.moving_average_window) / self.moving_average_window
def _apply_filter(self, iqsignal: np.array):
"""_summary_
Args:
iqsignal (np.array): _description_
Returns:
_type_: _description_
"""
return np.convolve(abs(iqsignal), self._moving_avg_filter, mode="same")
def _create_segments(self, iq_signal: np.array, change_points: np.array):
"""_summary_
Args:
iq_signal (np.array): _description_
change_points (np.array): _description_
Returns:
_type_: _description_
"""
return np.split(iq_signal, change_points)
def apply(self, iq_signal: np.array):
"""_summary_
Args:
iq_signal (np.array): _description_
Returns:
_type_: _description_
"""
smoothed_signal = self._apply_filter(iq_signal)
_, change_points = _compute_cusum(smoothed_signal, self.sample_rate, self.tolerance, self.min_duration)
# segments = self._create_segments(iq_signal, change_points)
return change_points

View File

@ -0,0 +1,438 @@
"""
Energy-based signal detection and bandwidth analysis.
Provides automatic annotation generation using energy-based signal detection
and occupied bandwidth calculation following ITU-R SM.328 standard.
"""
import json
from typing import Tuple
import numpy as np
from scipy.signal import filtfilt
from ria_toolkit_oss.data import Annotation, Recording
def detect_signals_energy(
recording: Recording,
k: int = 10,
threshold_factor: float = 1.2,
window_size: int = 200,
min_distance: int = 5000,
label: str = "signal",
annotation_type: str = "standalone",
freq_method: str = "nbw",
nfft: int = None,
obw_power: float = 0.99,
) -> Recording:
"""
Detect signal bursts using energy-based method with adaptive noise floor estimation.
This algorithm smooths the signal with a moving average filter, estimates the noise
floor from k segments, applies a threshold to detect regions above noise, and merges
nearby detections. Detected time boundaries are then assigned frequency bounds based
on the selected frequency method.
Time Detection Algorithm:
1. Smooth signal using moving average (envelope detection)
2. Divide smoothed signal into k segments
3. Estimate noise floor as median of segment mean powers
4. Detect regions where power exceeds threshold_factor * noise_floor
5. Merge regions closer than min_distance samples
Frequency Bounding (freq_method):
- 'nbw': Nominal bandwidth (OBW + center frequency) - DEFAULT
- 'obw': Occupied bandwidth (99.99% power, includes siedelobes)
- 'full-detected': Lowest to highest spectral component
- 'full-bandwidth': Entire Nyquist span (center_freq ± sample_rate/2)
:param recording: Recording to analyze
:type recording: Recording
:param k: Number of segments for noise floor estimation (default: 10)
:type k: int
:param threshold_factor: Threshold multiplier above noise floor (typical: 1.2-2.0, default: 1.2)
:type threshold_factor: float
:param window_size: Moving average window size in samples (default: 200)
:type window_size: int
:param min_distance: Minimum distance between separate signals in samples (default: 5000)
:type min_distance: int
:param label: Label for detected annotations (default: "signal")
:type label: str
:param annotation_type: Annotation type (standalone, parallel, intersection, default: standalone)
:type annotation_type: str
:param freq_method: How to calculate frequency bounds (default: 'nbw')
:type freq_method: str
:param nfft: FFT size for frequency calculations (default: None)
:type nfft: int
:param obw_power: Power percentage for OBW (0.9999 = 99.99%, default: 0.99)
:type obw_power: float
:returns: New Recording with added annotations
:rtype: Recording
**Example**::
>>> from ria.io import load_recording
>>> from ria_toolkit_oss.annotations import detect_signals_energy
>>> recording = load_recording("capture.sigmf")
>>> # Detect with NBW frequency bounds (default, best for real signals)
>>> annotated = detect_signals_energy(recording, label="burst")
>>> # Detect with OBW (more conservative, includes siedelobes)
>>> annotated = detect_signals_energy(
... recording, label="burst", freq_method="obw"
... )
>>> # Detect with full detected range (captures all spectral components)
>>> annotated = detect_signals_energy(
... recording, label="burst", freq_method="full-detected"
... )
"""
# Extract signal data (use first channel only)
signal = recording.data[0]
# Calculate smoothed signal power
kernel = np.ones(window_size) / window_size
smoothed_power = filtfilt(kernel, [1], np.abs(signal) ** 2)
# Estimate noise floor using segment-based median (robust to signal presence)
segments = np.array_split(smoothed_power, k)
noise_floor = np.median([np.mean(s) for s in segments])
# Detect signal boundaries (regions above threshold)
enter = noise_floor * threshold_factor
exit = enter * 0.8
boundaries = []
start = None
active = False
for i, p in enumerate(smoothed_power):
if not active and p > enter:
start = i
active = True
elif active and p < exit:
boundaries.append((start, i - start))
active = False
if active:
boundaries.append((start, len(smoothed_power) - start))
# Merge boundaries that are closer than min_distance
merged_boundaries = []
if boundaries:
start, length = boundaries[0]
for next_start, next_length in boundaries[1:]:
if next_start - (start + length) < min_distance:
# Merge with current boundary
length = next_start + next_length - start
else:
# Save current and start new boundary
merged_boundaries.append((start, length))
start, length = next_start, next_length
# Add final boundary
merged_boundaries.append((start, length))
# Create annotations from detected boundaries
sample_rate = recording.metadata["sample_rate"]
center_frequency = recording.metadata.get("center_frequency", 0)
# Validate frequency method
valid_freq_methods = ["nbw", "obw", "full-detected", "full-bandwidth"]
if freq_method not in valid_freq_methods:
raise ValueError(f"Invalid freq_method '{freq_method}'. " f"Must be one of: {', '.join(valid_freq_methods)}")
annotations = []
for start_sample, sample_count in merged_boundaries:
# Calculate frequency bounds based on method
freq_lower, freq_upper = calculate_frequency_bounds(
freq_method, center_frequency, sample_rate, nfft, signal, start_sample, sample_count, obw_power
)
# Build comment JSON with type metadata
comment_data = {
"type": annotation_type,
"generator": "energy_detector",
"freq_method": freq_method,
"params": {
"threshold_factor": threshold_factor,
"window_size": window_size,
"noise_floor": float(noise_floor),
"threshold": float(enter),
},
}
anno = Annotation(
sample_start=start_sample,
sample_count=sample_count,
freq_lower_edge=freq_lower,
freq_upper_edge=freq_upper,
label=label,
comment=json.dumps(comment_data),
detail={"generator": "energy_detector", "freq_method": freq_method},
)
annotations.append(anno)
return Recording(data=recording.data, metadata=recording.metadata, annotations=recording.annotations + annotations)
def calculate_occupied_bandwidth(
signal: np.ndarray,
sampling_rate: float,
nfft: int = None,
power_percentage: float = 0.99,
):
if nfft is None:
nfft = max(65536, 2 ** int(np.floor(np.log2(len(signal)))))
window = np.blackman(len(signal))
spec = np.fft.fftshift(np.fft.fft(signal * window, n=nfft))
psd = np.abs(spec) ** 2
psd = psd / psd.sum() # normalize
freqs = np.fft.fftshift(np.fft.fftfreq(nfft, 1 / sampling_rate))
cdf = np.cumsum(psd)
tail = (1 - power_percentage) / 2
lower_idx = np.searchsorted(cdf, tail)
upper_idx = np.searchsorted(cdf, 1 - tail)
return freqs[upper_idx] - freqs[lower_idx], freqs[lower_idx], freqs[upper_idx]
def calculate_nominal_bandwidth(
signal: np.ndarray,
sampling_rate: float,
nfft: int = None,
power_percentage: float = 0.99,
) -> Tuple[float, float]:
"""
Calculate nominal bandwidth and center frequency.
Nominal bandwidth (NBW) is the occupied bandwidth along with the center
frequency of the signal's spectral occupancy. Useful for characterizing
signals with unknown or drifting center frequencies.
:param signal: Complex IQ signal samples
:type signal: np.ndarray
:param sampling_rate: Sample rate in Hz
:type sampling_rate: float
:param nfft: FFT size
:type nfft: int
:param power_percentage: Fraction of power to contain
:type power_percentage: float
:returns: Tuple of (nominal_bandwidth_hz, center_frequency_hz)
:rtype: Tuple[float, float]
**Example**::
>>> from ria_toolkit_oss.annotations import calculate_nominal_bandwidth
>>> nbw, center = calculate_nominal_bandwidth(signal, sampling_rate=10e6)
>>> print(f"NBW: {nbw/1e6:.3f} MHz, Center: {center/1e6:.3f} MHz")
"""
bw, lower_freq, upper_freq = calculate_occupied_bandwidth(signal, sampling_rate, nfft, power_percentage)
# Center frequency is midpoint of occupied band
center_freq = (lower_freq + upper_freq) / 2
return lower_freq, upper_freq, center_freq
def calculate_full_detected_bandwidth(
signal: np.ndarray,
sampling_rate: float,
nfft: int = None,
start_offset: int = 1000,
) -> Tuple[float, float, float]:
"""
Calculate frequency range from lowest to highest spectral component.
Unlike OBW/NBW which define a power-based bandwidth, this calculates
the absolute frequency span from the lowest non-zero spectral component
to the highest non-zero component.
Useful for:
- Signals with spectral gaps
- Multiple parallel signals (captures all of them)
- Understanding total occupied spectrum vs. actual bandwidth
:param signal: Complex IQ signal samples
:type signal: np.ndarray
:param sampling_rate: Sample rate in Hz
:type sampling_rate: float
:param nfft: FFT size
:type nfft: int
:param start_offset: Skip samples at start
:type start_offset: int
:returns: Tuple of (bandwidth_hz, lower_freq_hz, upper_freq_hz)
:rtype: Tuple[float, float, float]
**Example**::
>>> # Signal with two components at different frequencies
>>> bw, f_low, f_high = calculate_full_detected_bandwidth(
... signal, sampling_rate=10e6, nfft=65536
... )
>>> print(f"Full span: {f_low/1e6:.3f} to {f_high/1e6:.3f} MHz")
"""
# Validate input
if len(signal) < nfft + start_offset:
raise ValueError(
f"Signal too short: need {nfft + start_offset} samples, "
f"got {len(signal)}. Reduce nfft or start_offset."
)
# Extract segment
signal_segment = signal[start_offset : nfft + start_offset]
# Compute FFT and power spectral density
freq_spectrum = np.fft.fft(signal_segment, n=nfft)
psd = np.abs(freq_spectrum) ** 2
# Shift to center DC
psd_shifted = np.fft.fftshift(psd)
freq_bins = np.fft.fftshift(np.fft.fftfreq(nfft, 1 / sampling_rate))
# Find noise floor (mean of lowest 10% of bins) and all bins above noise floor
noise_floor = np.mean(np.sort(psd_shifted)[: int(len(psd_shifted) * 0.1)])
above_noise = np.where(psd_shifted > noise_floor * 1.5)[0]
if len(above_noise) == 0:
# No signal above noise, return zero bandwidth
return 0.0, 0.0, 0.0
# Get frequency range of signal components
lower_idx = above_noise[0]
upper_idx = above_noise[-1]
lower_freq = freq_bins[lower_idx]
upper_freq = freq_bins[upper_idx]
bandwidth = upper_freq - lower_freq
return bandwidth, lower_freq, upper_freq
def annotate_with_obw(
recording: Recording,
label: str = "signal",
annotation_type: str = "standalone",
nfft: int = None,
power_percentage: float = 0.99,
) -> Recording:
"""
Create a single annotation spanning the occupied bandwidth of the entire recording.
Analyzes the full recording to find its occupied bandwidth and creates an annotation
covering that frequency range for the entire time duration.
:param recording: Recording to analyze
:type recording: Recording
:param label: Annotation label
:type label: str
:param annotation_type: Annotation type
:type annotation_type: str
:param nfft: FFT size
:type nfft: int
:param power_percentage: Power percentage for OBW calculation
:type power_percentage: float
:returns: Recording with OBW annotation added
:rtype: Recording
**Example**::
>>> from ria_toolkit_oss.annotations import annotate_with_obw
>>> annotated = annotate_with_obw(recording, label="signal_obw")
"""
signal = recording.data[0]
sample_rate = recording.metadata["sample_rate"]
center_freq = recording.metadata.get("center_frequency", 0)
# Calculate OBW
obw, lower_offset, upper_offset = calculate_occupied_bandwidth(signal, sample_rate, nfft, power_percentage)
# Convert baseband offsets to absolute frequencies
freq_lower = center_freq + lower_offset
freq_upper = center_freq + upper_offset
# Create comment JSON
comment_data = {
"type": annotation_type,
"generator": "obw_annotator",
"obw_hz": float(obw),
"power_percentage": power_percentage,
"params": {"nfft": nfft},
}
# Create annotation spanning entire recording
anno = Annotation(
sample_start=0,
sample_count=len(signal),
freq_lower_edge=freq_lower,
freq_upper_edge=freq_upper,
label=label,
comment=json.dumps(comment_data),
detail={"generator": "obw_annotator", "obw_hz": float(obw)},
)
return Recording(data=recording.data, metadata=recording.metadata, annotations=recording.annotations + [anno])
def calculate_frequency_bounds(
freq_method, center_frequency, sample_rate, nfft, signal, start_sample, sample_count, obw_power
):
if freq_method == "full-bandwidth":
# Full Nyquist span
freq_lower = center_frequency - (sample_rate / 2)
freq_upper = center_frequency + (sample_rate / 2)
else:
# Extract segment for frequency analysis
segment_start = start_sample
segment_end = min(start_sample + sample_count, len(signal))
segment = signal[segment_start:segment_end]
if nfft is None or len(segment) >= nfft:
if freq_method == "nbw":
# Nominal bandwidth (OBW + center frequency)
try:
lower_freq, upper_freq, _ = calculate_nominal_bandwidth(segment, sample_rate, nfft, obw_power)
freq_lower = center_frequency + lower_freq
freq_upper = center_frequency + upper_freq
except (ValueError, IndexError):
# Fallback if calculation fails
freq_lower = center_frequency - (sample_rate / 2)
freq_upper = center_frequency + (sample_rate / 2)
elif freq_method == "obw":
# Occupied bandwidth
try:
_, f_lower, f_upper = calculate_occupied_bandwidth(segment, sample_rate, nfft, obw_power)
freq_lower = center_frequency + f_lower
freq_upper = center_frequency + f_upper
except (ValueError, IndexError):
# Fallback if calculation fails
freq_lower = center_frequency - (sample_rate / 2)
freq_upper = center_frequency + (sample_rate / 2)
elif freq_method == "full-detected":
# Full detected range (lowest to highest component)
try:
_, f_lower, f_upper = calculate_full_detected_bandwidth(segment, sample_rate, nfft)
freq_lower = center_frequency + f_lower
freq_upper = center_frequency + f_upper
except (ValueError, IndexError):
# Fallback if calculation fails
freq_lower = center_frequency - (sample_rate / 2)
freq_upper = center_frequency + (sample_rate / 2)
else:
# Segment too short for FFT, use full bandwidth
freq_lower = center_frequency - (sample_rate / 2)
freq_upper = center_frequency + (sample_rate / 2)
return freq_lower, freq_upper

View File

@ -0,0 +1,435 @@
"""
Parallel signal separation for multi-component frequency-offset signals.
Provides methods to detect and separate overlapping frequency-domain signals
that occupy the same time window but different frequency bands.
This module implements **spectral peak detection** to identify distinct frequency
components and split single time-domain annotations into frequency-specific
sub-annotations.
**Key Design Decisions** (per Codex review):
1. **Complex IQ Support**: Uses `scipy.signal.welch` with `return_onesided=False`
for proper complex signal handling. Window length automatically adapts to
signal length via `nperseg=min(nfft, len(signal))` to handle bursts <nfft.
2. **Frequency Representation**: Components are detected in **relative** frequency
(baseband, centered at 0 Hz). Caller must add RF center_frequency_hz when
writing to SigMF annotations. This separation of concerns avoids the frequency
context bug where absolute Hz would be meaningless for baseband processing.
3. **Bandwidth Estimation**: Dual strategy avoids -3dB limitations:
- Primary: -3dB rolloff for typical narrowband signals
- Fallback: Cumulative power (99% like OBW) for wide/OFDM signals
- Auto-fallback when -3dB bandwidth is anomalous
4. **Noise Floor**: Auto-estimated via `np.percentile(psd_db, 10)` from data
to adapt across hardware (Pluto vs. ThinkRF). User can override if needed.
5. **Filter Sizing (Optional FIR extraction)**: When extracting components,
uses Kaiser window FIR with proper stopband specification. Auto-sizes
numtaps based on desired transition bandwidth. Includes downsampling
guidance for long captures.
6. **CLI Surface**: Single `separate` subcommand for all separation operations.
Can be chained after any detector or used standalone on existing annotations.
Example:
Two WiFi channels captured simultaneously:
>>> from ria_toolkit_oss.annotations import find_spectral_components
>>> # Detect the two distinct channels (returns relative frequencies)
>>> components = find_spectral_components(signal, sampling_rate=20e6)
>>> print(f"Found {len(components)} components")
Found 2 components
The module is designed to work with detected time-domain annotations,
allowing splitting of overlapping signals into separate training samples.
"""
import json
from typing import List, Optional, Tuple
import numpy as np
from scipy import ndimage
from scipy import signal as scipy_signal
from ria_toolkit_oss.data import Annotation, Recording
def find_spectral_components(
signal_data: np.ndarray,
sampling_rate: float,
nfft: int = 65536,
noise_threshold_db: Optional[float] = None,
min_component_bw: float = 50e3,
time_percentile: float = 70.0,
) -> List[Tuple[float, float, float]]:
"""
Find distinct frequency components using spectral peak detection.
Identifies separate frequency components in a signal by analyzing the power
spectral density and finding peaks corresponding to distinct signals. This is
useful for separating parallel signals that occupy different frequency bands.
**Frequency Representation**: Returns frequencies in **baseband/relative** Hz
(centered at 0). To get absolute RF frequencies, add center_frequency_hz from
recording metadata to all returned values.
Algorithm:
1. Compute power spectral density using Welch (properly handles complex IQ)
2. Auto-estimate noise floor from data if not specified
3. Smooth PSD to reduce spurious peaks
4. Find local maxima above noise floor
5. Estimate bandwidth per peak using -3dB (fallback: cumulative power)
6. Filter components below minimum bandwidth threshold
:param signal_data: Complex IQ signal samples (np.complex64/128)
:type signal_data: np.ndarray
:param sampling_rate: Sample rate in Hz
:type sampling_rate: float
:param nfft: FFT size / window length for Welch. Automatically capped at
signal length to handle bursts (default: 65536)
:type nfft: int
:param noise_threshold_db: Minimum SNR threshold in dB. If None (default),
auto-estimates as np.percentile(psd_db, 10).
Adapt this across hardware (Pluto: ~-100, ThinkRF: ~-60).
:type noise_threshold_db: Optional[float]
:param min_component_bw: Minimum component bandwidth in Hz (default: 50 kHz)
:type min_component_bw: float
:param power_threshold: Cumulative power threshold for fallback bandwidth
estimation (default: 0.99 = 99% power, like OBW)
:type power_threshold: float
:returns: List of (center_freq_hz, lower_freq_hz, upper_freq_hz) tuples.
**All frequencies are relative (baseband, 0-centered).**
Add recording metadata['center_frequency'] to get absolute RF frequencies.
:rtype: List[Tuple[float, float, float]]
:raises ValueError: If signal has fewer than 256 samples
**Example**::
>>> from ria.io import load_recording
>>> from ria_toolkit_oss.annotations import find_spectral_components
>>> recording = load_recording("capture.sigmf")
>>> segment = recording.data[0][start:end]
>>> # Components in relative (baseband) frequency
>>> components = find_spectral_components(segment, sampling_rate=20e6)
>>> for center_rel, lower_rel, upper_rel in components:
... # Convert to absolute RF frequency
... center_abs = recording.metadata['center_frequency'] + center_rel
... print(f"Component @ {center_abs/1e9:.3f} GHz")
"""
# Validate input
min_samples = 256
if len(signal_data) < min_samples:
raise ValueError(f"Signal too short: need at least {min_samples} samples, " f"got {len(signal_data)}.")
# Compute PSD using Welch method for complex IQ signals
# CRITICAL: return_onesided=False for proper complex signal handling
nperseg = min(nfft, len(signal_data))
noverlap = nperseg // 2
# --- STFT ---
freqs, times, Zxx = scipy_signal.stft(
signal_data,
fs=sampling_rate,
window="blackman",
nperseg=nperseg,
noverlap=noverlap,
return_onesided=False,
boundary=None,
)
# Shift zero freq to center
Zxx = np.fft.fftshift(Zxx, axes=0)
freqs = np.fft.fftshift(freqs)
# Power spectrogram
power = np.abs(Zxx) ** 2
power_db = 10 * np.log10(power + 1e-12)
# --- Aggregate across time robustly ---
# Using percentile instead of mean prevents short signals from being diluted
freq_profile_db = np.percentile(power_db, time_percentile, axis=1)
# --- Noise floor estimation ---
if noise_threshold_db is None:
noise_threshold_db = np.percentile(freq_profile_db, 20)
threshold = noise_threshold_db + 3 # 3 dB above noise floor
# --- Smooth lightly (avoid merging nearby signals) ---
freq_profile_db = ndimage.gaussian_filter1d(freq_profile_db, sigma=1.5)
# --- Binary mask of significant frequencies ---
mask = freq_profile_db > threshold
# --- Find contiguous frequency regions ---
labeled, num_features = ndimage.label(mask)
components = []
for region_label in range(1, num_features + 1):
region_indices = np.where(labeled == region_label)[0]
if len(region_indices) == 0:
continue
lower_idx = region_indices[0]
upper_idx = region_indices[-1]
lower_freq = freqs[lower_idx]
upper_freq = freqs[upper_idx]
bw = upper_freq - lower_freq
if bw < min_component_bw:
continue
center_freq = (lower_freq + upper_freq) / 2
components.append((center_freq, lower_freq, upper_freq))
return components
def split_annotation_by_components(
annotation: Annotation,
signal: np.ndarray,
sampling_rate: float,
center_frequency_hz: float = 0.0,
nfft: int = 65536,
noise_threshold_db: Optional[float] = None,
min_component_bw: float = 50e3,
) -> List[Annotation]:
"""
Split an annotation into multiple annotations by detected frequency components.
Takes an existing annotation spanning multiple frequency components and
analyzes the frequency content to create separate sub-annotations for
each distinct frequency component.
**Use case**: Energy detection found a time window with 2-3 parallel WiFi
channels. This function splits it into separate annotations per channel.
**Frequency Handling**: `find_spectral_components` returns relative (baseband)
frequencies. This function adds `center_frequency_hz` to convert to absolute
RF frequencies for SigMF annotation bounds. This ensures correct frequency
context across baseband and RF domains.
:param annotation: Original annotation to split
:type annotation: Annotation
:param signal: Full signal array (complex IQ)
:type signal: np.ndarray
:param sampling_rate: Sample rate in Hz
:type sampling_rate: float
:param center_frequency_hz: RF center frequency to add to relative frequencies
from peak detection (default: 0.0 = baseband)
:type center_frequency_hz: float
:param nfft: FFT size for analysis (default: 65536, auto-capped at signal length)
:type nfft: int
:param noise_threshold_db: Noise floor threshold in dB. If None (default),
auto-estimates from data.
:type noise_threshold_db: Optional[float]
:param min_component_bw: Minimum component bandwidth in Hz (default: 50 kHz)
:type min_component_bw: float
:returns: List of new annotations (one per detected component).
Returns empty list if no components found or segment too short.
:rtype: List[Annotation]
**Example**::
>>> from ria.io import load_recording
>>> from ria_toolkit_oss.annotations import split_annotation_by_components
>>> recording = load_recording("capture.sigmf")
>>> # Original annotation spans multiple channels
>>> original = recording.annotations[0]
>>> # Split using RF center frequency from metadata
>>> components = split_annotation_by_components(
... original,
... recording.data[0],
... recording.metadata['sample_rate'],
... center_frequency_hz=recording.metadata.get('center_frequency', 0.0)
... )
>>> print(f"Split into {len(components)} components")
Split into 2 components
**Algorithm**:
1. Extract segment corresponding to annotation time bounds
2. Find frequency components in that segment (returns relative frequencies)
3. Add center_frequency_hz to get absolute RF frequencies
4. Create new annotation for each component
5. Preserve original metadata (label, type, etc.)
6. Add component info to comment JSON
**Notes**:
- Original annotation is not modified
- Returns empty list if segment too short (<256 samples)
- Segments <nfft get auto-downsampled to nfft (see find_spectral_components)
- Each component inherits label from original
- Component frequencies in comment JSON are absolute (RF) frequencies
"""
# Extract segment corresponding to annotation time bounds
start_sample = annotation.sample_start
end_sample = min(start_sample + annotation.sample_count, len(signal))
segment = signal[start_sample:end_sample]
# Validate segment length is enough for spectral analysis
if len(segment) < 256:
return []
# Find components in this segment (returns relative/baseband frequencies)
try:
components = find_spectral_components(segment, sampling_rate, nfft, noise_threshold_db, min_component_bw)
except ValueError:
# Spectral analysis failed (e.g., not complex IQ)
return []
if not components:
# No components found
return []
# Create annotations for each component
new_annotations = []
for center_freq_rel, lower_freq_rel, upper_freq_rel in components:
# Convert relative (baseband) frequencies to absolute (RF) frequencies
center_freq_abs = center_frequency_hz + center_freq_rel
lower_freq_abs = center_frequency_hz + lower_freq_rel
upper_freq_abs = center_frequency_hz + upper_freq_rel
# Parse original annotation metadata
try:
comment_data = json.loads(annotation.comment)
except (json.JSONDecodeError, TypeError):
comment_data = {"type": "standalone"}
# Add component information (with absolute RF frequencies)
comment_data["split_from_annotation"] = True
comment_data["original_freq_bounds"] = {
"lower": float(annotation.freq_lower_edge),
"upper": float(annotation.freq_upper_edge),
}
comment_data["component_freq_bounds_rf"] = {
"center": float(center_freq_abs),
"lower": float(lower_freq_abs),
"upper": float(upper_freq_abs),
}
# Create new annotation with absolute RF frequency bounds
new_anno = Annotation(
sample_start=annotation.sample_start,
sample_count=annotation.sample_count,
freq_lower_edge=lower_freq_abs,
freq_upper_edge=upper_freq_abs,
label=annotation.label,
comment=json.dumps(comment_data),
detail={
"generator": "parallel_signal_separator",
"center_freq_hz": float(center_freq_abs),
},
)
new_annotations.append(new_anno)
return new_annotations
def split_recording_annotations(
recording: Recording,
indices: Optional[List[int]] = None,
nfft: int = 65536,
noise_threshold_db: Optional[float] = None,
min_component_bw: float = 50e3,
) -> Recording:
"""
Split multiple annotations in a recording by frequency components.
Processes specified annotations (or all if indices=None), replacing each
with its frequency-separated components. Uses RF center_frequency from
recording metadata for proper absolute frequency conversion.
:param recording: Recording to process
:type recording: Recording
:param indices: Annotation indices to split (None = all, default: None).
Use indices=[] to skip splitting (returns unchanged recording).
:type indices: Optional[List[int]]
:param nfft: FFT size for spectral analysis (default: 65536,
auto-capped at signal segment length)
:type nfft: int
:param noise_threshold_db: Noise floor threshold in dB. If None (default),
auto-estimates from each segment.
:type noise_threshold_db: Optional[float]
:param min_component_bw: Minimum component bandwidth in Hz (default: 50 kHz).
Components narrower than this are filtered out.
:type min_component_bw: float
:returns: New Recording with split annotations
:rtype: Recording
**Example**::
>>> from ria.io import load_recording
>>> from ria_toolkit_oss.annotations import split_recording_annotations
>>> recording = load_recording("capture.sigmf")
>>> # Split all annotations
>>> split_rec = split_recording_annotations(recording)
>>> print(f"Original: {len(recording.annotations)} annotations")
>>> print(f"Split: {len(split_rec.annotations)} annotations")
Original: 5 annotations
Split: 9 annotations
**Algorithm**:
1. For each annotation in indices (or all if None):
2. Call split_annotation_by_components with RF center_frequency
3. If components found, replace annotation with components
4. If no components found, keep original annotation
5. Annotations not in indices are kept unchanged
**Notes**:
- Original recording is not modified
- Returns empty Recording.annotations if recording has no annotations
- RF center_frequency from metadata ensures correct absolute frequencies
- If an annotation can't be split (too short, wrong format), original kept
"""
if indices is None:
# Split all annotations
indices = list(range(len(recording.annotations)))
if not recording.annotations:
# No annotations to split
return recording
signal = recording.data[0]
sample_rate = recording.metadata["sample_rate"]
center_frequency = recording.metadata.get("center_frequency", 0.0)
# Build new annotation list
new_annotations = []
for i, anno in enumerate(recording.annotations):
if i in indices:
# Attempt to split this annotation
try:
components = split_annotation_by_components(
anno,
signal,
sample_rate,
center_frequency_hz=center_frequency,
nfft=nfft,
noise_threshold_db=noise_threshold_db,
min_component_bw=min_component_bw,
)
if components:
# Split successful, use components
new_annotations.extend(components)
else:
# No components found, keep original
new_annotations.append(anno)
except Exception:
# Split failed for any reason, keep original
new_annotations.append(anno)
else:
# Not in split list, keep as-is
new_annotations.append(anno)
return Recording(data=recording.data, metadata=recording.metadata, annotations=new_annotations)

View File

@ -0,0 +1,35 @@
import numpy as np
from ria_toolkit_oss.data import Recording
def qualify_slice_from_annotations(recording: Recording, slice_length: int):
"""
Slice a recording into many smaller recordings,
discarding any slices which do not have annotations that apply to those samples.
Used together with an annotation based qualifier.
:param recording: The recording to slice.
:type recording: Recording
:param slice_length: The length in samples of a slice.
:type slice_length: int"""
if len(recording.annotations) == 0:
print("Warning, no annotations.")
annotation_mask = np.zeros(len(recording.data[0]))
for annotation in recording.annotations:
annotation_mask[annotation.sample_start : annotation.sample_start + annotation.sample_count] = 1
output_recordings = []
for i in range((len(recording.data[0]) // slice_length) - 1):
start_index = slice_length * i
end_index = slice_length * (i + 1)
if 1 in annotation_mask[start_index:end_index]:
sl = recording.data[:, start_index:end_index]
output_recordings.append(Recording(data=sl, metadata=recording.metadata))
return output_recordings

View File

@ -0,0 +1,97 @@
import numpy as np
from scipy.signal import butter, lfilter
from ria_toolkit_oss.data.annotation import Annotation
from ria_toolkit_oss.data.recording import Recording
def isolate_signal(recording: Recording, annotation: Annotation) -> Recording:
"""
Slice, filter and frequency shift the input recording according to the bounding box defined by the annotation.
:param recording: The input Recording to be sliced.
:type recording: Recording
:param annotation: The Annotation object defining the area of the recording to isolate.
:type annotation: Annotation
:param decimate: Decimate the input signal after filtering to reduce the sample rate.
:type decimate: bool
:returns: The subsection of the original recording defined by the annotation.
:rtype: Recording"""
sample_start = max(0, annotation.sample_start)
sample_stop = min(len(recording), annotation.sample_start + annotation.sample_count)
anno_base_center_freq = (annotation.freq_lower_edge + annotation.freq_upper_edge) / 2 - recording.metadata.get(
"center_frequency", 0
)
anno_bw = annotation.freq_upper_edge - annotation.freq_lower_edge
signal_slice = recording.data[0, sample_start:sample_stop]
# normalize
signal_slice = signal_slice / np.max(np.abs(signal_slice))
isolation_bw = anno_bw
# frequency shift the center of the box about zero
shifted_signal_slice = frequency_shift_iq_samples(
iq_samples=signal_slice,
sample_rate=recording.metadata["sample_rate"],
shift_frequency=-1 * anno_base_center_freq,
)
# filter
if isolation_bw < recording.metadata["sample_rate"] - 1:
filtered_signal = apply_complex_lowpass_filter(
signal=shifted_signal_slice, cutoff_frequency=isolation_bw, sample_rate=recording.metadata["sample_rate"]
)
else:
filtered_signal = shifted_signal_slice
output = Recording(data=[filtered_signal], metadata=recording.metadata)
return output
def frequency_shift_iq_samples(iq_samples, sample_rate, shift_frequency):
# Number of samples
num_samples = len(iq_samples)
# Create a time vector from 0 to the total duration in seconds
time_vector = np.arange(num_samples) / sample_rate
# Generate the complex exponential for the frequency shift
complex_exponential = np.exp(1j * 2 * np.pi * shift_frequency * time_vector)
# Apply the frequency shift to the IQ samples
shifted_samples = iq_samples * complex_exponential
return shifted_samples
# Function to apply a lowpass Butterworth filter to a complex signal
def apply_complex_lowpass_filter(signal, cutoff_frequency, sample_rate, order=5):
# Design the lowpass filter
b, a = design_complex_lowpass_filter(cutoff_frequency, sample_rate, order)
# Apply the lowpass filter
filtered_signal = lfilter(b, a, signal)
return filtered_signal
def design_complex_lowpass_filter(cutoff_frequency, sample_rate, order=5):
# Nyquist frequency for complex signals is the sample rate
nyquist = sample_rate
# Ensure the cutoff frequency is positive and within the Nyquist limit
if cutoff_frequency <= 0 or cutoff_frequency > nyquist:
raise ValueError("Cutoff frequency must be between 0 and the Nyquist frequency.")
# Normalize the cutoff frequency to the Nyquist frequency
cutoff_normalized = cutoff_frequency / nyquist
# Create a Butterworth lowpass filter
b, a = butter(order, cutoff_normalized, btype="low")
return b, a

View File

@ -0,0 +1,359 @@
"""
Temporal signal detection and boundary refinement via Hysteresis Thresholding.
Provides methods to detect signal bursts in the time domain by triggering on
smoothed power peaks and expanding boundaries to capture the full energy envelope.
This module implements a **dual-threshold trigger** to solve the 'chatter'
problem in noisy environments, ensuring that signal annotations encapsulate
the entire rise and fall of a burst rather than just the peak.
**Key Design Decisions**:
1. **Hysteresis Logic (Dual-Threshold)**:
- **Trigger**: High threshold (`threshold * max_power`) ensures high confidence
in signal presence.
- **Boundary**: Low threshold (`0.5 * trigger`) allows the annotation to
"crawl" outward, capturing the lower-energy start and end of the burst
often missed by simple single-threshold detectors.
2. **Temporal Smoothing**: Uses a moving average window (`window_size`) prior
- to thresholding. This prevents high-frequency noise spikes from causing
fragmented annotations and provides a more stable estimate of the
signal's power envelope.
3. **Spectral Profiling**: Once a temporal segment is isolated, the module
- performs an automated FFT analysis. It identifies the **90% spectral
occupancy** to define the frequency boundaries (`f_min`, `f_max`),
allowing the detector to work on narrowband and wideband signals without
manual frequency tuning.
4. **Baseband/RF Mapping**: Automatically handles the conversion from
- relative FFT bin frequencies to absolute RF frequencies by referencing
`recording.metadata["center_frequency"]`.
5. **False Positive Mitigation**: Implements a hard minimum duration check
- (10ms) to ignore transient hardware spikes or noise floor fluctuations
that do not constitute a valid signal burst.
The module is designed to be the primary "first-pass" detector for pulsed
waveforms (like ADS-B, Lora, or bursty FSK) before passing them to
classification or demodulation stages.
"""
import json
from typing import Optional
import numpy as np
from ria_toolkit_oss.data import Annotation, Recording
def _find_ranges(indices, max_gap):
"""
Groups individual indices into continuous temporal ranges.
Args:
indices: Array of indices where the signal exceeded a threshold.
max_gap: Maximum gap allowed between indices to consider them part
of the same range.
Returns:
A list of (start, stop) tuples representing detected signal segments.
"""
if len(indices) == 0:
return []
start = indices[0]
prev = indices[0]
ranges = []
for i in range(1, len(indices)):
if indices[i] - prev > max_gap:
ranges.append((start, prev))
start = indices[i]
prev = indices[i]
ranges.append((start, prev))
return ranges
def _expand_and_filter_ranges(
smoothed_power: np.ndarray,
initial_ranges: list[tuple[int, int]],
boundary_val: float,
min_duration_samples: int,
) -> list[tuple[int, int]]:
"""Apply hysteresis expansion and minimum-duration filtering."""
out: list[tuple[int, int]] = []
n = len(smoothed_power)
for start, stop in initial_ranges:
if (stop - start) < min_duration_samples:
continue
true_start = start
while true_start > 0 and smoothed_power[true_start] > boundary_val:
true_start -= 1
true_stop = stop
while true_stop < n - 1 and smoothed_power[true_stop] > boundary_val:
true_stop += 1
if (true_stop - true_start) >= min_duration_samples:
out.append((true_start, true_stop))
return out
def _merge_ranges(ranges: list[tuple[int, int]], max_gap: int) -> list[tuple[int, int]]:
"""Merge overlapping or near-adjacent ranges."""
if not ranges:
return []
ranges = sorted(ranges, key=lambda r: r[0])
merged = [ranges[0]]
for s, e in ranges[1:]:
last_s, last_e = merged[-1]
if s <= last_e + max_gap:
merged[-1] = (last_s, max(last_e, e))
else:
merged.append((s, e))
return merged
def _estimate_noise_floor(power: np.ndarray, quantile: float = 20.0) -> float:
"""Estimate baseline from the quieter portion of the envelope."""
return float(np.percentile(power, quantile))
def _estimate_group_gap(sample_rate: float) -> int:
"""Use a fixed temporal grouping gap instead of reusing the smoothing window."""
return max(1, int(0.001 * sample_rate))
def _estimate_spectral_bounds(signal_segment: np.ndarray, sample_rate: float) -> tuple[float, float]:
"""Estimate occupied bandwidth from a smoothed magnitude spectrum."""
if len(signal_segment) == 0:
return -sample_rate / 4, sample_rate / 4
window = np.hanning(len(signal_segment))
windowed = signal_segment * window
fft_data = np.abs(np.fft.fftshift(np.fft.fft(windowed)))
fft_freqs = np.fft.fftshift(np.fft.fftfreq(len(signal_segment), 1 / sample_rate))
# Smooth the spectrum so noise-like wideband bursts form a contiguous mask
# instead of thousands of tiny isolated runs.
spectral_smooth_bins = max(5, min(257, (len(signal_segment) // 512) | 1))
spectral_kernel = np.ones(spectral_smooth_bins, dtype=np.float64) / spectral_smooth_bins
smoothed_fft = np.convolve(fft_data, spectral_kernel, mode="same")
spectral_floor = float(np.percentile(smoothed_fft, 20))
spectral_peak = float(np.max(smoothed_fft))
spectral_ratio = spectral_peak / max(spectral_floor, 1e-12)
if spectral_ratio < 1.2:
return -sample_rate / 4, sample_rate / 4
spectral_thresh = spectral_floor + 0.1 * (spectral_peak - spectral_floor)
sig_indices = np.where(smoothed_fft > spectral_thresh)[0]
if len(sig_indices) == 0:
peak_idx = int(np.argmax(smoothed_fft))
bin_hz = sample_rate / len(signal_segment)
half_bins = max(1, int(np.ceil(10_000.0 / bin_hz)))
lo_idx = max(0, peak_idx - half_bins)
hi_idx = min(len(smoothed_fft) - 1, peak_idx + half_bins)
else:
runs = _find_ranges(sig_indices, max_gap=max(1, spectral_smooth_bins // 2))
peak_idx = int(np.argmax(smoothed_fft))
lo_idx, hi_idx = min(
runs,
key=lambda run: 0 if run[0] <= peak_idx <= run[1] else min(abs(run[0] - peak_idx), abs(run[1] - peak_idx)),
)
# Prevent extremely narrow tone boxes from collapsing to just a few bins.
min_total_bw_hz = 20_000.0
min_half_bins = max(1, int(np.ceil((min_total_bw_hz / 2) / (sample_rate / len(signal_segment)))))
center_idx = int(round((lo_idx + hi_idx) / 2))
lo_idx = max(0, min(lo_idx, center_idx - min_half_bins))
hi_idx = min(len(smoothed_fft) - 1, max(hi_idx, center_idx + min_half_bins))
return float(fft_freqs[lo_idx]), float(fft_freqs[hi_idx])
def threshold_qualifier(
recording: Recording,
threshold: float,
window_size: Optional[int] = None,
label: Optional[str] = None,
annotation_type: Optional[str] = "standalone",
channel: int = 0,
) -> Recording:
"""
Annotate a recording with bounding boxes for regions above a threshold.
Threshold is defined as a fraction of the maximum sample magnitude.
This algorithm searches for samples above the threshold and combines them into ranges if they
are within window_size of each other.
Detects and annotates signals using energy thresholding and spectral analysis.
The algorithm follows these steps:
1. Smooths power data using a moving average.
2. Identifies 'peak' regions exceeding a high trigger threshold.
3. Uses hysteresis to expand boundaries until power drops below a lower threshold.
4. Performs an FFT on each segment to determine frequency occupancy.
Args:
recording: The Recording object containing IQ or real signal data.
threshold: Sensitivity multiplier (0.0 to 1.0) applied to max power.
window_size: Size of the smoothing filter in samples. Defaults to 1ms worth of samples.
label: Custom string label for annotations.
annotation_type: Metadata string for the 'type' field in the annotation.
channel: Index of the channel to annotate. Defaults to 0.
Returns:
A new Recording object populated with detected Annotations.
"""
# Extract signal and metadata
sample_data = recording.data[channel]
sample_rate = recording.metadata["sample_rate"]
center_frequency = recording.metadata.get("center_frequency", 0)
if window_size is None:
window_size = max(64, int(sample_rate * 0.001))
# --- 1. SIGNAL CONDITIONING ---
# Convert to power (Magnitude squared)
power_data = np.abs(sample_data) ** 2
smoothing_window = np.ones(window_size) / window_size
smoothed_power = np.convolve(power_data, smoothing_window, mode="same")
group_gap_samples = _estimate_group_gap(sample_rate)
# Define thresholds using peak relative to baseline.
max_power = np.max(smoothed_power)
noise_floor = _estimate_noise_floor(smoothed_power)
dynamic_range_ratio = max_power / max(noise_floor, 1e-12)
# Soft early exit: keep a guard for low-contrast noise, but compute it from
# the quieter tail of the envelope so burst-heavy captures are not rejected.
if dynamic_range_ratio < 1.5:
return Recording(data=recording.data, metadata=recording.metadata, annotations=recording.annotations)
trigger_val = noise_floor + threshold * (max_power - noise_floor)
boundary_val = noise_floor + 0.5 * threshold * (max_power - noise_floor)
# --- 2. INITIAL DETECTION ---
# Enforce an explicit minimum duration in seconds; this is stable across
# varying capture lengths and avoids over-fitting to recording length.
min_duration_samples = max(1, int(0.005 * sample_rate))
annotations = []
# Pass 1: Detect stronger bursts.
indices = np.where(smoothed_power > trigger_val)[0]
pass1_initial = _find_ranges(indices=indices, max_gap=group_gap_samples)
pass1_ranges = _expand_and_filter_ranges(
smoothed_power=smoothed_power,
initial_ranges=pass1_initial,
boundary_val=boundary_val,
min_duration_samples=min_duration_samples,
)
# Pass 2: Recover weaker bursts on residual power not already covered.
# This improves recall in mixed-amplitude captures.
# Expand each Pass-1 range by the smoothing window on both sides so the
# smoothing skirts of a strong burst are not re-detected as a weak burst
# immediately adjacent to it (mirrors the guard used in Pass 3).
mask = np.ones_like(smoothed_power, dtype=np.float32)
pass2_mask_expand = window_size
for s, e in pass1_ranges:
mask[max(0, s - pass2_mask_expand) : min(len(mask), e + pass2_mask_expand)] = 0.0
residual_power = smoothed_power * mask
residual_max = float(np.max(residual_power))
residual_ratio = residual_max / max(noise_floor, 1e-12)
pass2_ranges: list[tuple[int, int]] = []
if residual_ratio >= 2.0:
weak_threshold = max(0.3, threshold * 0.7)
weak_trigger = noise_floor + weak_threshold * (residual_max - noise_floor)
weak_boundary = noise_floor + 0.5 * weak_threshold * (residual_max - noise_floor)
weak_indices = np.where(residual_power > weak_trigger)[0]
pass2_initial = _find_ranges(indices=weak_indices, max_gap=group_gap_samples)
pass2_ranges = _expand_and_filter_ranges(
smoothed_power=residual_power,
initial_ranges=pass2_initial,
boundary_val=weak_boundary,
min_duration_samples=min_duration_samples,
)
# Pass 3: Detect sustained faint bursts via macro-window averaging.
# Targets bursts whose peak power is near the trigger level but whose
# *average* power is consistently elevated above the noise floor — these
# are missed by peak-based detection because only a few short spikes exceed
# the trigger, all too brief to pass the minimum-duration filter.
#
# The mask is applied to power_data *before* convolving so that bright
# burst energy does not bleed through the long window into adjacent regions,
# which would inflate macro_residual_max and push the trigger above the
# faint burst's average power.
macro_window_size = max(window_size * 16, int(sample_rate * 0.02))
macro_kernel = np.ones(macro_window_size, dtype=np.float64) / macro_window_size
# Expand each annotated range by half the macro window on both sides so that
# the long convolution cannot "see" the leading/trailing edges of already-
# annotated bursts, which would produce spurious short fragments in Pass 3.
macro_expand = macro_window_size * 2
masked_power_for_macro = power_data.copy()
n = len(masked_power_for_macro)
for s, e in pass1_ranges + pass2_ranges:
masked_power_for_macro[max(0, s - macro_expand) : min(n, e + macro_expand)] = 0.0
macro_residual = np.convolve(masked_power_for_macro, macro_kernel, mode="same")
macro_residual_max = float(np.max(macro_residual))
pass3_ranges: list[tuple[int, int]] = []
if macro_residual_max / max(noise_floor, 1e-12) >= 1.3:
macro_trigger = noise_floor + threshold * (macro_residual_max - noise_floor)
macro_boundary = noise_floor + 0.5 * threshold * (macro_residual_max - noise_floor)
macro_indices = np.where(macro_residual > macro_trigger)[0]
macro_initial = _find_ranges(indices=macro_indices, max_gap=group_gap_samples)
pass3_ranges = _expand_and_filter_ranges(
smoothed_power=macro_residual,
initial_ranges=macro_initial,
boundary_val=macro_boundary,
min_duration_samples=min_duration_samples,
)
all_ranges = _merge_ranges(pass1_ranges + pass2_ranges + pass3_ranges, max_gap=group_gap_samples)
for true_start, true_stop in all_ranges:
# --- 4. SPECTRAL ANALYSIS (Frequency Detection) ---
signal_segment = sample_data[true_start:true_stop]
f_min, f_max = _estimate_spectral_bounds(signal_segment, sample_rate)
# --- 5. ANNOTATION GENERATION ---
ann_label = label if label is not None else f"{int(threshold*100)}%"
# Pack metadata for the UI/Downstream processing
comment_data = {
"type": annotation_type,
"generator": "threshold_qualifier",
"params": {
"threshold": threshold,
"window_size": window_size,
},
}
anno = Annotation(
sample_start=true_start,
sample_count=true_stop - true_start,
freq_lower_edge=center_frequency + f_min,
freq_upper_edge=center_frequency + f_max,
label=ann_label,
comment=json.dumps(comment_data),
detail={"generator": "hysteresis_qualifier"},
)
annotations.append(anno)
# Return a new Recording object including the new annotations
return Recording(data=recording.data, metadata=recording.metadata, annotations=recording.annotations + annotations)

View File

@ -0,0 +1 @@
"""App runner: pull and run containerized RIA applications."""

View File

@ -0,0 +1,278 @@
"""Unified ``ria-app`` CLI.
Subcommands:
- ``ria-app pull <app>[:tag]`` pull a RIA app image from the configured registry.
- ``ria-app run <app>[:tag]`` pull (if needed) and run, auto-configuring
GPU/USB/network flags from image labels set by CI.
- ``ria-app list`` list locally cached RIA app images.
- ``ria-app stop <app>`` stop a running app container.
- ``ria-app logs <app>`` tail logs of a running app container.
- ``ria-app configure`` set default registry/namespace.
Image references resolve as::
my-classifier -> {registry}/{namespace}/my-classifier:latest
group/my-classifier -> {registry}/group/my-classifier:latest
host/group/app:tag -> host/group/app:tag (fully-qualified passthrough)
"""
from __future__ import annotations
import argparse
import json
import os
import shutil
import subprocess
import sys
from . import config as _config
_LABEL_PROFILE = "ria.profile"
_LABEL_HARDWARE = "ria.hardware"
_LABEL_APP = "ria.app"
def _engine(cfg: _config.AppConfig, sudo_override: bool = False) -> list[str]:
for exe in ("docker", "podman"):
if shutil.which(exe):
use_sudo = sudo_override or cfg.sudo
return ["sudo", exe] if use_sudo else [exe]
print("error: neither 'docker' nor 'podman' found on PATH", file=sys.stderr)
sys.exit(2)
def _resolve_ref(app: str, cfg: _config.AppConfig) -> str:
ref = app if ":" in app.split("/")[-1] else f"{app}:latest"
slashes = ref.count("/")
if slashes >= 2:
return ref
if slashes == 1:
return f"{cfg.registry}/{ref}" if cfg.registry else ref
if not cfg.registry or not cfg.namespace:
print(
"error: app is not fully qualified and no default registry/namespace configured. "
"Run `ria-app configure` or pass a full image reference (registry/namespace/app:tag).",
file=sys.stderr,
)
sys.exit(2)
return f"{cfg.registry}/{cfg.namespace}/{ref}"
def _container_name(ref: str) -> str:
name = ref.rsplit("/", 1)[-1].split(":", 1)[0]
return f"ria-app-{name}"
def _inspect_labels(engine: list[str], ref: str) -> dict:
try:
out = subprocess.check_output(
[*engine, "image", "inspect", "--format", "{{json .Config.Labels}}", ref],
stderr=subprocess.DEVNULL,
)
except subprocess.CalledProcessError:
return {}
try:
return json.loads(out.decode().strip()) or {}
except json.JSONDecodeError:
return {}
def _gpu_available() -> bool:
if os.path.exists("/dev/nvidia0"):
return True
return shutil.which("nvidia-smi") is not None
def _hardware_flags(labels: dict, no_gpu: bool, no_usb: bool, no_host_net: bool) -> tuple[list[str], list[str]]:
flags: list[str] = []
notes: list[str] = []
profile = (labels.get(_LABEL_PROFILE) or "").lower()
hardware = (labels.get(_LABEL_HARDWARE) or "").lower()
hw_items = {h.strip() for h in hardware.split(",") if h.strip()}
wants_gpu = any(k in profile for k in ("nvidia", "holoscan", "cuda"))
if wants_gpu and not no_gpu:
if _gpu_available():
flags += ["--gpus", "all"]
else:
notes.append(
"image wants GPU but no NVIDIA runtime detected — skipping --gpus (use --force-gpu to override)"
)
if hw_items & {"pluto", "rtlsdr", "hackrf", "bladerf"} and not no_usb:
flags += ["--device", "/dev/bus/usb"]
if hw_items & {"usrp", "thinkrf", "pluto"} and not no_host_net:
flags += ["--net", "host"]
return flags, notes
def _cmd_configure(args: argparse.Namespace) -> int:
cfg = _config.load()
if args.registry:
cfg.registry = args.registry
if args.namespace:
cfg.namespace = args.namespace
if args.sudo is not None:
cfg.sudo = args.sudo
path = _config.save(cfg)
print(f"Saved app config to {path}")
print(f" registry: {cfg.registry or '(unset)'}")
print(f" namespace: {cfg.namespace or '(unset)'}")
print(f" sudo: {cfg.sudo}")
return 0
def _cmd_pull(args: argparse.Namespace) -> int:
cfg = _config.load()
engine = _engine(cfg, args.sudo)
ref = _resolve_ref(args.app, cfg)
print(f"Pulling {ref}")
return subprocess.call([*engine, "pull", ref])
def _cmd_run(args: argparse.Namespace) -> int:
cfg = _config.load()
engine = _engine(cfg, args.sudo)
ref = _resolve_ref(args.app, cfg)
if not _inspect_labels(engine, ref):
rc = subprocess.call([*engine, "pull", ref])
if rc != 0:
return rc
labels = _inspect_labels(engine, ref)
no_gpu = args.no_gpu and not args.force_gpu
hw_flags, notes = _hardware_flags(labels, no_gpu=no_gpu, no_usb=args.no_usb, no_host_net=args.no_host_net)
if args.force_gpu and "--gpus" not in hw_flags:
hw_flags = ["--gpus", "all", *hw_flags]
cmd = [*engine, "run", "--rm"]
if not args.foreground:
cmd += ["-d"]
cmd += ["--name", args.name or _container_name(ref)]
cmd += hw_flags
if args.config:
cmd += ["-v", f"{args.config}:/config/config.yaml:ro", "-e", "RIA_CONFIG=/config/config.yaml"]
for env in args.env or []:
cmd += ["-e", env]
for vol in args.volume or []:
cmd += ["-v", vol]
for port in args.publish or []:
cmd += ["-p", port]
cmd += list(args.docker_args or [])
cmd += [ref]
cmd += list(args.app_args or [])
if args.dry_run:
print(" ".join(cmd))
return 0
label_str = ", ".join(f"{k}={v}" for k, v in labels.items() if k.startswith("ria.")) or "(no ria.* labels)"
print(f"Running {ref} [{label_str}]")
if hw_flags:
print(f" auto flags: {' '.join(hw_flags)}")
for note in notes:
print(f" note: {note}")
return subprocess.call(cmd)
def _cmd_list(args: argparse.Namespace) -> int:
cfg = _config.load()
engine = _engine(cfg, args.sudo)
return subprocess.call(
[
*engine,
"images",
"--filter",
f"label={_LABEL_APP}",
"--format",
"table {{.Repository}}:{{.Tag}}\t{{.ID}}\t{{.Size}}",
]
)
def _cmd_stop(args: argparse.Namespace) -> int:
cfg = _config.load()
engine = _engine(cfg, args.sudo)
name = args.name or _container_name(_resolve_ref(args.app, cfg))
return subprocess.call([*engine, "stop", name])
def _cmd_logs(args: argparse.Namespace) -> int:
cfg = _config.load()
engine = _engine(cfg, args.sudo)
name = args.name or _container_name(_resolve_ref(args.app, cfg))
cmd = [*engine, "logs"]
if args.follow:
cmd += ["-f"]
cmd += [name]
return subprocess.call(cmd)
def main() -> None:
parser = argparse.ArgumentParser(prog="ria-app")
parser.add_argument("--sudo", action="store_true", default=False, help="Run docker/podman via sudo")
sub = parser.add_subparsers(dest="command", required=True)
p_cfg = sub.add_parser("configure", help="Set default registry/namespace")
p_cfg.add_argument("--registry", default=None, help="Default container registry (e.g. registry.riahub.ai)")
p_cfg.add_argument("--namespace", default=None, help="Default namespace (e.g. qoherent)")
p_cfg.add_argument(
"--sudo",
dest="sudo",
action=argparse.BooleanOptionalAction,
default=None,
help="Persist sudo default (--sudo / --no-sudo)",
)
p_pull = sub.add_parser("pull", help="Pull an app image")
p_pull.add_argument("app", help="App name or image reference")
p_run = sub.add_parser("run", help="Run an app, auto-detecting hardware flags")
p_run.add_argument("app", help="App name or image reference")
p_run.add_argument("--name", default=None, help="Container name (default: ria-app-<app>)")
p_run.add_argument("--config", default=None, help="Path to config.yaml to mount into the container")
p_run.add_argument("-e", "--env", action="append", help="Extra env var (KEY=VALUE)")
p_run.add_argument("-v", "--volume", action="append", help="Extra volume mount")
p_run.add_argument("-p", "--publish", action="append", help="Publish port")
p_run.add_argument("--foreground", "-F", action="store_true", help="Run in foreground (no -d)")
p_run.add_argument("--no-gpu", action="store_true", help="Skip --gpus flag even if image wants GPU")
p_run.add_argument("--force-gpu", action="store_true", help="Force --gpus all even if no NVIDIA runtime detected")
p_run.add_argument("--no-usb", action="store_true", help="Skip --device /dev/bus/usb")
p_run.add_argument("--no-host-net", action="store_true", help="Skip --net host")
p_run.add_argument("--dry-run", action="store_true", help="Print the container command and exit")
p_run.add_argument("--docker-args", nargs=argparse.REMAINDER, help="Pass remaining args to docker/podman run")
p_run.add_argument("--app-args", nargs=argparse.REMAINDER, help="Pass remaining args to the app entrypoint")
sub.add_parser("list", help="List locally cached RIA app images")
p_stop = sub.add_parser("stop", help="Stop a running app")
p_stop.add_argument("app", help="App name or image reference")
p_stop.add_argument("--name", default=None, help="Container name override")
p_logs = sub.add_parser("logs", help="Tail logs of a running app")
p_logs.add_argument("app", help="App name or image reference")
p_logs.add_argument("--name", default=None, help="Container name override")
p_logs.add_argument("-f", "--follow", action="store_true", help="Follow log output")
args = parser.parse_args()
dispatch = {
"configure": _cmd_configure,
"pull": _cmd_pull,
"run": _cmd_run,
"list": _cmd_list,
"stop": _cmd_stop,
"logs": _cmd_logs,
}
sys.exit(dispatch[args.command](args))
if __name__ == "__main__":
main()

View File

@ -0,0 +1,51 @@
"""App runner configuration at ``~/.ria/toolkit.json``.
Schema::
{
"registry": "registry.riahub.ai",
"namespace": "qoherent"
}
"""
from __future__ import annotations
import json
import os
from dataclasses import asdict, dataclass
from pathlib import Path
_DEFAULT_PATH = Path(os.environ.get("RIA_TOOLKIT_CONFIG", str(Path.home() / ".ria" / "toolkit.json")))
@dataclass
class AppConfig:
registry: str = ""
namespace: str = ""
sudo: bool = False
def default_path() -> Path:
return _DEFAULT_PATH
def load(path: Path | None = None) -> AppConfig:
p = path or _DEFAULT_PATH
if not p.exists():
return AppConfig(
registry=os.environ.get("RIA_REGISTRY", ""),
namespace=os.environ.get("RIA_NAMESPACE", ""),
)
data = json.loads(p.read_text())
return AppConfig(
registry=data.get("registry", "") or os.environ.get("RIA_REGISTRY", ""),
namespace=data.get("namespace", "") or os.environ.get("RIA_NAMESPACE", ""),
sudo=bool(data.get("sudo", False)) or os.environ.get("RIA_DOCKER_SUDO", "") not in ("", "0", "false"),
)
def save(cfg: AppConfig, path: Path | None = None) -> Path:
p = path or _DEFAULT_PATH
p.parent.mkdir(parents=True, exist_ok=True)
p.write_text(json.dumps(asdict(cfg), indent=2))
return p

View File

@ -0,0 +1,8 @@
"""
The Data package contains abstract data types tailored for radio machine learning, such as ``Recording``, as well
as the abstract interfaces for the radio dataset and radio dataset builder framework.
"""
__all__ = ["Annotation", "Recording"]
from .annotation import Annotation
from .recording import Recording

View File

@ -7,8 +7,8 @@ from typing import Any, Optional
from packaging.version import Version from packaging.version import Version
from ria_toolkit_oss.datatypes.datasets.license.dataset_license import DatasetLicense from ria_toolkit_oss.data.datasets.license.dataset_license import DatasetLicense
from ria_toolkit_oss.datatypes.datasets.radio_dataset import RadioDataset from ria_toolkit_oss.data.datasets.radio_dataset import RadioDataset
from ria_toolkit_oss.utils.abstract_attribute import abstract_attribute from ria_toolkit_oss.utils.abstract_attribute import abstract_attribute
@ -21,7 +21,8 @@ class DatasetBuilder(ABC):
""" """
_url: str = abstract_attribute() _url: str = abstract_attribute()
_SHA256: str # SHA256 checksum. _SHA256: Optional[str] = None # SHA256 checksum.
_MD5: Optional[str] = None # MD5 checksum.
_name: str = abstract_attribute() _name: str = abstract_attribute()
_author: str = abstract_attribute() _author: str = abstract_attribute()
_license: DatasetLicense = abstract_attribute() _license: DatasetLicense = abstract_attribute()

View File

@ -109,13 +109,10 @@ def copy_file(original_source: str | os.PathLike, new_source: str | os.PathLike)
:return: None :return: None
""" """
original_file = h5py.File(original_source, "r") with h5py.File(original_source, "r") as original_file:
with h5py.File(new_source, "w") as new_file:
with h5py.File(new_source, "w") as new_file: for key in original_file.keys():
for key in original_file.keys(): original_file.copy(key, new_file)
original_file.copy(key, new_file)
original_file.close()
def make_empty_clone(original_source: str | os.PathLike, new_source: str | os.PathLike, example_length: int) -> None: def make_empty_clone(original_source: str | os.PathLike, new_source: str | os.PathLike, example_length: int) -> None:
@ -172,8 +169,10 @@ def delete_example_inplace(source: str | os.PathLike, idx: int) -> None:
with h5py.File(source, "a") as f: with h5py.File(source, "a") as f:
ds, md = f["data"], f["metadata/metadata"] ds, md = f["data"], f["metadata/metadata"]
m, c, n = ds.shape m, c, n = ds.shape
assert 0 <= idx <= m - 1 if not (0 <= idx <= m - 1):
assert len(ds) == len(md) raise IndexError(f"Index {idx} out of range [0, {m - 1}]")
if len(ds) != len(md):
raise ValueError("Data and metadata array lengths do not match")
new_ds = f.create_dataset( new_ds = f.create_dataset(
"data.temp", "data.temp",
@ -218,4 +217,3 @@ def overwrite_file(source: str | os.PathLike, new_data: np.ndarray) -> None:
ds_name = tuple(f.keys())[0] ds_name = tuple(f.keys())[0]
del f[ds_name] del f[ds_name]
f.create_dataset(ds_name, data=new_data) f.create_dataset(ds_name, data=new_data)
f.close()

View File

@ -7,11 +7,11 @@ from typing import Optional
import h5py import h5py
import numpy as np import numpy as np
from ria_toolkit_oss.datatypes.datasets.h5helpers import ( from ria_toolkit_oss.data.datasets.h5helpers import (
append_entry_inplace, append_entry_inplace,
copy_dataset_entry_by_index, copy_dataset_entry_by_index,
) )
from ria_toolkit_oss.datatypes.datasets.radio_dataset import RadioDataset from ria_toolkit_oss.data.datasets.radio_dataset import RadioDataset
class IQDataset(RadioDataset, ABC): class IQDataset(RadioDataset, ABC):
@ -19,7 +19,7 @@ class IQDataset(RadioDataset, ABC):
radiofrequency (RF) signals represented as In-phase (I) and Quadrature (Q) samples. radiofrequency (RF) signals represented as In-phase (I) and Quadrature (Q) samples.
For machine learning tasks that involve processing spectrograms, please use For machine learning tasks that involve processing spectrograms, please use
ria_toolkit_oss.datatypes.datasets.SpectDataset instead. ria_toolkit_oss.data.datasets.SpectDataset instead.
This is an abstract interface defining common properties and behaviour of IQDatasets. Therefore, this class This is an abstract interface defining common properties and behaviour of IQDatasets. Therefore, this class
should not be instantiated directly. Instead, it is subclassed to define custom interfaces for specific machine should not be instantiated directly. Instead, it is subclassed to define custom interfaces for specific machine
@ -169,8 +169,10 @@ class IQDataset(RadioDataset, ABC):
""" """
if split_factor is not None and example_length is not None: if split_factor is not None and example_length is not None:
# Raise warning and use split factor # Warn and use split factor
raise Warning("split_factor and example_length should not both be specified.") import warnings
warnings.warn("split_factor and example_length should not both be specified.")
if not inplace: if not inplace:
# ds = self.create_new_dataset(example_length=example_length) # ds = self.create_new_dataset(example_length=example_length)

View File

@ -12,7 +12,7 @@ import numpy as np
import pandas as pd import pandas as pd
from numpy.typing import ArrayLike from numpy.typing import ArrayLike
from ria_toolkit_oss.datatypes.datasets.h5helpers import ( from ria_toolkit_oss.data.datasets.h5helpers import (
append_entry_inplace, append_entry_inplace,
copy_file, copy_file,
copy_over_example, copy_over_example,
@ -29,7 +29,7 @@ class RadioDataset(ABC):
This is an abstract interface defining common properties and behavior of radio datasets. Therefore, this class This is an abstract interface defining common properties and behavior of radio datasets. Therefore, this class
should not be instantiated directly. Instead, it should be subclassed to define specific interfaces for different should not be instantiated directly. Instead, it should be subclassed to define specific interfaces for different
types of radio datasets. For example, see ria_toolkit_oss.datatypes.datasets.IQDataset, which is a radio dataset types of radio datasets. For example, see ria_toolkit_oss.data.datasets.IQDataset, which is a radio dataset
subclass tailored for tasks involving the processing of radio signals represented as IQ (In-phase and Quadrature) subclass tailored for tasks involving the processing of radio signals represented as IQ (In-phase and Quadrature)
samples. samples.
@ -255,7 +255,9 @@ class RadioDataset(ABC):
else: else:
classes_to_augment = classes_to_augment.encode("utf-8") classes_to_augment = classes_to_augment.encode("utf-8")
if classes_to_augment not in class_sizes: if classes_to_augment not in class_sizes:
raise ValueError(f"class name of {i} does not belong to the class key of {class_key}") raise ValueError(
f"class name of {classes_to_augment} does not belong to the class key of {class_key}"
)
result_sizes = get_result_sizes( result_sizes = get_result_sizes(
level=level, target_size=target_size, classes_to_augment=classes_to_augment, class_sizes=class_sizes level=level, target_size=target_size, classes_to_augment=classes_to_augment, class_sizes=class_sizes
@ -375,7 +377,7 @@ class RadioDataset(ABC):
counters[key] = counters.get(key, 0) counters[key] = counters.get(key, 0)
idx = 0 idx = 0
with h5py.File(self.source, "a") as f: with h5py.File(self.source, "r") as f:
while idx < len(self): while idx < len(self):
labels = f["metadata/metadata"][class_key] labels = f["metadata/metadata"][class_key]
current_class = labels[idx] current_class = labels[idx]
@ -514,7 +516,7 @@ class RadioDataset(ABC):
idx = 0 idx = 0
with h5py.File(self.source, "a") as f: with h5py.File(self.source, "r") as f:
while idx < len(self): while idx < len(self):
labels = f["metadata/metadata"][class_key] labels = f["metadata/metadata"][class_key]
current_class = labels[idx] current_class = labels[idx]
@ -956,10 +958,8 @@ def get_result_sizes( # noqa: C901 # TODO: Simplify function
# Check that each class that will be augmented does not already suffice target_size # Check that each class that will be augmented does not already suffice target_size
for cls_name, target_size_value in zip(classes_to_augment, target_size): for cls_name, target_size_value in zip(classes_to_augment, target_size):
if class_sizes[cls_name] >= target_size_value: if class_sizes[cls_name] >= target_size_value:
raise ValueError( raise ValueError(f"""target_size of {target_size_value} is already sufficed for current size of
f"""target_size of {target_size_value} is already sufficed for current size of {class_sizes[cls_name]} for class: {cls_name}""")
{class_sizes[cls_name]} for class: {cls_name}"""
)
for index, class_name in enumerate(classes_to_augment): for index, class_name in enumerate(classes_to_augment):
result_sizes[class_name] = target_size[index] result_sizes[class_name] = target_size[index]

View File

@ -3,7 +3,7 @@ from __future__ import annotations
import os import os
from abc import ABC from abc import ABC
from ria_toolkit_oss.datatypes.datasets.radio_dataset import RadioDataset from ria_toolkit_oss.data.datasets.radio_dataset import RadioDataset
class SpectDataset(RadioDataset, ABC): class SpectDataset(RadioDataset, ABC):
@ -13,7 +13,7 @@ class SpectDataset(RadioDataset, ABC):
radio signal spectrograms. radio signal spectrograms.
For machine learning tasks that involve processing on IQ samples, please use For machine learning tasks that involve processing on IQ samples, please use
ria_toolkit_oss.datatypes.datasets.IQDataset instead. ria_toolkit_oss.data.datasets.IQDataset instead.
This is an abstract interface defining common properties and behaviour of IQDatasets. Therefore, this class This is an abstract interface defining common properties and behaviour of IQDatasets. Therefore, this class
should not be instantiated directly. Instead, it is subclassed to define custom interfaces for specific machine should not be instantiated directly. Instead, it is subclassed to define custom interfaces for specific machine

View File

@ -6,11 +6,8 @@ from typing import Optional
import numpy as np import numpy as np
from numpy.random import Generator from numpy.random import Generator
from ria_toolkit_oss.datatypes.datasets import RadioDataset from ria_toolkit_oss.data.datasets import RadioDataset
from ria_toolkit_oss.datatypes.datasets.h5helpers import ( from ria_toolkit_oss.data.datasets.h5helpers import copy_over_example, make_empty_clone
copy_over_example,
make_empty_clone,
)
def split(dataset: RadioDataset, lengths: list[int | float]) -> list[RadioDataset]: def split(dataset: RadioDataset, lengths: list[int | float]) -> list[RadioDataset]:
@ -31,7 +28,7 @@ def split(dataset: RadioDataset, lengths: list[int | float]) -> list[RadioDatase
cases. cases.
This function is deterministic, meaning it will always produce the same split. For a random split, see This function is deterministic, meaning it will always produce the same split. For a random split, see
ria_toolkit_oss.datatypes.datasets.random_split. ria_toolkit_oss.data.datasets.random_split.
:param dataset: Dataset to be split. :param dataset: Dataset to be split.
:type dataset: RadioDataset :type dataset: RadioDataset
@ -50,7 +47,7 @@ def split(dataset: RadioDataset, lengths: list[int | float]) -> list[RadioDatase
>>> import string >>> import string
>>> import numpy as np >>> import numpy as np
>>> import pandas as pd >>> import pandas as pd
>>> from ria_toolkit_oss.datatypes.datasets import split >>> from ria_toolkit_oss.data.datasets import split
First, let's generate some random data: First, let's generate some random data:
@ -126,7 +123,7 @@ def random_split(
training and test datasets. training and test datasets.
This restriction makes it unlikely that a random split will produce datasets with the exact lengths specified. This restriction makes it unlikely that a random split will produce datasets with the exact lengths specified.
If it is important to ensure the closest possible split, consider using ria_toolkit_oss.datatypes.datasets.split If it is important to ensure the closest possible split, consider using ria_toolkit_oss.data.datasets.split
instead. instead.
:param dataset: Dataset to be split. :param dataset: Dataset to be split.
@ -144,7 +141,7 @@ def random_split(
:rtype: list of RadioDataset :rtype: list of RadioDataset
See Also: See Also:
ria_toolkit_oss.datatypes.datasets.split: Usage is the same as for ``random_split()``. ria_toolkit_oss.data.datasets.split: Usage is the same as for ``random_split()``.
""" """
if not isinstance(dataset, RadioDataset): if not isinstance(dataset, RadioDataset):
raise ValueError(f"'dataset' must be RadioDataset or one of its subclasses, got {type(dataset)}.") raise ValueError(f"'dataset' must be RadioDataset or one of its subclasses, got {type(dataset)}.")
@ -247,7 +244,7 @@ def _validate_sublists(list_of_lists: list[list[str]], ids: list[str]) -> None:
"""Ensure that each ID is present in one and only one sublist.""" """Ensure that each ID is present in one and only one sublist."""
all_elements = [item for sublist in list_of_lists for item in sublist] all_elements = [item for sublist in list_of_lists for item in sublist]
assert len(all_elements) == len(set(all_elements)) and list(set(ids)).sort() == list(set(all_elements)).sort() assert len(all_elements) == len(set(all_elements)) and sorted(set(ids)) == sorted(set(all_elements))
def _generate_split_source_filenames( def _generate_split_source_filenames(

View File

@ -12,7 +12,7 @@ from typing import Any, Iterator, Optional
import numpy as np import numpy as np
from numpy.typing import ArrayLike from numpy.typing import ArrayLike
from ria_toolkit_oss.datatypes.annotation import Annotation from ria_toolkit_oss.data.annotation import Annotation
PROTECTED_KEYS = ["rec_id", "timestamp"] PROTECTED_KEYS = ["rec_id", "timestamp"]
@ -26,7 +26,7 @@ class Recording:
Metadata is stored in a dictionary of key value pairs, Metadata is stored in a dictionary of key value pairs,
to include information such as sample_rate and center_frequency. to include information such as sample_rate and center_frequency.
Annotations are a list of :class:`~ria_toolkit_oss.datatypes.Annotation`, Annotations are a list of :class:`~ria_toolkit_oss.data.Annotation`,
defining bounding boxes in time and frequency with labels and metadata. defining bounding boxes in time and frequency with labels and metadata.
Here, signal data is represented as a NumPy array. This class is then extended in the RIA Backends to provide Here, signal data is represented as a NumPy array. This class is then extended in the RIA Backends to provide
@ -46,7 +46,7 @@ class Recording:
:param metadata: Additional information associated with the recording. :param metadata: Additional information associated with the recording.
:type metadata: dict, optional :type metadata: dict, optional
:param annotations: A collection of :class:`~ria_toolkit_oss.datatypes.Annotation` objects defining bounding boxes. :param annotations: A collection of :class:`~ria_toolkit_oss.data.Annotation` objects defining bounding boxes.
:type annotations: list of Annotations, optional :type annotations: list of Annotations, optional
:param dtype: Explicitly specify the data-type of the complex samples. Must be a complex NumPy type, such as :param dtype: Explicitly specify the data-type of the complex samples. Must be a complex NumPy type, such as
@ -66,7 +66,7 @@ class Recording:
**Examples:** **Examples:**
>>> import numpy >>> import numpy
>>> from ria_toolkit_oss.datatypes import Recording, Annotation >>> from ria_toolkit_oss.data import Recording, Annotation
>>> # Create an array of complex samples, just 1s in this case. >>> # Create an array of complex samples, just 1s in this case.
>>> samples = numpy.ones(10000, dtype=numpy.complex64) >>> samples = numpy.ones(10000, dtype=numpy.complex64)
@ -146,7 +146,7 @@ class Recording:
self._metadata["timestamp"] = time.time() self._metadata["timestamp"] = time.time()
else: else:
if not isinstance(self._metadata["timestamp"], (int, float)): if not isinstance(self._metadata["timestamp"], (int, float)):
raise ValueError("timestamp must be int or float, not ", type(self._metadata["timestamp"])) raise ValueError(f"timestamp must be int or float, not {type(self._metadata['timestamp'])}")
if "rec_id" not in self.metadata: if "rec_id" not in self.metadata:
self._metadata["rec_id"] = generate_recording_id(data=self.data, timestamp=self._metadata["timestamp"]) self._metadata["rec_id"] = generate_recording_id(data=self.data, timestamp=self._metadata["timestamp"])
@ -244,7 +244,7 @@ class Recording:
@property @property
def sample_rate(self) -> float | None: def sample_rate(self) -> float | None:
""" """
:return: Sample rate of the recording, or None is 'sample_rate' is not in metadata. :return: Sample rate of the recording, or None if 'sample_rate' is not in metadata.
:type: str :type: str
""" """
return self.metadata.get("sample_rate") return self.metadata.get("sample_rate")
@ -311,7 +311,7 @@ class Recording:
Create a recording and add metadata: Create a recording and add metadata:
>>> import numpy >>> import numpy
>>> from ria_toolkit_oss.datatypes import Recording >>> from ria_toolkit_oss.data import Recording
>>> >>>
>>> samples = numpy.ones(10000, dtype=numpy.complex64) >>> samples = numpy.ones(10000, dtype=numpy.complex64)
>>> metadata = { >>> metadata = {
@ -366,7 +366,7 @@ class Recording:
Create a recording and update metadata: Create a recording and update metadata:
>>> import numpy >>> import numpy
>>> from ria_toolkit_oss.datatypes import Recording >>> from ria_toolkit_oss.data import Recording
>>> samples = numpy.ones(10000, dtype=numpy.complex64) >>> samples = numpy.ones(10000, dtype=numpy.complex64)
>>> metadata = { >>> metadata = {
@ -393,6 +393,7 @@ class Recording:
""" """
if key not in self.metadata: if key not in self.metadata:
self.add_to_metadata(key=key, value=value) self.add_to_metadata(key=key, value=value)
return
if not _is_jsonable(value): if not _is_jsonable(value):
raise ValueError("Value must be JSON serializable.") raise ValueError("Value must be JSON serializable.")
@ -420,7 +421,7 @@ class Recording:
Create a recording and add metadata: Create a recording and add metadata:
>>> import numpy >>> import numpy
>>> from ria_toolkit_oss.datatypes import Recording >>> from ria_toolkit_oss.data import Recording
>>> samples = numpy.ones(10000, dtype=numpy.complex64) >>> samples = numpy.ones(10000, dtype=numpy.complex64)
>>> metadata = { >>> metadata = {
@ -444,7 +445,7 @@ class Recording:
'rec_id': 'fda0f41...'} # Example value 'rec_id': 'fda0f41...'} # Example value
""" """
if key not in PROTECTED_KEYS: if key not in PROTECTED_KEYS:
self._metadata.pop(key) self._metadata.pop(key, None)
else: else:
raise ValueError(f"Key {key} is protected and cannot be modified or removed.") raise ValueError(f"Key {key} is protected and cannot be modified or removed.")
@ -453,7 +454,7 @@ class Recording:
:param output_path: The output image path. Defaults to "images/signal.png". :param output_path: The output image path. Defaults to "images/signal.png".
:type output_path: str, optional :type output_path: str, optional
:param kwargs: Keyword arguments passed on to utils.view.view_sig. :param kwargs: Keyword arguments passed on to ria_toolkit_oss.view.view_sig.
:type: dict of keyword arguments :type: dict of keyword arguments
**Examples:** **Examples:**
@ -461,7 +462,7 @@ class Recording:
Create a recording and view it as a plot in a .png image: Create a recording and view it as a plot in a .png image:
>>> import numpy >>> import numpy
>>> from utils.data import Recording >>> from ria_toolkit_oss.data import Recording
>>> samples = numpy.ones(10000, dtype=numpy.complex64) >>> samples = numpy.ones(10000, dtype=numpy.complex64)
>>> metadata = { >>> metadata = {
@ -479,7 +480,7 @@ class Recording:
def simple_view(self, **kwargs) -> None: def simple_view(self, **kwargs) -> None:
"""Create a plot of various signal visualizations as a PNG or SVG image. """Create a plot of various signal visualizations as a PNG or SVG image.
:param kwargs: Keyword arguments passed on to utils.view.view_signal_simple.create_plots. :param kwargs: Keyword arguments passed on to ria_toolkit_oss.view.view_signal_simple.view_simple_sig.
:type: dict of keyword arguments :type: dict of keyword arguments
**Examples:** **Examples:**
@ -487,7 +488,7 @@ class Recording:
Create a recording and view it as a plot in a .png image: Create a recording and view it as a plot in a .png image:
>>> import numpy >>> import numpy
>>> from utils.data import Recording >>> from ria_toolkit_oss.data import Recording
>>> samples = numpy.ones(10000, dtype=numpy.complex64) >>> samples = numpy.ones(10000, dtype=numpy.complex64)
>>> metadata = { >>> metadata = {
@ -510,7 +511,7 @@ class Recording:
The SigMF io format is defined by the `SigMF Specification Project <https://github.com/sigmf/SigMF>`_ The SigMF io format is defined by the `SigMF Specification Project <https://github.com/sigmf/SigMF>`_
:param recording: The recording to be written to file. :param recording: The recording to be written to file.
:type recording: ria_toolkit_oss.datatypes.Recording :type recording: ria_toolkit_oss.data.Recording
:param filename: The name of the file where the recording is to be saved. Defaults to auto generated filename. :param filename: The name of the file where the recording is to be saved. Defaults to auto generated filename.
:type filename: os.PathLike or str, optional :type filename: os.PathLike or str, optional
:param path: The directory path to where the recording is to be saved. Defaults to recordings/. :param path: The directory path to where the recording is to be saved. Defaults to recordings/.
@ -544,7 +545,7 @@ class Recording:
Create a recording and save it to a .npy file: Create a recording and save it to a .npy file:
>>> import numpy >>> import numpy
>>> from ria_toolkit_oss.datatypes import Recording >>> from ria_toolkit_oss.data import Recording
>>> samples = numpy.ones(10000, dtype=numpy.complex64) >>> samples = numpy.ones(10000, dtype=numpy.complex64)
>>> metadata = { >>> metadata = {
@ -559,6 +560,102 @@ class Recording:
to_npy(recording=self, filename=filename, path=path, overwrite=overwrite) to_npy(recording=self, filename=filename, path=path, overwrite=overwrite)
def to_wav(
self,
filename: Optional[str] = None,
path: Optional[os.PathLike | str] = None,
target_sample_rate: Optional[int] = 48000,
bits_per_sample: int = 32,
overwrite: bool = False,
) -> str:
"""Write recording to WAV file with embedded YAML metadata.
WAV format uses stereo audio with I (in-phase) in left channel and Q (quadrature) in right channel.
Metadata is stored in standard LIST INFO chunks with RF-specific metadata encoded as YAML
in the ICMT (comment) field for human readability.
:param filename: The name of the file where the recording is to be saved. Defaults to auto generated filename.
:type filename: os.PathLike or str, optional
:param path: The directory path to where the recording is to be saved. Defaults to recordings/.
:type path: os.PathLike or str, optional
:param target_sample_rate: Sample rate stored in the WAV header when no sample_rate metadata
is present. IQ samples are written without decimation or interpolation. Default is 48000 Hz.
:type target_sample_rate: int, optional
:param bits_per_sample: Bits per sample (32 for float32, 16 for int16). Default is 32.
:type bits_per_sample: int, optional
:param overwrite: Whether to overwrite existing files. Default is False.
:type overwrite: bool, optional
:raises IOError: If there is an issue encountered during the file writing process.
:return: Path where the file was saved.
:rtype: str
**Examples:**
Create a recording and save it to a .wav file:
>>> import numpy
>>> from ria_toolkit_oss.data import Recording
>>> samples = numpy.exp(1j * 2 * numpy.pi * 0.1 * numpy.arange(10000))
>>> metadata = {"sample_rate": 1e6, "center_frequency": 915e6}
>>> recording = Recording(data=samples, metadata=metadata)
>>> recording.to_wav()
"""
from ria_toolkit_oss.io.recording import to_wav
return to_wav(
recording=self,
filename=filename,
path=path,
target_sample_rate=target_sample_rate,
bits_per_sample=bits_per_sample,
overwrite=overwrite,
)
def to_blue(
self,
filename: Optional[str] = None,
path: Optional[os.PathLike | str] = None,
data_format: str = "CI",
overwrite: bool = False,
) -> str:
"""Write recording to MIDAS Blue file format.
MIDAS Blue is a legacy RF file format with a 512-byte binary header.
Commonly used with X-Midas and other RF/radar signal processing tools.
:param filename: The name of the file where the recording is to be saved. Defaults to auto generated filename.
:type filename: os.PathLike or str, optional
:param path: The directory path to where the recording is to be saved. Defaults to recordings/.
:type path: os.PathLike or str, optional
:param data_format: Format code (default 'CI' = complex int16).
Common formats: 'CI' (complex int16), 'CF' (complex float32), 'CD' (complex float64).
Integer formats require the IQ samples to already be scaled within [-1, 1).
:type data_format: str, optional
:param overwrite: Whether to overwrite existing files. Default is False.
:type overwrite: bool, optional
:raises IOError: If there is an issue encountered during the file writing process.
:return: Path where the file was saved.
:rtype: str
**Examples:**
Create a recording and save it to a .blue file:
>>> import numpy
>>> from ria_toolkit_oss.data import Recording
>>> samples = numpy.ones(10000, dtype=numpy.complex64)
>>> metadata = {"sample_rate": 1e6, "center_frequency": 2.44e9}
>>> recording = Recording(data=samples, metadata=metadata)
>>> recording.to_blue()
"""
from ria_toolkit_oss.io.recording import to_blue
return to_blue(recording=self, filename=filename, path=path, data_format=data_format, overwrite=overwrite)
def trim(self, num_samples: int, start_sample: Optional[int] = 0) -> Recording: def trim(self, num_samples: int, start_sample: Optional[int] = 0) -> Recording:
"""Trim Recording samples to a desired length, shifting annotations to maintain alignment. """Trim Recording samples to a desired length, shifting annotations to maintain alignment.
@ -577,7 +674,7 @@ class Recording:
Create a recording and trim it: Create a recording and trim it:
>>> import numpy >>> import numpy
>>> from ria_toolkit_oss.datatypes import Recording >>> from ria_toolkit_oss.data import Recording
>>> samples = numpy.ones(10000, dtype=numpy.complex64) >>> samples = numpy.ones(10000, dtype=numpy.complex64)
>>> metadata = { >>> metadata = {
@ -606,7 +703,14 @@ class Recording:
data = self.data[:, start_sample:end_sample] data = self.data[:, start_sample:end_sample]
new_annotations = copy.deepcopy(self.annotations) new_annotations = copy.deepcopy(self.annotations)
trimmed_annotations = []
for annotation in new_annotations: for annotation in new_annotations:
# skip annotations entirely outside the trim window
if annotation.sample_start + annotation.sample_count <= start_sample:
continue
if annotation.sample_start >= end_sample:
continue
# trim annotation if it goes outside the trim boundaries # trim annotation if it goes outside the trim boundaries
if annotation.sample_start < start_sample: if annotation.sample_start < start_sample:
annotation.sample_count = annotation.sample_count - (start_sample - annotation.sample_start) annotation.sample_count = annotation.sample_count - (start_sample - annotation.sample_start)
@ -617,8 +721,9 @@ class Recording:
# shift annotation to align with the new start point # shift annotation to align with the new start point
annotation.sample_start = annotation.sample_start - start_sample annotation.sample_start = annotation.sample_start - start_sample
trimmed_annotations.append(annotation)
return Recording(data=data, metadata=self.metadata, annotations=new_annotations) return Recording(data=data, metadata=self.metadata, annotations=trimmed_annotations)
def normalize(self) -> Recording: def normalize(self) -> Recording:
"""Scale the recording data, relative to its maximum value, so that the magnitude of the maximum sample is 1. """Scale the recording data, relative to its maximum value, so that the magnitude of the maximum sample is 1.
@ -631,7 +736,7 @@ class Recording:
Create a recording with maximum amplitude 0.5 and normalize to a maximum amplitude of 1: Create a recording with maximum amplitude 0.5 and normalize to a maximum amplitude of 1:
>>> import numpy >>> import numpy
>>> from ria_toolkit_oss.datatypes import Recording >>> from ria_toolkit_oss.data import Recording
>>> samples = numpy.ones(10000, dtype=numpy.complex64) * 0.5 >>> samples = numpy.ones(10000, dtype=numpy.complex64) * 0.5
>>> metadata = { >>> metadata = {
@ -647,7 +752,10 @@ class Recording:
>>> print(numpy.max(numpy.abs(normalized_recording.data))) >>> print(numpy.max(numpy.abs(normalized_recording.data)))
1 1
""" """
scaled_data = self.data / np.max(abs(self.data)) max_val = np.max(abs(self.data))
if max_val == 0:
raise ValueError("Cannot normalize a recording with all-zero data.")
scaled_data = self.data / max_val
return Recording(data=scaled_data, metadata=self.metadata, annotations=self.annotations) return Recording(data=scaled_data, metadata=self.metadata, annotations=self.annotations)
def __len__(self) -> int: def __len__(self) -> int:

View File

@ -1,8 +0,0 @@
"""
The datatypes package contains abstract data types tailored for radio machine learning.
"""
__all__ = ["Annotation", "Recording"]
from .annotation import Annotation
from .recording import Recording

View File

@ -2,3 +2,37 @@
The IO package contains utilities for input and output operations, such as loading and saving recordings to and from The IO package contains utilities for input and output operations, such as loading and saving recordings to and from
file. file.
""" """
__all__ = [
# Common:
"exists",
"copy",
"move",
"validate",
# Recording:
"save_recording",
"load_recording",
"to_sigmf",
"from_sigmf",
"to_npy",
"from_npy",
"from_npy_legacy",
"to_wav",
"from_wav",
"to_blue",
"from_blue",
]
from .common import copy, exists, move, validate
from .recording import (
from_blue,
from_npy,
from_npy_legacy,
from_sigmf,
from_wav,
load_recording,
to_blue,
to_npy,
to_sigmf,
to_wav,
)

View File

@ -0,0 +1,83 @@
"""
Utilities for common input/output operations.
"""
import os
import ria_toolkit_oss
def exists(fid: str | os.PathLike) -> bool:
"""Check if the file or directory exists.
.. todo::
This method is not yet implemented.
:param fid: The path to the file or directory to check for existence.
:type fid: str or os.PathLike
:return: True if the file or directory exists, False otherwise.
:rtype: bool
"""
raise NotImplementedError
def validate(fid: str | os.PathLike) -> bool:
"""Validate the contents of the file or directory to ensure it is not corrupted,
the correct format for its extension, and readable RIA.
.. todo::
This method is not yet implemented.
:param fid: The path to the file or directory to validate.
:type fid: str or os.PathLike
:return: True if the file or directory is valid and readable, False otherwise.
"""
raise NotImplementedError
def move(source_path: str | os.PathLike, destination_path: str | os.PathLike, copy: bool = False) -> None:
"""Recursively move a file or directory at source_path to destination_path.
.. todo::
This method is not yet implemented.
:param source_path: The path to the source file or directory.
:type source_path: str or os.PathLike
:param destination_path: The path to the destination directory.
:type destination_path: str or os.PathLike
:param copy: If True, perform a copy instead of a move. Default is False.
:type copy: bool, optional
:raises RuntimeError: If the move was unsuccessful.
:return: None
"""
if copy:
ria_toolkit_oss.io.common.copy(source_path=source_path, destination_path=destination_path)
return
raise NotImplementedError
def copy(source_path: str | os.PathLike, destination_path: str | os.PathLike) -> None:
"""Copy the file or directory at source_path to destination_path.
.. todo::
This function is not yet implemented.
:param source_path: The path to the source file or directory.
:type source_path: str or os.PathLike
:param destination_path: The path to the destination directory.
:type destination_path: str or os.PathLike
:raises RuntimeError: If the copy was unsuccessful.
:return: None
"""
raise NotImplementedError

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,26 @@
"""Orchestration layer for automated RF capture campaigns."""
from .campaign import (
CampaignConfig,
CaptureStep,
QAConfig,
RecorderConfig,
TransmitterConfig,
)
from .executor import CampaignExecutor, CampaignResult, StepResult
from .labeler import label_recording
from .qa import QAResult, check_recording
__all__ = [
"CampaignConfig",
"CaptureStep",
"QAConfig",
"RecorderConfig",
"TransmitterConfig",
"CampaignExecutor",
"CampaignResult",
"StepResult",
"label_recording",
"QAResult",
"check_recording",
]

View File

@ -0,0 +1,503 @@
"""Campaign configuration schema and YAML parser for orchestrated RF captures."""
from __future__ import annotations
import re
from dataclasses import dataclass, field
from pathlib import Path
from typing import Optional
import yaml
# Allowed characters in campaign names when used as filename components.
_SAFE_NAME_RE = re.compile(r"[^a-zA-Z0-9_\-]")
# Reasonable RF bounds for consumer/research SDR hardware.
_FREQ_MIN_HZ = 1.0 # 1 Hz
_FREQ_MAX_HZ = 300e9 # 300 GHz
_GAIN_MIN_DB = -30.0
_GAIN_MAX_DB = 120.0
# ---------------------------------------------------------------------------
# Parsing helpers
# ---------------------------------------------------------------------------
def parse_duration(value: str | float | int) -> float:
"""Parse a duration string to seconds.
Accepts:
"30s" 30.0
"1.5m" or "1.5min" 90.0
"2h" 7200.0
30 (numeric) 30.0
"""
if isinstance(value, (int, float)):
return float(value)
value = str(value).strip()
match = re.fullmatch(r"([\d.]+)\s*(s|sec|m|min|h|hr)?", value, re.IGNORECASE)
if not match:
raise ValueError(f"Cannot parse duration: '{value}'")
amount = float(match.group(1))
unit = (match.group(2) or "s").lower()
if unit in ("h", "hr"):
return amount * 3600
if unit in ("m", "min"):
return amount * 60
return amount
def parse_frequency(value: str | float | int) -> float:
"""Parse a frequency string to Hz.
Accepts:
"2.45GHz" 2_450_000_000.0
"40MHz" 40_000_000.0
"915e6" 915_000_000.0
2.45e9 (numeric) 2_450_000_000.0
"""
if isinstance(value, (int, float)):
result = float(value)
if not (_FREQ_MIN_HZ <= result <= _FREQ_MAX_HZ):
raise ValueError(
f"Frequency {result:.3g} Hz is outside the supported range "
f"({_FREQ_MIN_HZ:.0f} Hz {_FREQ_MAX_HZ:.3g} Hz)"
)
return result
value = str(value).strip()
# Try bare numeric first (handles scientific notation like "915e6")
try:
result = float(value)
except ValueError:
pass
else:
if not (_FREQ_MIN_HZ <= result <= _FREQ_MAX_HZ):
raise ValueError(
f"Frequency {result:.3g} Hz is outside the supported range "
f"({_FREQ_MIN_HZ:.0f} Hz {_FREQ_MAX_HZ:.3g} Hz): '{value}'"
)
return result
# Handle suffix notation: "2.45GHz", "40MHz", "40M", "433k"
match = re.fullmatch(r"([\d.]+)\s*(k|M|G)(?:\s*Hz?)?", value, re.IGNORECASE)
if match:
amount = float(match.group(1))
suffix = match.group(2).upper()
result = amount * {"K": 1e3, "M": 1e6, "G": 1e9}[suffix]
if not (_FREQ_MIN_HZ <= result <= _FREQ_MAX_HZ):
raise ValueError(
f"Frequency {result:.3g} Hz is outside the supported range "
f"({_FREQ_MIN_HZ:.0f} Hz {_FREQ_MAX_HZ:.3g} Hz): '{value}'"
)
return result
raise ValueError(f"Cannot parse frequency: '{value}'")
def parse_gain(value: str | float | int) -> float | str:
"""Parse a gain string.
Accepts:
"40dB" or "40 dB" 40.0
"auto" "auto"
40 (numeric) 40.0
"""
if isinstance(value, (int, float)):
result = float(value)
if not (_GAIN_MIN_DB <= result <= _GAIN_MAX_DB):
raise ValueError(f"Gain {result} dB is outside the supported range ({_GAIN_MIN_DB} {_GAIN_MAX_DB} dB)")
return result
value = str(value).strip()
if value.lower() == "auto":
return "auto"
match = re.fullmatch(r"([\d.+\-]+)\s*dB?", value, re.IGNORECASE)
if not match:
raise ValueError(f"Cannot parse gain: '{value}'")
result = float(match.group(1))
if not (_GAIN_MIN_DB <= result <= _GAIN_MAX_DB):
raise ValueError(
f"Gain {result} dB is outside the supported range ({_GAIN_MIN_DB} {_GAIN_MAX_DB} dB): '{value}'"
)
return result
def parse_bandwidth_mhz(value: str | float | int | None) -> Optional[float]:
"""Parse a bandwidth string to MHz.
Accepts:
"20MHz" 20.0
"40MHz" 40.0
20 (numeric, assumed MHz) 20.0
None None
"""
if value is None:
return None
if isinstance(value, (int, float)):
return float(value)
value = str(value).strip()
match = re.fullmatch(r"([\d.]+)\s*MHz?", value, re.IGNORECASE)
if match:
return float(match.group(1))
match = re.fullmatch(r"([\d.]+)", value)
if match:
return float(match.group(1))
raise ValueError(f"Cannot parse bandwidth: '{value}'")
# ---------------------------------------------------------------------------
# Config dataclasses
# ---------------------------------------------------------------------------
@dataclass
class RecorderConfig:
"""SDR recorder configuration."""
device: str
center_freq: float # Hz
sample_rate: float # Hz
gain: float | str # dB float, or "auto"
bandwidth: Optional[float] = None # Hz, None = match sample_rate
@classmethod
def from_dict(cls, d: dict) -> "RecorderConfig":
gain = parse_gain(d.get("gain", "auto"))
bandwidth_raw = d.get("bandwidth") or d.get("bandwidth_hz")
bandwidth = parse_frequency(bandwidth_raw) if bandwidth_raw else None
return cls(
device=str(d["device"]),
center_freq=parse_frequency(d["center_freq"]),
sample_rate=parse_frequency(d["sample_rate"]),
gain=gain,
bandwidth=bandwidth,
)
@dataclass
class CaptureStep:
"""A single timed capture within a transmitter schedule."""
duration: float # seconds
label: str # used as filename component
# WiFi-specific
channel: Optional[int] = None
bandwidth_mhz: Optional[float] = None # MHz
traffic: Optional[str] = None
# Bluetooth-specific
connection_interval_ms: Optional[float] = None
# Power (dBm), optional
power_dbm: Optional[float] = None
@classmethod
def from_dict(cls, d: dict, auto_label: bool = True) -> "CaptureStep":
duration = parse_duration(d["duration"])
label = d.get("label", "")
if not label and auto_label:
parts = []
if d.get("channel"):
parts.append(f"ch{d['channel']:02d}")
if d.get("bandwidth"):
bw = parse_bandwidth_mhz(d["bandwidth"])
parts.append(f"{int(bw)}mhz")
if d.get("traffic"):
parts.append(str(d["traffic"]).replace(" ", "_"))
label = "_".join(parts) if parts else "capture"
return cls(
duration=duration,
label=label,
channel=d.get("channel"),
bandwidth_mhz=parse_bandwidth_mhz(d.get("bandwidth")),
traffic=d.get("traffic"),
connection_interval_ms=d.get("connection_interval_ms"),
power_dbm=float(d["power"].removesuffix("dBm").strip()) if d.get("power") else None,
)
@dataclass
class TransmitterConfig:
"""Configuration for a single transmitter device in the campaign."""
id: str
type: str # "wifi", "bluetooth", "sdr", "external"
control_method: str # "external_script" | "sdr" | "sdr_remote"
schedule: list[CaptureStep]
# For external_script control
script: Optional[str] = None # path to control script
device: Optional[str] = None # e.g. "/dev/wlan0"
# For sdr_remote control — keys: host, ssh_user, ssh_key_path, device_type, device_id, zmq_port
sdr_remote: Optional[dict] = None
# For sdr_agent control — keys: modulation, order, symbol_rate, center_frequency, filter, rolloff
sdr_agent: Optional[dict] = None
@classmethod
def from_dict(cls, d: dict) -> "TransmitterConfig":
schedule = [CaptureStep.from_dict(s) for s in d.get("schedule", [])]
return cls(
id=str(d["id"]),
type=str(d["type"]),
control_method=str(d.get("control_method", "external_script")),
schedule=schedule,
script=d.get("script"),
device=d.get("device"),
sdr_remote=d.get("sdr_remote"),
sdr_agent=d.get("sdr_agent"),
)
@dataclass
class QAConfig:
"""Quality assurance thresholds."""
snr_threshold_db: float = 10.0
min_duration_s: float = 25.0
flag_for_review: bool = True
@classmethod
def from_dict(cls, d: dict) -> "QAConfig":
return cls(
snr_threshold_db=float(str(d.get("snr_threshold", "10")).rstrip("dB").strip()),
min_duration_s=parse_duration(d.get("min_duration", "25s")),
flag_for_review=bool(d.get("flag_for_review", True)),
)
@dataclass
class OutputConfig:
"""Where to save captured recordings."""
format: str = "sigmf"
path: str = "recordings"
device_id: Optional[str] = None # for device-profile campaigns
repo: Optional[str] = None
folder: Optional[str] = None # repo subfolder: None = use campaign name, "" = no subfolder, str = custom
@classmethod
def from_dict(cls, d: dict) -> "OutputConfig":
return cls(
format=str(d.get("format", "sigmf")),
path=str(d.get("path", "recordings")),
device_id=d.get("device_id"),
repo=d.get("repo"),
folder=d.get("folder"),
)
@dataclass
class CampaignConfig:
"""Full campaign configuration parsed from YAML."""
name: str
recorder: RecorderConfig
transmitters: list[TransmitterConfig]
qa: QAConfig = field(default_factory=QAConfig)
output: OutputConfig = field(default_factory=OutputConfig)
mode: str = "controlled_testbed"
loops: int = 1 # repeat full schedule this many times; labels get _run{N:02d} suffix
# ---------------------------------------------------------------------------
# Loaders
# ---------------------------------------------------------------------------
@classmethod
def from_dict(cls, raw: dict) -> "CampaignConfig":
"""Build a CampaignConfig from a parsed dictionary.
Accepts the same structure as the campaign YAML, already loaded into
a Python dict (e.g. from a JSON HTTP request body).
Raises:
ValueError: If required fields are missing or malformed.
KeyError: If ``recorder`` key is absent.
"""
campaign_meta = raw.get("campaign", {})
transmitters = [TransmitterConfig.from_dict(t) for t in raw.get("transmitters", [])]
if not transmitters:
raise ValueError("Campaign config must define at least one transmitter")
if "recorder" not in raw:
raise ValueError("Campaign config is missing required 'recorder' section")
raw_name = str(campaign_meta.get("name", "unnamed"))
safe_name = _SAFE_NAME_RE.sub("_", raw_name)
return cls(
name=safe_name,
mode=str(campaign_meta.get("mode", "controlled_testbed")),
loops=max(1, int(campaign_meta.get("loops", 1))),
recorder=RecorderConfig.from_dict(raw["recorder"]),
transmitters=transmitters,
qa=QAConfig.from_dict(raw.get("qa", {})),
output=OutputConfig.from_dict(raw.get("output", {})),
)
@classmethod
def from_yaml(cls, path: str | Path) -> "CampaignConfig":
"""Load a full campaign config YAML.
Expected format::
campaign:
name: "wifi_capture_001"
mode: "controlled_testbed"
transmitters:
- id: "laptop_wifi"
type: "wifi"
control_method: "external_script"
script: "./scripts/wifi_control.sh"
device: "/dev/wlan0"
schedule:
- channel: 6
bandwidth: "20MHz"
traffic: "iperf_udp"
duration: "30s"
recorder:
device: "usrp_b210"
center_freq: "2.45GHz"
sample_rate: "40MHz"
gain: "40dB"
qa:
snr_threshold: "10dB"
min_duration: "25s"
flag_for_review: true
output:
format: "sigmf"
path: "./recordings"
"""
path = Path(path)
try:
with open(path) as f:
raw = yaml.safe_load(f)
except FileNotFoundError:
raise FileNotFoundError(f"Campaign config not found: {path}")
except yaml.YAMLError as e:
raise ValueError(f"Invalid YAML in {path}: {e}")
campaign_meta = raw.get("campaign", {})
transmitters = [TransmitterConfig.from_dict(t) for t in raw.get("transmitters", [])]
if not transmitters:
raise ValueError("Campaign config must define at least one transmitter")
if "recorder" not in raw:
raise ValueError(f"Campaign config is missing required 'recorder' section in {path}")
raw_name = str(campaign_meta.get("name", path.stem))
safe_name = _SAFE_NAME_RE.sub("_", raw_name)
return cls(
name=safe_name,
mode=str(campaign_meta.get("mode", "controlled_testbed")),
loops=max(1, int(campaign_meta.get("loops", 1))),
recorder=RecorderConfig.from_dict(raw["recorder"]),
transmitters=transmitters,
qa=QAConfig.from_dict(raw.get("qa", {})),
output=OutputConfig.from_dict(raw.get("output", {})),
)
@classmethod
def from_device_profile(cls, path: str | Path) -> "CampaignConfig":
"""Build a campaign config from an App 1 device profile YAML.
Expected format::
device:
name: "iPhone_13_WiFi"
type: "wifi"
protocol: "wifi_24ghz"
capture:
channels: [1, 6, 11] # WiFi only
bandwidth: "20MHz" # WiFi only
traffic_patterns: ["idle", "ping", "iperf_udp"]
duration_per_config: "30s"
recorder:
device: "usrp_b210"
center_freq: "2.45GHz"
sample_rate: "40MHz"
gain: "auto"
output:
path: "./recordings"
device_id: "iphone13_wifi_001"
For WiFi devices, schedule is expanded as channels × traffic_patterns.
For Bluetooth devices (no channels), schedule is traffic_patterns only.
"""
path = Path(path)
try:
with open(path) as f:
raw = yaml.safe_load(f)
except FileNotFoundError:
raise FileNotFoundError(f"Device profile not found: {path}")
except yaml.YAMLError as e:
raise ValueError(f"Invalid YAML in {path}: {e}")
device = raw.get("device", {})
capture = raw.get("capture", {})
device_type = str(device.get("type", "wifi")).lower()
device_name = str(device.get("name", path.stem))
duration = parse_duration(capture.get("duration_per_config", "30s"))
traffic_patterns = capture.get("traffic_patterns", ["idle"])
# Build capture schedule
schedule: list[CaptureStep] = []
if device_type in ("wifi", "wifi_24ghz", "wifi_5ghz"):
channels = capture.get("channels", [6])
bw_str = capture.get("bandwidth", "20MHz")
bw_mhz = parse_bandwidth_mhz(bw_str)
for ch in channels:
for traffic in traffic_patterns:
label = f"ch{ch:02d}_{int(bw_mhz)}mhz_{traffic}"
schedule.append(
CaptureStep(
duration=duration,
label=label,
channel=ch,
bandwidth_mhz=bw_mhz,
traffic=traffic,
)
)
else:
# Bluetooth / generic — no channels
for traffic in traffic_patterns:
schedule.append(
CaptureStep(
duration=duration,
label=traffic,
traffic=traffic,
)
)
device_id = raw.get("output", {}).get("device_id", device_name.lower().replace(" ", "_"))
transmitter = TransmitterConfig(
id=device_id,
type=device_type,
control_method=str(capture.get("control_method", "external_script")),
schedule=schedule,
script=capture.get("script"),
device=capture.get("device"),
)
return cls(
name=f"enroll_{device_id}",
mode="controlled_testbed",
recorder=RecorderConfig.from_dict(raw["recorder"]),
transmitters=[transmitter],
qa=QAConfig.from_dict(raw.get("qa", {})),
output=OutputConfig.from_dict(raw.get("output", {})),
)
def total_capture_time_s(self) -> float:
"""Sum of all step durations across all transmitters and loops."""
return sum(step.duration for tx in self.transmitters for step in tx.schedule) * self.loops
def total_steps(self) -> int:
"""Total number of capture steps across all transmitters and loops."""
return sum(len(tx.schedule) for tx in self.transmitters) * self.loops

View File

@ -0,0 +1,570 @@
"""Campaign executor: runs a capture campaign end-to-end."""
from __future__ import annotations
import json
import logging
import subprocess
import threading
import time
from dataclasses import dataclass, field, replace
from pathlib import Path
from typing import Callable, Optional
from ria_toolkit_oss.data.recording import Recording
from ria_toolkit_oss.io.recording import to_sigmf
from .campaign import CampaignConfig, CaptureStep, TransmitterConfig
from .labeler import build_output_filename, label_recording
from .qa import QAResult, check_recording
from .tx_executor import TxExecutor
logger = logging.getLogger(__name__)
# Device name aliases: campaign YAML names → get_sdr_device() names
_DEVICE_ALIASES = {
"usrp_b210": "usrp",
"usrp_b200": "usrp",
"usrp": "usrp",
"plutosdr": "pluto",
"pluto": "pluto",
"hackrf": "hackrf",
"hackrf_one": "hackrf",
"bladerf": "bladerf",
"rtlsdr": "rtlsdr",
"rtl_sdr": "rtlsdr",
"thinkrf": "thinkrf",
# Simulated device — no hardware required
"mock": "mock",
"sim": "mock",
}
@dataclass
class StepResult:
"""Outcome of a single capture step."""
transmitter_id: str
step_label: str
output_path: Optional[str]
qa: QAResult
capture_timestamp: float
error: Optional[str] = None
@property
def ok(self) -> bool:
return self.error is None and self.qa.passed
def to_dict(self) -> dict:
return {
"transmitter_id": self.transmitter_id,
"step_label": self.step_label,
"output_path": self.output_path,
"capture_timestamp": self.capture_timestamp,
"qa": self.qa.to_dict(),
"error": self.error,
}
@dataclass
class CampaignResult:
"""Aggregate outcome of a full campaign."""
campaign_name: str
steps: list[StepResult] = field(default_factory=list)
start_time: float = field(default_factory=time.time)
end_time: Optional[float] = None
@property
def total_steps(self) -> int:
return len(self.steps)
@property
def passed(self) -> int:
return sum(1 for s in self.steps if s.ok)
@property
def flagged(self) -> int:
return sum(1 for s in self.steps if not s.error and s.qa.flagged)
@property
def failed(self) -> int:
return sum(1 for s in self.steps if s.error or not s.qa.passed)
@property
def duration_s(self) -> float:
if self.end_time:
return self.end_time - self.start_time
return time.time() - self.start_time
def to_dict(self) -> dict:
return {
"campaign_name": self.campaign_name,
"total_steps": self.total_steps,
"passed": self.passed,
"flagged": self.flagged,
"failed": self.failed,
"duration_s": round(self.duration_s, 1),
"steps": [s.to_dict() for s in self.steps],
}
def write_report(self, path: str | Path) -> None:
"""Write a JSON QA report to disk."""
path = Path(path)
path.parent.mkdir(parents=True, exist_ok=True)
with open(path, "w") as f:
json.dump(self.to_dict(), f, indent=2)
logger.info(f"QA report written to {path}")
# ---------------------------------------------------------------------------
# External script interface
# ---------------------------------------------------------------------------
def _run_script(script: str, *args: str, timeout: float = 15.0) -> str:
"""Run an external control script and return stdout.
The script is called as::
<script> <arg1> <arg2> ...
A non-zero return code raises RuntimeError.
Args:
script: Path to executable script. Must be an absolute path to an
existing regular file. Relative paths are rejected to prevent
accidentally executing files that are not the intended script.
*args: Positional arguments forwarded to the script.
timeout: Maximum seconds to wait.
Returns:
Script stdout as a string.
"""
if not Path(script).is_absolute():
raise RuntimeError(f"Script path must be absolute: {script}")
script_path = Path(script).resolve()
if not script_path.is_file():
raise RuntimeError(f"Script not found or is not a regular file: {script}")
cmd = [str(script_path), *args]
logger.debug(f"Running script: {' '.join(cmd)}")
try:
result = subprocess.run(
cmd,
capture_output=True,
text=True,
timeout=timeout,
)
except subprocess.TimeoutExpired:
raise RuntimeError(f"Script timed out after {timeout}s: {script}")
except FileNotFoundError:
raise RuntimeError(f"Script not found: {script}")
if result.returncode != 0:
raise RuntimeError(f"Script exited {result.returncode}: {result.stderr.strip() or result.stdout.strip()}")
return result.stdout.strip()
# ---------------------------------------------------------------------------
# Campaign executor
# ---------------------------------------------------------------------------
def _extract_tx_params(transmitter: TransmitterConfig) -> dict | None:
"""Build a tx_params dict from a transmitter's signal config for SigMF labeling.
For sdr_agent transmitters, returns the synthetic generation parameters
(modulation, order, symbol_rate, etc.) so recordings capture what was
transmitted. Returns None for control methods without signal-level params.
"""
sdr_agent_cfg = getattr(transmitter, "sdr_agent", None)
if not sdr_agent_cfg:
return None
# Extract known signal-level fields; ignore infra fields
_INFRA_KEYS = {"node_id", "session_code"}
return {k: v for k, v in sdr_agent_cfg.items() if k not in _INFRA_KEYS and v is not None}
class CampaignExecutor:
"""Executes a :class:`CampaignConfig` end-to-end.
Initialises the SDR recorder once, then for each (transmitter, step):
1. Configures the transmitter (via external script or SDR TX)
2. Records IQ samples
3. Labels the recording with device/config metadata
4. Runs QA checks
5. Saves the recording to disk
6. Stops/resets the transmitter
Args:
config: Parsed campaign configuration.
progress_cb: Optional callback ``(step_index, total_steps, step_result)``
called after each step completes. Useful for status reporting.
verbose: Enable debug logging.
"""
def __init__(
self,
config: CampaignConfig,
progress_cb: Optional[Callable[[int, int, StepResult], None]] = None,
verbose: bool = False,
skip_local_tx: bool = False,
):
self.config = config
self.progress_cb = progress_cb
self.skip_local_tx = skip_local_tx
self._sdr = None
self._remote_tx_controllers: dict = {}
self._tx_executors: dict[str, tuple] = {} # tx_id → (TxExecutor, stop_event, thread)
if verbose:
logging.basicConfig(level=logging.DEBUG)
else:
logging.basicConfig(level=logging.INFO)
# ------------------------------------------------------------------
# Public interface
# ------------------------------------------------------------------
def run(self) -> CampaignResult:
"""Execute the full campaign and return a :class:`CampaignResult`.
Initialises the SDR, runs all steps across all transmitters,
then closes the SDR. If SDR initialisation fails the exception
propagates immediately (nothing is captured).
"""
result = CampaignResult(campaign_name=self.config.name)
loops = self.config.loops
logger.info(
f"Starting campaign '{self.config.name}': "
f"{self.config.total_steps()} steps"
+ (f" ({self.config.total_steps() // loops} × {loops} loops)" if loops > 1 else "")
+ f", ~{self.config.total_capture_time_s():.0f}s capture time"
)
self._init_sdr()
self._init_remote_tx_controllers()
try:
total = self.config.total_steps()
step_index = 0
for loop_idx in range(loops):
if loops > 1:
logger.info(f"Loop {loop_idx + 1}/{loops}")
for transmitter in self.config.transmitters:
logger.info(f"Transmitter: {transmitter.id} ({len(transmitter.schedule)} steps)")
for step in transmitter.schedule:
looped_step = replace(step, label=f"{step.label}_run{loop_idx + 1:02d}") if loops > 1 else step
step_result = self._execute_step(transmitter, looped_step)
result.steps.append(step_result)
step_index += 1
if self.progress_cb:
self.progress_cb(step_index, total, step_result)
if step_result.error:
logger.warning(f"Step '{looped_step.label}' error: {step_result.error}")
elif step_result.qa.flagged:
logger.warning(
f"Step '{looped_step.label}' flagged for review: " + "; ".join(step_result.qa.issues)
)
else:
logger.info(
f"Step '{looped_step.label}' OK "
f"(SNR {step_result.qa.snr_db:.1f} dB, "
f"{step_result.qa.duration_s:.1f}s)"
)
finally:
self._close_sdr()
self._close_remote_tx_controllers()
self._close_tx_executors()
result.end_time = time.time()
logger.info(
f"Campaign complete: {result.passed}/{result.total_steps} passed, "
f"{result.flagged} flagged, {result.failed} failed"
)
return result
# ------------------------------------------------------------------
# SDR management
# ------------------------------------------------------------------
def _init_sdr(self) -> None:
"""Initialise and configure the SDR recorder."""
from ria_toolkit_oss.sdr import get_sdr_device
rec = self.config.recorder
device_name = _DEVICE_ALIASES.get(rec.device.lower(), rec.device.lower())
logger.info(f"Initialising SDR: {device_name} @ {rec.center_freq/1e6:.2f} MHz")
self._sdr = get_sdr_device(device_name)
gain = None if rec.gain == "auto" else float(rec.gain)
self._sdr.init_rx(
sample_rate=rec.sample_rate,
center_frequency=rec.center_freq,
gain=gain,
channel=0,
)
if rec.bandwidth and hasattr(self._sdr, "set_rx_bandwidth"):
self._sdr.set_rx_bandwidth(rec.bandwidth)
def _close_sdr(self) -> None:
if self._sdr is not None:
try:
self._sdr.close()
except Exception as e:
logger.warning(f"SDR close error: {e}")
self._sdr = None
# ------------------------------------------------------------------
# Remote Tx controller management
# ------------------------------------------------------------------
def _init_remote_tx_controllers(self) -> None:
"""Open SSH+ZMQ connections for all sdr_remote transmitters."""
from ria_toolkit_oss.remote_control import RemoteTransmitterController
for tx in self.config.transmitters:
if tx.control_method != "sdr_remote":
continue
cfg = tx.sdr_remote
if not cfg:
raise RuntimeError(f"Transmitter '{tx.id}' uses sdr_remote but has no sdr_remote config")
logger.info(f"Connecting remote Tx controller for {tx.id}{cfg['host']}")
ctrl = RemoteTransmitterController(
host=cfg["host"],
ssh_user=cfg["ssh_user"],
ssh_key_path=cfg["ssh_key_path"],
zmq_port=int(cfg.get("zmq_port", 5556)),
)
ctrl.set_radio(
device_type=cfg["device_type"],
device_id=cfg.get("device_id", ""),
)
self._remote_tx_controllers[tx.id] = ctrl
def _close_remote_tx_controllers(self) -> None:
for tx_id, ctrl in list(self._remote_tx_controllers.items()):
try:
ctrl.close()
except Exception as exc:
logger.warning(f"Error closing remote Tx controller {tx_id}: {exc}")
self._remote_tx_controllers.clear()
def _close_tx_executors(self) -> None:
for tx_id, (_, stop_event, t) in list(self._tx_executors.items()):
stop_event.set()
t.join(timeout=5.0)
self._tx_executors.clear()
def _record(self, duration_s: float) -> Recording:
"""Capture ``duration_s`` seconds of IQ samples."""
num_samples = int(duration_s * self.config.recorder.sample_rate)
return self._sdr.record(num_samples=num_samples)
# ------------------------------------------------------------------
# Step execution
# ------------------------------------------------------------------
def _execute_step(self, transmitter: TransmitterConfig, step: CaptureStep) -> StepResult:
"""Run a single capture step.
Returns:
StepResult with QA outcome and output path (or error string).
"""
capture_timestamp = time.time()
output_path: Optional[str] = None
try:
self._start_transmitter(transmitter, step)
recording = self._record(step.duration)
self._stop_transmitter(transmitter, step)
except Exception as e:
# Best-effort stop on error
try:
self._stop_transmitter(transmitter, step)
except Exception:
pass
return StepResult(
transmitter_id=transmitter.id,
step_label=step.label,
output_path=None,
qa=QAResult(passed=False, flagged=True, snr_db=0.0, duration_s=0.0, issues=[f"Capture error: {e}"]),
capture_timestamp=capture_timestamp,
error=str(e),
)
# Label recording
recording = label_recording(
recording=recording,
device_id=transmitter.id,
step=step,
capture_timestamp=capture_timestamp,
campaign_name=self.config.name,
tx_params=_extract_tx_params(transmitter),
)
# QA
qa_result = check_recording(recording, self.config.qa)
# Save
try:
output_path = self._save(recording, transmitter.id, step)
except Exception as e:
return StepResult(
transmitter_id=transmitter.id,
step_label=step.label,
output_path=None,
qa=qa_result,
capture_timestamp=capture_timestamp,
error=f"Save failed: {e}",
)
return StepResult(
transmitter_id=transmitter.id,
step_label=step.label,
output_path=output_path,
qa=qa_result,
capture_timestamp=capture_timestamp,
)
# ------------------------------------------------------------------
# Transmitter control (external script interface)
# ------------------------------------------------------------------
def _start_transmitter(self, transmitter: TransmitterConfig, step: CaptureStep) -> None:
"""Configure the transmitter for this step.
For ``external_script`` control method the script is called as::
<script> configure <step_params_json>
where ``step_params_json`` is a JSON object with channel, bandwidth,
traffic, etc. The script is responsible for applying the configuration
and returning promptly (i.e. not blocking for the capture duration).
For ``sdr_remote`` the remote ZMQ controller calls ``init_tx`` then
starts a background transmit thread that runs for the step duration.
"""
if transmitter.control_method == "external_script":
if not transmitter.script:
logger.debug(f"No script configured for {transmitter.id}, skipping configure")
return
params = self._step_params_json(transmitter, step)
_run_script(transmitter.script, "configure", params)
elif transmitter.control_method == "sdr":
logger.debug("SDR TX not yet implemented — skipping start")
elif transmitter.control_method == "sdr_remote":
ctrl = self._remote_tx_controllers.get(transmitter.id)
if ctrl is None:
raise RuntimeError(f"No remote Tx controller found for transmitter '{transmitter.id}'")
gain = step.power_dbm if step.power_dbm is not None else 0.0
ctrl.init_tx(
center_frequency=self.config.recorder.center_freq,
sample_rate=self.config.recorder.sample_rate,
gain=gain,
channel=step.channel or 0,
)
# Start transmission in background; _record() runs concurrently
ctrl.transmit_async(step.duration + 1.0)
elif transmitter.control_method == "sdr_agent":
if self.skip_local_tx:
logger.debug(f"skip_local_tx — TX for '{transmitter.id}' delegated to TX agent node")
return
if not transmitter.sdr_agent:
logger.warning(f"Transmitter '{transmitter.id}' has no sdr_agent config — skipping")
return
step_dict: dict = {"label": step.label, "duration": step.duration + 1.0}
if step.power_dbm is not None:
step_dict["power_dbm"] = step.power_dbm
tx_config = {
"id": transmitter.id,
"sdr_agent": transmitter.sdr_agent,
"schedule": [step_dict],
}
rec = self.config.recorder
tx_device = transmitter.device or rec.device
sdr_device = _DEVICE_ALIASES.get(tx_device.lower(), tx_device.lower())
stop_event = threading.Event()
executor = TxExecutor(tx_config, sdr_device=sdr_device, stop_event=stop_event)
t = threading.Thread(target=executor.run, daemon=True, name=f"tx-{transmitter.id}")
self._tx_executors[transmitter.id] = (executor, stop_event, t)
t.start()
else:
logger.warning(f"Unknown control method '{transmitter.control_method}' — skipping")
def _stop_transmitter(self, transmitter: TransmitterConfig, step: CaptureStep) -> None:
"""Signal the transmitter to stop.
Calls ``<script> stop`` for external_script transmitters.
For ``sdr_remote``, waits for the background transmit thread to finish.
"""
if transmitter.control_method == "external_script":
if not transmitter.script:
return
try:
_run_script(transmitter.script, "stop")
except Exception as e:
logger.warning(f"Script stop failed for {transmitter.id}: {e}")
elif transmitter.control_method == "sdr_remote":
ctrl = self._remote_tx_controllers.get(transmitter.id)
if ctrl is not None:
ctrl.wait_transmit(timeout=step.duration + 10.0)
elif transmitter.control_method == "sdr_agent":
entry = self._tx_executors.pop(transmitter.id, None)
if entry is not None:
_, stop_event, t = entry
stop_event.set()
t.join(timeout=step.duration + 10.0)
@staticmethod
def _step_params_json(transmitter: TransmitterConfig, step: CaptureStep) -> str:
"""Serialise step parameters to a JSON string for the control script."""
params: dict = {"device": transmitter.device or ""}
if step.channel is not None:
params["channel"] = step.channel
if step.bandwidth_mhz is not None:
params["bandwidth_mhz"] = step.bandwidth_mhz
if step.traffic is not None:
params["traffic"] = step.traffic
if step.power_dbm is not None:
params["power_dbm"] = step.power_dbm
return json.dumps(params)
# ------------------------------------------------------------------
# Output
# ------------------------------------------------------------------
def _save(self, recording: Recording, device_id: str, step: CaptureStep) -> str:
"""Save a recording to disk and return the data file path."""
out = self.config.output
rel_filename = build_output_filename(device_id, step)
out_dir = Path(out.path).resolve()
# build_output_filename returns "<device_id>/<label>"
# to_sigmf needs filename (base) and path (dir) separately
parts = Path(rel_filename)
subdir = (out_dir / parts.parent).resolve()
# Prevent path traversal: the resolved subdir must stay within the configured output directory.
try:
subdir.relative_to(out_dir)
except ValueError:
raise RuntimeError(
f"Output path escape detected: '{subdir}' is outside configured output directory '{out_dir}'"
)
subdir.mkdir(parents=True, exist_ok=True)
base = parts.name
to_sigmf(recording, filename=base, path=str(subdir), overwrite=True)
return str(subdir / f"{base}.sigmf-data")

View File

@ -0,0 +1,86 @@
"""Timestamp-based labeling for captured recordings."""
from __future__ import annotations
from typing import Optional
from ria_toolkit_oss.data.recording import Recording
from .campaign import CaptureStep
def label_recording(
recording: Recording,
device_id: str,
step: CaptureStep,
capture_timestamp: float,
campaign_name: Optional[str] = None,
tx_params: Optional[dict] = None,
) -> Recording:
"""Apply device identity and capture configuration labels to a recording's metadata.
Labels are stored in the ``ria:*`` namespace when the recording is saved
as SigMF, via the existing ``update_metadata`` mechanism.
Args:
recording: The recording to label.
device_id: Identifier for the transmitting device (e.g. "iphone13_wifi_001").
step: The capture step that was active during this recording.
capture_timestamp: Unix timestamp (float) of when capture started.
campaign_name: Optional campaign name for cross-recording reference.
tx_params: Optional dict of transmitter signal parameters (e.g. modulation,
order, symbol_rate) written as ``ria:tx_<key>`` fields so downstream
training pipelines know what was transmitted into the recording.
Returns:
The same recording with updated metadata.
"""
recording.update_metadata("device_id", device_id)
recording.update_metadata("capture_timestamp", capture_timestamp)
recording.update_metadata("step_label", step.label)
recording.update_metadata("step_duration_s", step.duration)
if campaign_name:
recording.update_metadata("campaign", campaign_name)
# WiFi-specific labels
if step.channel is not None:
recording.update_metadata("wifi_channel", step.channel)
if step.bandwidth_mhz is not None:
recording.update_metadata("wifi_bandwidth_mhz", step.bandwidth_mhz)
# Bluetooth-specific labels
if step.connection_interval_ms is not None:
recording.update_metadata("bt_connection_interval_ms", step.connection_interval_ms)
# Traffic pattern (WiFi + BT)
if step.traffic is not None:
recording.update_metadata("traffic_pattern", step.traffic)
# TX power
if step.power_dbm is not None:
recording.update_metadata("tx_power_dbm", step.power_dbm)
# Transmitter signal parameters (e.g. from sdr_agent synthetic generation)
if tx_params:
for key, value in tx_params.items():
recording.update_metadata(f"tx_{key}", value)
return recording
def build_output_filename(device_id: str, step: CaptureStep) -> str:
"""Generate a deterministic filename for a labeled recording.
Format: ``<device_id>/<step_label>``
Args:
device_id: Device identifier string.
step: Capture step.
Returns:
Relative path string (no extension) to use as ``filename`` in ``to_sigmf()``.
"""
safe_id = device_id.replace("/", "_").replace(" ", "_")
safe_label = step.label.replace("/", "_").replace(" ", "_")
return f"{safe_id}/{safe_label}"

View File

@ -0,0 +1,109 @@
"""QA metrics for captured RF recordings."""
from __future__ import annotations
from dataclasses import dataclass, field
import numpy as np
from ria_toolkit_oss.data.recording import Recording
from .campaign import QAConfig
@dataclass
class QAResult:
"""Result of QA checks on a single recording."""
passed: bool
flagged: bool # True if any metric is below threshold (but not hard-failed)
snr_db: float
duration_s: float
issues: list[str] = field(default_factory=list)
def to_dict(self) -> dict:
return {
"passed": self.passed,
"flagged": self.flagged,
"snr_db": round(self.snr_db, 2),
"duration_s": round(self.duration_s, 3),
"issues": self.issues,
}
def estimate_snr_db(samples: np.ndarray, signal_fraction: float = 0.7) -> float:
"""Estimate SNR from IQ samples using PSD-based signal/noise separation.
Computes an FFT of the samples and assumes the top ``signal_fraction``
of power bins are signal and the remainder are noise. This is a
heuristic appropriate for a controlled testbed where a single dominant
signal is expected.
Args:
samples: 1-D complex array of IQ samples.
signal_fraction: Fraction of PSD bins to treat as signal (01).
Returns:
Estimated SNR in dB, or 0.0 if the noise floor is zero.
"""
n_fft = min(4096, len(samples))
window = np.hanning(n_fft)
psd = np.abs(np.fft.fft(samples[:n_fft] * window)) ** 2
psd_sorted = np.sort(psd)[::-1]
n_signal = min(max(1, int(n_fft * signal_fraction)), n_fft - 1)
signal_power = psd_sorted[:n_signal].mean()
noise_power = psd_sorted[n_signal:].mean()
if noise_power <= 0.0:
return 0.0
return float(10.0 * np.log10(signal_power / noise_power))
def check_recording(recording: Recording, config: QAConfig) -> QAResult:
"""Run QA checks on a recording against the campaign QA config.
Checks performed:
- Duration: number of samples / sample_rate >= min_duration_s
- SNR: estimated SNR >= snr_threshold_db
Args:
recording: Recording to evaluate.
config: QA thresholds from the campaign config.
Returns:
QAResult with pass/flag status and per-metric details.
"""
issues: list[str] = []
flagged = False
# --- Duration check ---
sample_rate = recording.metadata.get("sample_rate", 1.0)
n_samples = recording.data.shape[-1]
duration_s = n_samples / sample_rate if sample_rate else 0.0
if duration_s < config.min_duration_s:
issues.append(f"Duration too short: {duration_s:.1f}s < {config.min_duration_s:.1f}s threshold")
flagged = True
# --- SNR check ---
samples = recording.data[0] if recording.data.ndim > 1 else recording.data
snr_db = estimate_snr_db(samples)
if snr_db < config.snr_threshold_db:
issues.append(f"SNR below threshold: {snr_db:.1f} dB < {config.snr_threshold_db:.1f} dB")
flagged = True
# In flag_for_review mode: flag but don't hard-fail
if config.flag_for_review:
passed = True # always accept; human reviews flagged recordings
else:
passed = not flagged
return QAResult(
passed=passed,
flagged=flagged,
snr_db=snr_db,
duration_s=duration_s,
issues=issues,
)

View File

@ -0,0 +1,299 @@
"""TX campaign executor — synthesises and transmits signals via a local SDR.
The TxExecutor receives a transmitter config dict (matching the
``sdr_agent`` control method's schema) and a step schedule, then for each
step builds a signal chain with the block generator and transmits it via
the local SDR device.
Supported modulations (``modulation`` field in config):
BPSK, QPSK, 8PSK, 16QAM, 64QAM, 256QAM, FSK, OOK, GMSK, OQPSK
Example config dict (matches CampaignConfig transmitter with
``control_method: sdr_agent``)::
{
"id": "synthetic-tx",
"type": "sdr",
"control_method": "sdr_agent",
"sdr_agent": {
"modulation": "QPSK",
"order": 4,
"symbol_rate": 1000000,
"center_frequency": 0.0,
"filter": "rrc",
"rolloff": 0.35
},
"schedule": [
{"label": "step1", "duration": 10, "power_dbm": -10}
]
}
"""
from __future__ import annotations
import logging
import threading
from typing import Any
logger = logging.getLogger(__name__)
def _parse_hz(val: object) -> float:
"""Parse a frequency value that may be a float (Hz) or a string like '2.45GHz'."""
if isinstance(val, (int, float)):
return float(val)
s = str(val).strip()
for suffix, mult in (("GHz", 1e9), ("MHz", 1e6), ("kHz", 1e3), ("Hz", 1.0)):
if s.endswith(suffix):
return float(s[: -len(suffix)]) * mult
return float(s)
def _parse_seconds(val: object) -> float:
"""Parse a duration value that may be a float (seconds) or a string like '5s'."""
if isinstance(val, (int, float)):
return float(val)
s = str(val).strip()
return float(s[:-1]) if s.endswith("s") else float(s)
# Mapping from modulation name → (PSK/QAM order, generator_type)
# 'psk' uses PSKGenerator, 'qam' uses QAMGenerator
_MOD_TABLE: dict[str, tuple[int, str]] = {
"BPSK": (1, "psk"),
"QPSK": (2, "psk"),
"8PSK": (3, "psk"),
"16QAM": (4, "qam"),
"64QAM": (6, "qam"),
"256QAM": (8, "qam"),
}
_SPECIAL_MODS = {"FSK", "OOK", "GMSK", "OQPSK"}
# usrp-uhd-client's tx_recording() streams 2 000-sample chunks and loops the
# source buffer for the full tx_time, so only this many samples ever need to
# be in RAM regardless of step duration or sample rate.
# 50 000 complex64 samples ≈ 400 kB — enough spectral diversity for looping.
_SYNTH_BLOCK_SAMPLES = 50_000
class TxExecutor:
"""Synthesise and transmit a signal campaign via a local SDR.
Args:
config: Transmitter config dict (must have ``sdr_agent`` sub-dict with
modulation params, and ``schedule`` list of step dicts).
sdr_device: SDR device name to open in TX mode (e.g. "pluto", "usrp").
stop_event: External event that aborts the TX loop mid-step.
"""
def __init__(
self,
config: dict,
sdr_device: str = "unknown",
stop_event: threading.Event | None = None,
) -> None:
self.config = config
self.sdr_device = sdr_device
self.stop_event = stop_event or threading.Event()
self._sdr: Any = None
def run(self) -> None:
"""Execute all steps in the schedule, transmitting for each step duration."""
agent_cfg: dict = self.config.get("sdr_agent") or {}
schedule: list[dict] = self.config.get("schedule") or []
if not schedule:
logger.warning("TxExecutor: no schedule steps — nothing to transmit")
return
modulation: str = agent_cfg.get("modulation", "QPSK").upper()
symbol_rate: float = float(agent_cfg.get("symbol_rate", 1e6))
center_freq: float = _parse_hz(agent_cfg.get("center_frequency", 0.0))
filter_type: str = agent_cfg.get("filter", "rrc").lower()
rolloff: float = float(agent_cfg.get("rolloff", 0.35))
loops: int = max(1, int(self.config.get("loops", 1)))
# Upsampling factor: samples_per_symbol, fixed at 8 for SDR compatibility.
sps = 8
sample_rate = symbol_rate * sps
self._init_sdr(sample_rate, center_freq)
try:
for loop_idx in range(loops):
if self.stop_event.is_set():
break
if loops > 1:
logger.info("TX loop %d/%d", loop_idx + 1, loops)
for step in schedule:
if self.stop_event.is_set():
break
looped_step = (
{**step, "label": f"{step.get('label', 'step')}_run{loop_idx + 1:02d}"} if loops > 1 else step
)
self._execute_step(looped_step, modulation, sps, symbol_rate, filter_type, rolloff)
finally:
self._close_sdr()
def _execute_step(
self,
step: dict,
modulation: str,
sps: int,
symbol_rate: float,
filter_type: str,
rolloff: float,
) -> None:
duration: float = _parse_seconds(step.get("duration", 10.0))
label: str = step.get("label", "step")
gain: float = float(step.get("power_dbm") or 0.0)
sample_rate = symbol_rate * sps
logger.info(
"TX step '%s': %.0f s, %s @ %.3f MHz (sps=%d, filter=%s)",
label,
duration,
modulation,
symbol_rate / 1e6,
sps,
filter_type,
)
num_samples = int(duration * sample_rate)
# Synthesise a short representative block. tx_recording() loops this
# buffer for the full tx_time using a 2 000-sample streaming callback,
# so peak memory is O(_SYNTH_BLOCK_SAMPLES) regardless of duration.
block_size = min(num_samples, _SYNTH_BLOCK_SAMPLES)
signal = self._synthesise(modulation, sps, block_size, filter_type, rolloff)
if self._sdr is not None:
try:
# Apply gain update if SDR supports it
if hasattr(self._sdr, "set_tx_gain"):
self._sdr.set_tx_gain(gain)
self._sdr.tx_recording(signal, tx_time=duration)
except Exception as exc:
logger.error("TX step '%s' SDR error: %s", label, exc)
else:
# No SDR available — simulate by sleeping for the step duration.
logger.warning("TX step '%s': no SDR — simulating %.0f s delay", label, duration)
self.stop_event.wait(timeout=duration)
def _synthesise(
self,
modulation: str,
sps: int,
num_samples: int,
filter_type: str,
rolloff: float,
):
"""Build a block-generator chain and return IQ samples as a numpy array."""
try:
import numpy as np
from ria_toolkit_oss.signal.block_generator import (
BinarySource,
GMSKModulator,
Mapper,
OOKModulator,
OQPSKModulator,
RaisedCosineFilter,
RootRaisedCosineFilter,
Upsampling,
)
from ria_toolkit_oss.signal.block_generator.continuous_modulation.fsk_modulator import (
FSKModulator,
)
except ImportError as exc:
raise RuntimeError(f"ria_toolkit_oss block generator not available: {exc}") from exc
# ── Special modulations with their own source-connected modulator ──
if modulation in ("OOK", "GMSK", "OQPSK"):
src = BinarySource()
if modulation == "OOK":
mod = OOKModulator(src, samples_per_symbol=sps)
elif modulation == "GMSK":
mod = GMSKModulator(src, samples_per_symbol=sps)
else:
mod = OQPSKModulator(src, samples_per_symbol=sps)
recording = mod.record(num_samples)
flat = np.asarray(recording.data).flatten().astype(np.complex64)
if len(flat) < num_samples:
flat = np.tile(flat, num_samples // len(flat) + 1)
return flat[:num_samples]
if modulation == "FSK":
symbol_rate = num_samples / sps
bits_per_sym = 1 # 2-FSK
num_bits = max(num_samples // sps, 128) * bits_per_sym
bits = BinarySource()((1, num_bits))
mod = FSKModulator(
num_bits_per_symbol=bits_per_sym,
frequency_spacing=symbol_rate * 0.5,
symbol_duration=1.0 / max(symbol_rate, 1.0),
sampling_frequency=symbol_rate * sps,
)
flat = np.asarray(mod(bits)).flatten().astype(np.complex64)
if len(flat) < num_samples:
flat = np.tile(flat, num_samples // len(flat) + 1)
return flat[:num_samples]
# ── PSK / QAM via Mapper → Upsampling → pulse filter ──────────────
if modulation not in _MOD_TABLE:
logger.warning("Unknown modulation %r — defaulting to QPSK", modulation)
modulation = "QPSK"
bits_per_sym, gen_type = _MOD_TABLE[modulation]
mod_family = "QAM" if gen_type == "qam" else "PSK"
source = BinarySource()
mapper = Mapper(constellation_type=mod_family, num_bits_per_symbol=bits_per_sym)
upsampler = Upsampling(factor=sps)
mapper.connect_input([source])
upsampler.connect_input([mapper])
if filter_type in ("rrc",):
pulse_filter = RootRaisedCosineFilter(span_in_symbols=6, upsampling_factor=sps, beta=rolloff)
pulse_filter.connect_input([upsampler])
recording = pulse_filter.record(num_samples)
elif filter_type in ("rc",):
pulse_filter = RaisedCosineFilter(span_in_symbols=6, upsampling_factor=sps, beta=rolloff)
pulse_filter.connect_input([upsampler])
recording = pulse_filter.record(num_samples)
else:
# "none", "rect", "gaussian" — use upsampler output directly
recording = upsampler.record(num_samples)
flat = np.asarray(recording.data).flatten().astype(np.complex64)
if len(flat) < num_samples:
flat = np.tile(flat, num_samples // len(flat) + 1)
return flat[:num_samples]
def _init_sdr(self, sample_rate: float, center_freq: float) -> None:
try:
from ria_toolkit_oss.sdr import get_sdr_device
self._sdr = get_sdr_device(self.sdr_device)
self._sdr.init_tx(
sample_rate=sample_rate,
center_frequency=center_freq,
gain=0,
channel=0,
gain_mode="manual",
)
logger.info(
"TX SDR initialised: %s @ %.3f MHz, %.1f Msps", self.sdr_device, center_freq / 1e6, sample_rate / 1e6
)
except Exception as exc:
logger.warning("TX SDR init failed (%s) — will simulate: %s", self.sdr_device, exc)
self._sdr = None
def _close_sdr(self) -> None:
if self._sdr is not None:
try:
self._sdr.close()
except Exception as exc:
logger.debug("TX SDR close error: %s", exc)
self._sdr = None

View File

@ -0,0 +1,6 @@
"""Remote SDR transmitter control via SSH + ZMQ."""
from .remote_transmitter import RemoteTransmitter
from .remote_transmitter_controller import RemoteTransmitterController
__all__ = ["RemoteTransmitter", "RemoteTransmitterController"]

View File

@ -0,0 +1,152 @@
"""Server-side ZMQ RPC receiver for SDR transmission.
Run this script on the Tx machine. The script binds a ZMQ REP socket and
waits for JSON-RPC commands from a :class:`RemoteTransmitterController`.
Requires: zmq, and ria-toolkit or utils installed for SDR support.
"""
from __future__ import annotations
import argparse
import io
import json
import logging
from contextlib import redirect_stderr, redirect_stdout
import zmq
logger = logging.getLogger(__name__)
class RemoteTransmitter:
"""Executes SDR Tx commands received over ZMQ.
Loads the appropriate SDR driver dynamically so the script can run on
machines that have only a subset of SDR libraries installed.
"""
def __init__(self) -> None:
self._sdr = None
def set_radio(self, radio_str: str, identifier: str = "") -> None:
"""Initialise the SDR radio.
Args:
radio_str: SDR type pluto | usrp | hackrf | bladerf.
identifier: Device-specific identifier (IP, serial, etc.).
"""
radio_str = radio_str.lower()
try:
if radio_str in ("pluto", "plutosdr"):
from ria_toolkit_oss.sdr.pluto import Pluto
self._sdr = Pluto(identifier)
elif radio_str in ("usrp",):
from ria_toolkit_oss.sdr.usrp import USRP
self._sdr = USRP(identifier)
elif radio_str in ("hackrf", "hackrf_one"):
from ria_toolkit_oss.sdr.hackrf import HackRF
self._sdr = HackRF(identifier)
elif radio_str in ("bladerf", "blade"):
from ria_toolkit_oss.sdr.blade import Blade
self._sdr = Blade(identifier)
else:
raise ValueError(f"Unknown SDR type: {radio_str!r}")
except ImportError as exc:
raise RuntimeError(f"SDR driver for '{radio_str}' is not installed: {exc}") from exc
def init_tx(
self,
center_frequency: float,
sample_rate: float,
gain: float,
channel: int = 0,
gain_mode: str = "absolute",
) -> None:
if self._sdr is None:
raise RuntimeError("Call set_radio() before init_tx()")
self._sdr.init_tx(
center_frequency=center_frequency,
sample_rate=sample_rate,
gain=gain,
channel=channel,
)
def transmit(self, duration_s: float) -> None:
"""Transmit a continuous wave for ``duration_s`` seconds."""
if self._sdr is None:
raise RuntimeError("Call set_radio() and init_tx() before transmit()")
import time
# Transmit in a loop until duration has elapsed
end = time.monotonic() + duration_s
while time.monotonic() < end:
try:
self._sdr.tx_cw()
except AttributeError:
time.sleep(0.01)
def stop(self) -> None:
"""Stop transmission and close the SDR."""
if self._sdr is not None:
try:
self._sdr.close()
except Exception:
pass
self._sdr = None
def run_function(self, command_dict: dict) -> dict:
"""Dispatch a JSON-RPC command and return a response dict."""
out_buf = io.StringIO()
err_buf = io.StringIO()
fn = command_dict.get("function_name", "")
try:
with redirect_stdout(out_buf), redirect_stderr(err_buf):
if fn == "set_radio":
self.set_radio(
radio_str=command_dict["radio_str"],
identifier=command_dict.get("identifier", ""),
)
elif fn == "init_tx":
self.init_tx(
center_frequency=command_dict["center_frequency"],
sample_rate=command_dict["sample_rate"],
gain=command_dict["gain"],
channel=command_dict.get("channel", 0),
gain_mode=command_dict.get("gain_mode", "absolute"),
)
elif fn == "transmit":
self.transmit(duration_s=command_dict.get("duration_s", 1.0))
elif fn == "stop":
self.stop()
else:
raise ValueError(f"Unknown function: {fn!r}")
return {"status": True, "message": out_buf.getvalue(), "error_message": err_buf.getvalue()}
except Exception as exc:
logger.exception("Error executing %s", fn)
return {"status": False, "message": out_buf.getvalue(), "error_message": str(exc)}
def _serve(port: int) -> None:
context = zmq.Context()
socket = context.socket(zmq.REP)
socket.bind(f"tcp://*:{port}")
logger.info("RemoteTransmitter listening on port %d", port)
tx = RemoteTransmitter()
while True:
raw = socket.recv()
cmd = json.loads(raw.decode())
response = tx.run_function(cmd)
socket.send(json.dumps(response).encode())
if __name__ == "__main__":
logging.basicConfig(level=logging.INFO)
parser = argparse.ArgumentParser(description="SDR Tx ZMQ server")
parser.add_argument("--port", type=int, default=5556)
args = parser.parse_args()
_serve(args.port)

View File

@ -0,0 +1,218 @@
"""Client-side SSH + ZMQ controller for a remote SDR transmitter.
Run this on the Rx machine (or hub). It SSH-es into the Tx machine,
starts :mod:`remote_transmitter` there, then sends JSON-RPC commands over
ZMQ.
Requires: paramiko, zmq.
"""
from __future__ import annotations
import json
import logging
import threading
import time
import paramiko
import zmq
logger = logging.getLogger(__name__)
_STARTUP_WAIT_S = 2.0 # seconds to wait for remote ZMQ server to bind
class RemoteTransmitterController:
"""SSH into a Tx machine, start the ZMQ server, and send commands.
Args:
host: IP or hostname of the Tx machine.
ssh_user: SSH username.
ssh_key_path: Path to SSH private key file.
zmq_port: ZMQ port that the remote transmitter will bind on.
"""
def __init__(
self,
host: str,
ssh_user: str,
ssh_key_path: str,
zmq_port: int = 5556,
) -> None:
self._host = host
self._zmq_port = zmq_port
self._ssh: paramiko.SSHClient | None = None
self._ssh_stdout = None
self._context: zmq.Context | None = None
self._socket: zmq.Socket | None = None
self._tx_thread: threading.Thread | None = None
self._lock = threading.Lock()
self._connect(host, ssh_user, ssh_key_path, zmq_port)
# ------------------------------------------------------------------
# Connection management
# ------------------------------------------------------------------
def _connect(self, host: str, ssh_user: str, ssh_key_path: str, zmq_port: int) -> None:
"""Open SSH tunnel, start remote server, connect ZMQ socket."""
try:
import paramiko
except ImportError as exc:
raise RuntimeError("paramiko is required for remote SDR control: pip install paramiko") from exc
try:
import zmq
except ImportError as exc:
raise RuntimeError("pyzmq is required for remote SDR control: pip install pyzmq") from exc
logger.info("SSH connecting to %s@%s", ssh_user, host)
self._ssh = paramiko.SSHClient()
self._ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
self._ssh.connect(hostname=host, username=ssh_user, key_filename=ssh_key_path)
cmd = f"python -m ria_toolkit_oss.remote_control.remote_transmitter --port {zmq_port}"
logger.info("Starting remote Tx server: %s", cmd)
_, self._ssh_stdout, _ = self._ssh.exec_command(cmd)
time.sleep(_STARTUP_WAIT_S)
self._context = zmq.Context()
self._socket = self._context.socket(zmq.REQ)
self._socket.connect(f"tcp://{host}:{zmq_port}")
logger.info("ZMQ connected to tcp://%s:%d", host, zmq_port)
def close(self) -> None:
"""Tear down ZMQ and SSH connections."""
if self._socket is not None:
try:
self._socket.close(linger=0)
except Exception:
pass
self._socket = None
if self._context is not None:
try:
self._context.term()
except Exception:
pass
self._context = None
if self._ssh_stdout is not None:
try:
self._ssh_stdout.channel.close()
except Exception:
pass
self._ssh_stdout = None
if self._ssh is not None:
try:
self._ssh.close()
except Exception:
pass
self._ssh = None
logger.info("RemoteTransmitterController closed")
# ------------------------------------------------------------------
# ZMQ dispatch
# ------------------------------------------------------------------
def _send(self, command: dict) -> dict:
"""Send a JSON-RPC command and return the response dict (thread-safe)."""
with self._lock:
if self._socket is None:
raise RuntimeError("Controller is closed")
self._socket.send(json.dumps(command).encode())
raw = self._socket.recv()
reply: dict = json.loads(raw.decode())
if not reply.get("status"):
raise RuntimeError(
f"Remote command '{command.get('function_name')}' failed: "
f"{reply.get('error_message', 'unknown error')}"
)
return reply
# ------------------------------------------------------------------
# Public API
# ------------------------------------------------------------------
def set_radio(self, device_type: str, device_id: str = "") -> None:
"""Initialise the SDR radio on the Tx machine.
Args:
device_type: SDR type ``pluto``, ``usrp``, ``hackrf``, ``bladerf``.
device_id: Device-specific identifier (IP, serial, etc.).
"""
logger.info("set_radio(%s, %r)", device_type, device_id)
self._send({"function_name": "set_radio", "radio_str": device_type, "identifier": device_id})
def init_tx(
self,
center_frequency: float,
sample_rate: float,
gain: float,
channel: int = 0,
gain_mode: str = "absolute",
) -> None:
"""Configure Tx parameters on the remote SDR.
Args:
center_frequency: Center frequency in Hz.
sample_rate: Sample rate in Hz.
gain: Tx gain in dB.
channel: RF channel index (default 0).
gain_mode: ``"absolute"`` (default) or ``"relative"``.
"""
logger.info(
"init_tx: fc=%.3f MHz, fs=%.3f MHz, gain=%.1f dB, ch=%d",
center_frequency / 1e6,
sample_rate / 1e6,
gain,
channel,
)
self._send(
{
"function_name": "init_tx",
"center_frequency": center_frequency,
"sample_rate": sample_rate,
"gain": gain,
"channel": channel,
"gain_mode": gain_mode,
}
)
def transmit_async(self, duration_s: float) -> None:
"""Start a timed CW transmission in a background thread.
Returns immediately. Call :meth:`wait_transmit` after recording to
ensure the transmit thread has finished before the next step.
Args:
duration_s: Transmission duration in seconds.
"""
logger.info("transmit_async: %.1f s", duration_s)
def _run() -> None:
try:
self._send({"function_name": "transmit", "duration_s": duration_s})
except Exception as exc:
logger.warning("Background transmit error: %s", exc)
self._tx_thread = threading.Thread(target=_run, daemon=True, name="remote-tx")
self._tx_thread.start()
def wait_transmit(self, timeout: float | None = None) -> None:
"""Wait for the background transmit thread to finish.
Args:
timeout: Maximum seconds to wait. ``None`` = wait indefinitely.
"""
if self._tx_thread is not None:
self._tx_thread.join(timeout=timeout)
self._tx_thread = None
def stop(self) -> None:
"""Stop transmission and release the remote SDR, then close connections."""
logger.info("Sending stop to remote Tx")
try:
self._send({"function_name": "stop"})
except Exception as exc:
logger.warning("stop command error (may be normal if connection closed): %s", exc)
finally:
self.close()

View File

@ -4,6 +4,82 @@ It streamlines tasks involving signal reception and transmission, as well as com
operations such as detecting and configuring available devices. operations such as detecting and configuring available devices.
""" """
__all__ = ["SDR"] __all__ = [
"SDR",
"SDRError",
"SDRParameterError",
"SdrDisconnectedError",
"MockSDR",
"get_sdr_device",
"detect_available",
]
from .sdr import SDR from .mock import MockSDR
from .sdr import ( # noqa: F401
SDR,
SdrDisconnectedError,
SDRError,
SDRParameterError,
translate_disconnect,
)
_DRIVER_CANDIDATES: tuple[tuple[str, str, str], ...] = (
("mock", "ria_toolkit_oss.sdr.mock", "MockSDR"),
("pluto", "ria_toolkit_oss.sdr.pluto", "Pluto"),
("hackrf", "ria_toolkit_oss.sdr.hackrf", "HackRF"),
("rtlsdr", "ria_toolkit_oss.sdr.rtlsdr", "RTLSDR"),
("usrp", "ria_toolkit_oss.sdr.usrp", "USRP"),
("blade", "ria_toolkit_oss.sdr.blade", "Blade"),
("thinkrf", "ria_toolkit_oss.sdr.thinkrf", "ThinkRF"),
)
def detect_available() -> dict[str, type]:
"""Return ``{device_name: driver_class}`` for every driver whose module imports cleanly.
Importability is a proxy for "the user has installed this driver's optional dependency".
It does not probe for physical hardware presence that requires actually instantiating
the driver, which can be slow and side-effectful.
"""
import importlib
out: dict[str, type] = {}
for name, module_path, cls_name in _DRIVER_CANDIDATES:
try:
mod = importlib.import_module(module_path)
out[name] = getattr(mod, cls_name)
except Exception:
continue
return out
def get_sdr_device(device_type: str, ident: str | None = None, tx: bool = False) -> SDR:
"""Return an SDR instance for *device_type*.
For ``"mock"`` / ``"sim"`` device types, returns a :class:`MockSDR`
immediately (no hardware required). For all real device types, delegates
to ``ria_toolkit_oss_cli.ria_toolkit_oss.common.get_sdr_device`` if the
CLI package is installed; otherwise raises ``ImportError`` with a helpful
message.
Args:
device_type: Device name (``"mock"``, ``"pluto"``, ``"usrp"``, ).
ident: Optional device identifier (IP address, serial number, ).
tx: If True, require TX capability.
"""
if device_type in ("mock", "sim"):
return MockSDR()
# Delegate real device types to the CLI package which holds the driver
# imports behind hardware-specific optional dependencies.
try:
from ria_toolkit_oss_cli.ria_toolkit_oss.common import (
get_sdr_device as _cli_get,
)
except ImportError as exc:
raise ImportError(
f"ria_toolkit_oss_cli is required to use hardware SDR device '{device_type}'. "
"Install it with: pip install ria-toolkit-oss-cli"
) from exc
return _cli_get(device_type, ident=ident, tx=tx)

View File

@ -1,12 +1,12 @@
import time import gc
import warnings import warnings
from typing import Optional from typing import Optional
import numpy as np import numpy as np
from bladerf import _bladerf from bladerf import _bladerf
from ria_toolkit_oss.datatypes import Recording from ria_toolkit_oss.data import Recording
from ria_toolkit_oss.sdr import SDR from ria_toolkit_oss.sdr import SDR, SDRError, SDRParameterError
class Blade(SDR): class Blade(SDR):
@ -22,7 +22,7 @@ class Blade(SDR):
""" """
if identifier != "": if identifier != "":
print(f"Warning, radio identifier {identifier} provided for Blade but will not be used.") warnings.warn(f"Blade: Identifier '{identifier}' will be ignored", UserWarning)
uut = self._probe_bladerf() uut = self._probe_bladerf()
@ -34,6 +34,7 @@ class Blade(SDR):
self.device = _bladerf.BladeRF(uut) self.device = _bladerf.BladeRF(uut)
self._print_versions(device=self.device) self._print_versions(device=self.device)
self.bytes_per_sample = 4
super().__init__() super().__init__()
@ -42,8 +43,10 @@ class Blade(SDR):
if board is not None: if board is not None:
board.close() board.close()
# TODO why does this create an error under any conditions? if error != 0:
raise OSError("Shutdown initiated with error code: {}".format(error)) raise OSError(f"BladeRF shutdown with error code: {error}")
else:
print("BladeRF shutdown successfully")
def _probe_bladerf(self): def _probe_bladerf(self):
device = None device = None
@ -85,24 +88,25 @@ class Blade(SDR):
:type sample_rate: int or float :type sample_rate: int or float
:param center_frequency: The center frequency of the recording. :param center_frequency: The center frequency of the recording.
:type center_frequency: int or float :type center_frequency: int or float
:param gain: The gain set for receiving on the BladeRF :param gain: The gain set for receiving on the BladeRF.
:type gain: int :type gain: int
:param channel: The channel the BladeRF is set to. :param channel: The channel the BladeRF is set to.
:type channel: int :type channel: int
:param buffer_size: The buffer size during receive. Defaults to 8192. :param buffer_size: The buffer size during receive. Defaults to 8192.
:type buffer_size: int :type buffer_size: int
:param gain_mode: 'absolute' passes gain directly to the sdr, :param gain_mode: 'absolute' passes gain directly to the SDR;
'relative' means that gain should be a negative value, and it will be subtracted from the max gain (60). 'relative' means that gain should be a negative value, and it will be subtracted from the max gain (60).
:type gain_mode: str :type gain_mode: str
""" """
print("Initializing RX") print("Initializing RX")
# Configure BladeRF # Configure BladeRF
self._set_rx_channel(channel) self.set_rx_channel(channel)
self._set_rx_sample_rate(sample_rate) self.set_rx_sample_rate(sample_rate)
self._set_rx_center_frequency(center_frequency) self.set_rx_center_frequency(center_frequency)
self._set_rx_gain(channel, gain, gain_mode) self.set_rx_gain(channel, gain, gain_mode)
self._set_rx_buffer_size(buffer_size) self.set_rx_buffer_size(buffer_size)
bw = self.rx_sample_rate bw = self.rx_sample_rate
if bw < 200000: if bw < 200000:
@ -128,10 +132,8 @@ class Blade(SDR):
stream_timeout=3500000000, stream_timeout=3500000000,
) )
self.rx_ch.enable = True
self.bytes_per_sample = 4
print("Blade Starting RX...") print("Blade Starting RX...")
self.rx_ch.enable = True
self._enable_rx = True self._enable_rx = True
while self._enable_rx: while self._enable_rx:
@ -148,18 +150,34 @@ class Blade(SDR):
print("Blade RX Completed.") print("Blade RX Completed.")
self.rx_ch.enable = False self.rx_ch.enable = False
def record(self, num_samples: Optional[int] = None, rx_time: Optional[int | float] = None): def record(
self,
num_samples: Optional[int] = None,
rx_time: Optional[int | float] = None,
) -> Recording:
"""
Create a radio recording (iq samples and metadata) of a given length from the Blade.
Either num_samples or rx_time must be provided.
init_rx() must be called before record()
:param num_samples: The number of samples to record.
:type num_samples: int, optional
:param rx_time: The time to record.
:type rx_time: int or float, optional
returns: Recording object (iq samples and metadata)
"""
if not self._rx_initialized: if not self._rx_initialized:
raise RuntimeError("RX was not initialized. init_rx() must be called before _stream_rx() or record()") raise RuntimeError("RX was not initialized. init_rx() must be called before _stream_rx() or record()")
if num_samples is not None and rx_time is not None: if num_samples is not None and rx_time is not None:
raise ValueError("Only input one of num_samples or rx_time") raise SDRParameterError("Only input one of num_samples or rx_time")
elif num_samples is not None: elif num_samples is not None:
self._num_samples_to_record = num_samples self._num_samples_to_record = num_samples
elif rx_time is not None: elif rx_time is not None:
self._num_samples_to_record = int(rx_time * self.rx_sample_rate) self._num_samples_to_record = int(rx_time * self.rx_sample_rate)
else: else:
raise ValueError("Must provide input of one of num_samples or rx_time") raise SDRParameterError("Must provide input of one of num_samples or rx_time")
# Setup synchronous stream # Setup synchronous stream
self.device.sync_config( self.device.sync_config(
@ -171,11 +189,10 @@ class Blade(SDR):
stream_timeout=3500000000, stream_timeout=3500000000,
) )
self.rx_ch.enable = True
self.bytes_per_sample = 4
print("Blade Starting RX...") print("Blade Starting RX...")
self._enable_rx = True with self._param_lock:
self._enable_rx = True
self.rx_ch.enable = True
store_array = np.zeros( store_array = np.zeros(
(1, (self._num_samples_to_record // self.rx_buffer_size + 1) * self.rx_buffer_size), dtype=np.complex64 (1, (self._num_samples_to_record // self.rx_buffer_size + 1) * self.rx_buffer_size), dtype=np.complex64
@ -191,7 +208,8 @@ class Blade(SDR):
# Disable module # Disable module
print("Blade RX Completed.") print("Blade RX Completed.")
self.rx_ch.enable = False with self._param_lock:
self.rx_ch.enable = False
metadata = { metadata = {
"source": self.__class__.__name__, "source": self.__class__.__name__,
"sample_rate": self.rx_sample_rate, "sample_rate": self.rx_sample_rate,
@ -207,7 +225,7 @@ class Blade(SDR):
center_frequency: int | float, center_frequency: int | float,
gain: int, gain: int,
channel: int, channel: int,
buffer_size: Optional[int] = 8192, buffer_size: Optional[int] = 32768,
gain_mode: Optional[str] = "absolute", gain_mode: Optional[str] = "absolute",
): ):
""" """
@ -224,16 +242,24 @@ class Blade(SDR):
:param buffer_size: The buffer size during transmission. Defaults to 8192. :param buffer_size: The buffer size during transmission. Defaults to 8192.
:type buffer_size: int :type buffer_size: int
:param gain_mode: 'absolute' passes gain directly to the sdr, :param gain_mode: 'absolute' passes gain directly to the sdr,
'relative' means that gain should be a negative value, and it will be subtracted from the max gain (60). 'relative' means that gain should be a negative value, and it will be subtracted from the max gain (60).
:type gain_mode: str :type gain_mode: str
:return: 0 if successful, -1 if there's an error.
:rtype: int
""" """
# Configure BladeRF # Configure BladeRF
self._set_tx_channel(channel) self.set_tx_channel(channel)
self._set_tx_sample_rate(sample_rate) self.set_tx_sample_rate(sample_rate)
self._set_tx_center_frequency(center_frequency) self.set_tx_center_frequency(center_frequency)
self._set_tx_gain(channel=channel, gain=gain, gain_mode=gain_mode) self.set_tx_gain(channel=channel, gain=gain, gain_mode=gain_mode)
self._set_tx_buffer_size(buffer_size) self.set_tx_buffer_size(buffer_size)
if self.tx_sample_rate >= 7.5e6 and self.tx_buffer_size < 65536:
warnings.warn(
"Blade: For high sample rates, a buffer size of 65536, 131072, or 262144 is recommended", UserWarning
)
bw = self.tx_sample_rate bw = self.tx_sample_rate
if bw < 200000: if bw < 200000:
@ -302,13 +328,13 @@ class Blade(SDR):
""" """
if num_samples is not None and tx_time is not None: if num_samples is not None and tx_time is not None:
raise ValueError("Only input one of num_samples or tx_time") raise SDRParameterError("Only input one of num_samples or tx_time")
elif num_samples is not None: elif num_samples is not None:
tx_time = num_samples / self.tx_sample_rate
elif tx_time is not None:
pass pass
elif tx_time is not None:
num_samples = int(tx_time * self.tx_sample_rate)
else: else:
tx_time = len(recording) / self.tx_sample_rate num_samples = len(recording)
if isinstance(recording, np.ndarray): if isinstance(recording, np.ndarray):
samples = recording samples = recording
@ -317,9 +343,15 @@ class Blade(SDR):
warnings.warn("Recording object is multichannel, only channel 0 data was used for transmission") warnings.warn("Recording object is multichannel, only channel 0 data was used for transmission")
samples = recording.data[0] samples = recording.data[0]
else: else:
raise TypeError("recording must be np.ndarray or Recording") raise SDRParameterError("recording must be np.ndarray or Recording")
samples = samples.astype(np.complex64, copy=False) samples = samples.astype(np.complex64, copy=False)
tx_bytes = self._convert_tx_samples(samples)
# Transmit in chunks
samples_sent = 0
len_samples = len(samples)
chunk_size = self.tx_buffer_size
# Setup stream # Setup stream
self.device.sync_config( self.device.sync_config(
@ -335,26 +367,21 @@ class Blade(SDR):
self.tx_ch.enable = True self.tx_ch.enable = True
print("Blade Starting TX...") print("Blade Starting TX...")
# Transmit samples - repeat as needed for the duration
start_time = time.time()
sample_index = 0
try: try:
while time.time() - start_time < tx_time: while samples_sent < num_samples:
# Get next chunk this_chunk_size = min(chunk_size, num_samples - samples_sent)
chunk_size = min(self.tx_buffer_size, len(samples) - sample_index)
if chunk_size == 0:
# Reached end, loop back
sample_index = 0
chunk_size = min(self.tx_buffer_size, len(samples))
chunk = samples[sample_index : sample_index + chunk_size] start_idx = (samples_sent % len_samples) * self.bytes_per_sample
sample_index += chunk_size end_idx = start_idx + this_chunk_size * self.bytes_per_sample
end_idx %= len_samples * self.bytes_per_sample
# Convert and transmit if end_idx > start_idx:
byte_array = self._convert_tx_samples(chunk) chunk_bytes_arr = tx_bytes[start_idx:end_idx]
self.device.sync_tx(byte_array, len(chunk)) else:
chunk_bytes_arr = tx_bytes[start_idx:] + tx_bytes[:end_idx]
self.device.sync_tx(chunk_bytes_arr, this_chunk_size)
samples_sent += this_chunk_size
except KeyboardInterrupt: except KeyboardInterrupt:
print("\nTransmission interrupted by user") print("\nTransmission interrupted by user")
@ -384,76 +411,145 @@ class Blade(SDR):
byte_array = tx_samples.tobytes() byte_array = tx_samples.tobytes()
return byte_array return byte_array
def _set_rx_channel(self, channel): def set_rx_channel(self, channel):
if channel != 0 and channel != 1:
raise SDRParameterError("Channel must be either 0 or 1.")
self.rx_channel = channel self.rx_channel = channel
self.rx_ch = self.device.Channel(_bladerf.CHANNEL_RX(channel)) self.rx_ch = self.device.Channel(_bladerf.CHANNEL_RX(channel))
print(f"\nBlade channel = {self.rx_ch}") print(f"\nBlade channel = {self.rx_ch}")
def _set_rx_sample_rate(self, sample_rate): def set_rx_sample_rate(self, sample_rate):
self.rx_sample_rate = sample_rate """
self.rx_ch.sample_rate = self.rx_sample_rate Set the sample rate of the receiver.
print(f"Blade sample rate = {self.rx_ch.sample_rate}") Not callable during recording; Blade requires stream stop/restart to change sample rate.
"""
def _set_rx_center_frequency(self, center_frequency): with self._param_lock:
self.rx_center_frequency = center_frequency if hasattr(self, "rx_channel"):
self.rx_ch.frequency = center_frequency range_list = self.device.get_sample_rate_range(self.rx_channel)
print(f"Blade center frequency = {self.rx_ch.frequency}") min_rate, max_rate = range_list[0], range_list[1]
def _set_rx_gain(self, channel, gain, gain_mode):
rx_gain_min = self.device.get_gain_range(channel)[0]
rx_gain_max = self.device.get_gain_range(channel)[1]
if gain_mode == "relative":
if gain > 0:
raise ValueError(
"When gain_mode = 'relative', gain must be < 0. This sets \
the gain relative to the maximum possible gain."
)
else: else:
abs_gain = rx_gain_max + gain raise SDRError("Must set channel before setting center frequency")
else:
abs_gain = gain
if abs_gain < rx_gain_min or abs_gain > rx_gain_max: if sample_rate < min_rate or sample_rate > max_rate:
abs_gain = min(max(gain, rx_gain_min), rx_gain_max) raise SDRParameterError(
print(f"Gain {abs_gain} out of range for Blade.") f"{self.__class__.__name__}: Sample rate {sample_rate/1e6:.3f} Msps "
print(f"Gain range: {rx_gain_min} to {rx_gain_max} dB") f"out of range: [{min_rate/1e6:.3f} - {max_rate/1e6:.3f} Msps]"
)
self.rx_gain = abs_gain self.rx_sample_rate = sample_rate
self.rx_ch.gain = abs_gain self.rx_ch.sample_rate = self.rx_sample_rate
print(f"Blade sample rate = {self.rx_ch.sample_rate}")
print(f"Blade gain = {self.rx_ch.gain}") def set_rx_center_frequency(self, center_frequency):
"""
Set the center frequency of the receiver.
Not callable during recording; Blade requires stream stop/restart to change center frequency.
"""
with self._param_lock:
if hasattr(self, "rx_channel"):
range_list = self.device.get_frequency_range(self.rx_channel)
min_rate, max_rate = range_list[0], range_list[1]
else:
raise SDRError("Must set channel before setting center frequency")
def _set_rx_buffer_size(self, buffer_size): if center_frequency < min_rate or center_frequency > max_rate:
raise SDRParameterError(
f"{self.__class__.__name__}: Center frequency {center_frequency/1e9:.3f} GHz "
f"out of range: [{min_rate/1e9:.3f} - {max_rate/1e9:.3f} GHz]"
)
self.rx_center_frequency = center_frequency
self.rx_ch.frequency = center_frequency
print(f"Blade center frequency = {self.rx_ch.frequency}")
def set_rx_gain(self, channel, gain, gain_mode):
"""
Set the gain of the receiver.
Not callable during recording; Blade requires stream stop/restart to change gain.
"""
with self._param_lock:
rx_gain_min = self.device.get_gain_range(channel)[0]
rx_gain_max = self.device.get_gain_range(channel)[1]
if gain_mode == "relative":
if gain > 0:
raise SDRParameterError("When gain_mode = 'relative', gain must be < 0. This sets \
the gain relative to the maximum possible gain.")
else:
abs_gain = rx_gain_max + gain
else:
abs_gain = gain
if abs_gain < rx_gain_min or abs_gain > rx_gain_max:
abs_gain = min(max(gain, rx_gain_min), rx_gain_max)
print(f"Gain {abs_gain} out of range for Blade.")
print(f"Gain range: {rx_gain_min} to {rx_gain_max} dB")
self.rx_gain = abs_gain
self.rx_ch.gain = abs_gain
print(f"Blade gain = {self.rx_ch.gain}")
def set_rx_buffer_size(self, buffer_size):
self.rx_buffer_size = buffer_size self.rx_buffer_size = buffer_size
def _set_tx_channel(self, channel): def set_tx_channel(self, channel):
if channel != 0 and channel != 1:
raise SDRParameterError("Channel must be either 0 or 1.")
self.tx_channel = channel self.tx_channel = channel
self.tx_ch = self.device.Channel(_bladerf.CHANNEL_TX(self.tx_channel)) self.tx_ch = self.device.Channel(_bladerf.CHANNEL_TX(self.tx_channel))
print(f"\nBlade channel = {self.tx_ch}") print(f"\nBlade channel = {self.tx_ch}")
def _set_tx_sample_rate(self, sample_rate): def set_tx_sample_rate(self, sample_rate):
if hasattr(self, "tx_channel"):
range_list = self.device.get_sample_rate_range(self.tx_channel)
min_rate, max_rate = range_list[0], range_list[1]
else:
raise SDRError("Must set channel before setting center frequency")
if sample_rate < min_rate or sample_rate > max_rate:
raise SDRParameterError(
f"{self.__class__.__name__}: Sample rate {sample_rate/1e6:.3f} Msps "
f"out of range: [{min_rate/1e6:.3f} - {max_rate/1e6:.3f} Msps]"
)
if sample_rate < min_rate or sample_rate > max_rate:
raise SDRParameterError(
f"{self.__class__.__name__}: Sample rate {sample_rate/1e6:.3f} Msps "
f"out of range: [{min_rate/1e6:.3f} - {max_rate/1e6:.3f} Msps]"
)
self.tx_sample_rate = sample_rate self.tx_sample_rate = sample_rate
self.tx_ch.sample_rate = self.tx_sample_rate self.tx_ch.sample_rate = self.tx_sample_rate
print(f"Blade sample rate = {self.tx_ch.sample_rate}") print(f"Blade sample rate = {self.tx_ch.sample_rate}")
def _set_tx_center_frequency(self, center_frequency): def set_tx_center_frequency(self, center_frequency):
if hasattr(self, "tx_channel"):
range_list = self.device.get_frequency_range(self.tx_channel)
min_rate, max_rate = range_list[0], range_list[1]
else:
raise SDRError("Must set channel before setting center frequency")
if center_frequency < min_rate or center_frequency > max_rate:
raise SDRParameterError(
f"{self.__class__.__name__}: Center frequency {center_frequency/1e9:.3f} GHz "
f"out of range: [{min_rate/1e9:.3f} - {max_rate/1e9:.3f} GHz]"
)
self.tx_center_frequency = center_frequency self.tx_center_frequency = center_frequency
self.tx_ch.frequency = center_frequency self.tx_ch.frequency = center_frequency
print(f"Blade center frequency = {self.tx_ch.frequency}") print(f"Blade center frequency = {self.tx_ch.frequency}")
def _set_tx_gain(self, channel, gain, gain_mode): def set_tx_gain(self, channel, gain, gain_mode):
tx_gain_min = self.device.get_gain_range(channel)[0] tx_gain_min = self.device.get_gain_range(channel)[0]
tx_gain_max = self.device.get_gain_range(channel)[1] tx_gain_max = self.device.get_gain_range(channel)[1]
if gain_mode == "relative": if gain_mode == "relative":
if gain > 0: if gain > 0:
raise ValueError( raise SDRParameterError("When gain_mode = 'relative', gain must be < 0. This sets\
"When gain_mode = 'relative', gain must be < 0. This sets\ the gain relative to the maximum possible gain.")
the gain relative to the maximum possible gain."
)
else: else:
abs_gain = tx_gain_max + gain abs_gain = tx_gain_max + gain
else: else:
@ -469,7 +565,7 @@ class Blade(SDR):
print(f"Blade gain = {self.tx_ch.gain}") print(f"Blade gain = {self.tx_ch.gain}")
def _set_tx_buffer_size(self, buffer_size): def set_tx_buffer_size(self, buffer_size):
self.tx_buffer_size = buffer_size self.tx_buffer_size = buffer_size
def set_clock_source(self, source): def set_clock_source(self, source):
@ -499,4 +595,20 @@ class Blade(SDR):
print(f"BladeRF bias tee {state} on channel {channel}.") print(f"BladeRF bias tee {state} on channel {channel}.")
def close(self): def close(self):
self.device.close() if hasattr(self, "device") and self.device is not None:
try:
if hasattr(self, "tx_ch"):
self.tx_ch.enable = False
if hasattr(self, "rx_ch"):
self.rx_ch.enable = False
self.device.close()
except Exception as e:
print(f"Warning: error closing bladeRF: {e}")
finally:
del self.device
self.device = None
gc.collect()
def supports_dynamic_updates(self) -> dict:
return {"center_frequency": False, "sample_rate": False, "gain": False}

View File

@ -4,9 +4,9 @@ from typing import Optional
import numpy as np import numpy as np
from ria_toolkit_oss.datatypes.recording import Recording from ria_toolkit_oss.data.recording import Recording
from ria_toolkit_oss.sdr._external.libhackrf import HackRF as hrf from ria_toolkit_oss.sdr._external.libhackrf import HackRF as hrf
from ria_toolkit_oss.sdr.sdr import SDR from ria_toolkit_oss.sdr.sdr import SDR, SDRParameterError
class HackRF(SDR): class HackRF(SDR):
@ -21,7 +21,7 @@ class HackRF(SDR):
""" """
if identifier != "": if identifier != "":
print(f"Warning, radio identifier {identifier} provided for HackRF but will not be used.") warnings.warn(f"HackRF: Identifier '{identifier}' will be ignored", UserWarning)
print("Initializing HackRF radio.") print("Initializing HackRF radio.")
try: try:
@ -33,8 +33,6 @@ class HackRF(SDR):
print("Failed to find HackRF radio.") print("Failed to find HackRF radio.")
raise e raise e
super().__init__()
def init_rx( def init_rx(
self, self,
sample_rate: int | float, sample_rate: int | float,
@ -60,18 +58,12 @@ class HackRF(SDR):
:param channel: The channel the HackRF is set to. (Not actually used) :param channel: The channel the HackRF is set to. (Not actually used)
:type channel: int :type channel: int
:param gain_mode: 'absolute' passes gain directly to the sdr, :param gain_mode: 'absolute' passes gain directly to the sdr,
'relative' means that gain should be a negative value, and it will be subtracted from the max gain (40). 'relative' means that gain should be a negative value, and it will be subtracted from the max gain (40).
:type gain_mode: str :type gain_mode: str
""" """
print("Initializing RX") print("Initializing RX")
self.set_sample_rate(sample_rate=sample_rate)
self.rx_sample_rate = sample_rate self.set_center_frequency(center_frequency=center_frequency)
self.radio.sample_rate = int(sample_rate)
print(f"HackRF sample rate = {self.radio.sample_rate}")
self.rx_center_frequency = center_frequency
self.radio.center_freq = int(center_frequency)
print(f"HackRF center frequency = {self.radio.center_freq}")
# Distribute gain across amplifier stages # Distribute gain across amplifier stages
rx_gain_min = 0 rx_gain_min = 0
@ -79,7 +71,7 @@ class HackRF(SDR):
if gain_mode == "relative": if gain_mode == "relative":
if gain > 0: if gain > 0:
raise ValueError( raise SDRParameterError(
"When gain_mode = 'relative', gain must be < 0. This " "When gain_mode = 'relative', gain must be < 0. This "
"sets the gain relative to the maximum possible gain." "sets the gain relative to the maximum possible gain."
) )
@ -99,7 +91,9 @@ class HackRF(SDR):
self.rx_gain = abs_gain self.rx_gain = abs_gain
print(f"HackRF gain distribution: Amp={self.amp_enabled}, LNA={self.rx_lna_gain}dB, VGA={self.rx_vga_gain}dB") print(f"HackRF gain distribution: Amp={self.amp_enabled}, LNA={self.rx_lna_gain}dB, VGA={self.rx_vga_gain}dB")
print("To individually modify the HackRF gains, use set_gain_amp(), set_rx_lna_gain(), and set_rx_vga_gain().") print(
"To individually modify the HackRF gains, use set_gain_amp(), set_rx_lna_gain(), and set_rx_vga_gain().\n"
)
self._tx_initialized = False self._tx_initialized = False
self._rx_initialized = True self._rx_initialized = True
@ -122,13 +116,13 @@ class HackRF(SDR):
raise RuntimeError("RX was not initialized. init_rx() must be called before _stream_rx() or record()") raise RuntimeError("RX was not initialized. init_rx() must be called before _stream_rx() or record()")
if num_samples is not None and rx_time is not None: if num_samples is not None and rx_time is not None:
raise ValueError("Only input one of num_samples or rx_time") raise SDRParameterError("Only input one of num_samples or rx_time")
elif num_samples is not None: elif num_samples is not None:
self._num_samples_to_record = num_samples self._num_samples_to_record = num_samples
elif rx_time is not None: elif rx_time is not None:
self._num_samples_to_record = int(rx_time * self.rx_sample_rate) self._num_samples_to_record = int(rx_time * self.sample_rate)
else: else:
raise ValueError("Must provide input of one of num_samples or rx_time") raise SDRParameterError("Must provide input of one of num_samples or rx_time")
print("HackRF Starting RX...") print("HackRF Starting RX...")
@ -137,18 +131,15 @@ class HackRF(SDR):
print("HackRF RX Completed.") print("HackRF RX Completed.")
# Create 1xN array for single-channel recording rx_complex = self.convert_rx_samples(rx_samples=all_samples)
store_array = np.zeros((1, self._num_samples_to_record), dtype=np.complex64)
store_array[0, :] = all_samples
metadata = { metadata = {
"source": self.__class__.__name__, "source": self.__class__.__name__,
"sample_rate": self.rx_sample_rate, "sample_rate": self.sample_rate,
"center_frequency": self.rx_center_frequency, "center_frequency": self.center_frequency,
"gain": self.rx_gain, "gain": self.rx_gain,
} }
return Recording(data=store_array, metadata=metadata) return Recording(data=rx_complex, metadata=metadata)
def init_tx( def init_tx(
self, self,
@ -174,22 +165,15 @@ class HackRF(SDR):
""" """
print("Initializing TX") print("Initializing TX")
self.tx_sample_rate = sample_rate self.set_sample_rate(sample_rate=sample_rate)
self.radio.sample_rate = int(sample_rate) self.set_center_frequency(center_frequency=center_frequency)
print(f"HackRF sample rate = {self.radio.sample_rate}")
self.tx_center_frequency = center_frequency
self.radio.center_freq = int(center_frequency)
print(f"HackRF center frequency = {self.radio.center_freq}")
tx_gain_min = 0 tx_gain_min = 0
tx_gain_max = 47 tx_gain_max = 47
if gain_mode == "relative": if gain_mode == "relative":
if gain > 0: if gain > 0:
raise ValueError( raise SDRParameterError("When gain_mode = 'relative', gain must be < 0. This \
"When gain_mode = 'relative', gain must be < 0. This \ sets the gain relative to the maximum possible gain.")
sets the gain relative to the maximum possible gain."
)
else: else:
abs_gain = tx_gain_max + gain abs_gain = tx_gain_max + gain
else: else:
@ -197,14 +181,14 @@ class HackRF(SDR):
if abs_gain < tx_gain_min or abs_gain > tx_gain_max: if abs_gain < tx_gain_min or abs_gain > tx_gain_max:
abs_gain = min(max(gain, tx_gain_min), tx_gain_max) abs_gain = min(max(gain, tx_gain_min), tx_gain_max)
print(f"Gain {gain} out of range for Pluto.") print(f"Gain {gain} out of range for HackRF.")
print(f"Gain range: {tx_gain_min} to {tx_gain_max} dB") print(f"Gain range: {tx_gain_min} to {tx_gain_max} dB")
self.set_gain_amp(True) self.set_gain_amp(True)
self.set_tx_vga_gain(abs_gain) self.set_tx_vga_gain(abs_gain)
self.tx_gain = abs_gain self.tx_gain = abs_gain
print(f"HackRF gain distribution: Amp={self.amp_enabled}, VGA={self.tx_vga_gain}dB") print(f"HackRF gain distribution: Amp={self.amp_enabled}, VGA={self.tx_vga_gain}dB")
print("To individually modify the HackRF gains, use set_gain_amp() or set_tx_vga_gain().") print("To individually modify the HackRF gains, use set_gain_amp() or set_tx_vga_gain().\n")
self._tx_initialized = True self._tx_initialized = True
self._rx_initialized = False self._rx_initialized = False
@ -229,13 +213,13 @@ class HackRF(SDR):
:type tx_time: int or float, optional :type tx_time: int or float, optional
""" """
if num_samples is not None and tx_time is not None: if num_samples is not None and tx_time is not None:
raise ValueError("Only input one of num_samples or tx_time") raise SDRParameterError("Only input one of num_samples or tx_time")
elif num_samples is not None: elif num_samples is not None:
tx_time = num_samples / self.tx_sample_rate tx_time = num_samples / self.sample_rate
elif tx_time is not None: elif tx_time is not None:
pass pass
else: else:
tx_time = len(recording) / self.tx_sample_rate tx_time = len(recording) / self.sample_rate
if isinstance(recording, np.ndarray): if isinstance(recording, np.ndarray):
samples = recording samples = recording
@ -275,6 +259,62 @@ class HackRF(SDR):
self.radio.set_txvga_gain(vga_gain) self.radio.set_txvga_gain(vga_gain)
self.tx_vga_gain = vga_gain self.tx_vga_gain = vga_gain
def set_sample_rate(self, sample_rate):
if sample_rate < 2e6 or sample_rate > 20e6:
raise SDRParameterError(
f"{self.__class__.__name__}: Sample rate {sample_rate/1e6:.3f} Msps "
f"out of range: [{2:.3f} - {20:.3f} Msps]"
)
self.sample_rate = sample_rate
self.radio.sample_rate = int(sample_rate)
print(f"HackRF sample rate = {self.radio.sample_rate}")
def set_rx_sample_rate(self, sample_rate):
"""
Set the sample rate.
Not callable during recording; HackRF requires stream stop/restart to change sample rate.
"""
self.set_sample_rate(sample_rate=sample_rate)
def set_tx_sample_rate(self, sample_rate):
self.set_sample_rate(sample_rate=sample_rate)
def set_center_frequency(self, center_frequency):
with self._param_lock:
if center_frequency < 1e6 or center_frequency > 6e9:
raise SDRParameterError(
f"{self.__class__.__name__}: Center frequency {center_frequency/1e9:.3f} GHz "
f"out of range: [{1e6/1e9:.3f} - {6e9/1e9:.3f} GHz]"
)
self.center_frequency = center_frequency
self.radio.center_freq = int(center_frequency)
print(f"HackRF center frequency = {self.radio.center_freq}")
def set_rx_center_frequency(self, center_frequency):
"""
Set the center frequency. Callable during streaming.
"""
self.set_center_frequency(center_frequency=center_frequency)
def set_tx_center_frequency(self, center_frequency):
self.set_center_frequency(center_frequency=center_frequency)
def convert_rx_samples(self, rx_samples):
# Handle conversion depending on dtype
if np.issubdtype(rx_samples.dtype, np.complexfloating):
# Already complex: just normalize
rx_complex = rx_samples.astype(np.complex64) / 128.0
elif np.issubdtype(rx_samples.dtype, np.integer):
# Raw interleaved I/Q bytes: convert to complex
i_samples = rx_samples[0::2].astype(np.float32)
q_samples = rx_samples[1::2].astype(np.float32)
rx_complex = (i_samples + 1j * q_samples) / 128.0
else:
raise TypeError(f"Unexpected dtype from read_samples: {rx_samples.dtype}")
# Ensure 2D array: 1xN for single channel
return rx_complex.reshape((1, -1))
def set_clock_source(self, source): def set_clock_source(self, source):
self.radio.set_clock_source(source) self.radio.set_clock_source(source)
@ -288,7 +328,11 @@ class HackRF(SDR):
raise NotImplementedError("Underlying HackRF interface lacks bias-tee control") from exc raise NotImplementedError("Underlying HackRF interface lacks bias-tee control") from exc
def close(self): def close(self):
self.radio.close() try:
self.radio.close()
del self.radio
finally:
self._enable_rx = False
def _stream_rx(self, callback): def _stream_rx(self, callback):
""" """
@ -342,3 +386,6 @@ class HackRF(SDR):
def _stream_tx(self, callback): def _stream_tx(self, callback):
return super()._stream_tx(callback) return super()._stream_tx(callback)
def supports_dynamic_updates(self) -> dict:
return {"center_frequency": True, "sample_rate": False, "gain": False}

View File

@ -0,0 +1,131 @@
"""Simulated SDR device for testing without hardware.
Set ``recorder.device = "mock"`` (or ``"sim"``) in a campaign config to use
this driver. The inference loop can also use it by specifying ``device:
"mock"`` in the SDR start request.
The mock generates complex float32 AWGN samples normalised to [-1, 1].
It satisfies both interfaces used in this codebase:
- ``record(num_samples)`` / ``_stream_rx(callback)`` used by
``CampaignExecutor`` (inherits from ``SDR`` base class).
- ``rx(num_samples)`` PlutoSDR-style interface used by the controller
inference loop.
"""
from __future__ import annotations
import time
import numpy as np
from ria_toolkit_oss.sdr.sdr import SDR
_DEFAULT_BUFFER_SIZE = 4096
# Simulated sample rate throttle: sleep this long between buffers so the
# loop does not spin at 100% CPU. 10 ms ≈ 100 buffers/s which is fine for
# tests and campaign execution timing.
_SLEEP_PER_BUFFER_S = 0.01
class MockSDR(SDR):
"""Software-simulated SDR that generates AWGN noise.
Args:
buffer_size: Number of complex samples per streaming buffer.
seed: Optional RNG seed for reproducible output.
"""
def __init__(self, buffer_size: int = _DEFAULT_BUFFER_SIZE, seed: int | None = None):
super().__init__()
self.rx_buffer_size: int = buffer_size
self._rng = np.random.default_rng(seed)
# Direct attribute aliases used by _apply_sdr_config in the controller.
self.center_freq: float = 2.45e9
self.sample_rate: float = 10e6
self.gain: float = 40.0
# ------------------------------------------------------------------
# Abstract method implementations
# ------------------------------------------------------------------
def init_rx(
self,
sample_rate: float,
center_frequency: float,
gain,
channel: int = 0,
gain_mode: str = "manual",
) -> None:
self.rx_sample_rate = float(sample_rate)
self.rx_center_frequency = float(center_frequency)
self.rx_gain = 40.0 if gain is None else float(gain)
# Mirror to the attribute names used by _apply_sdr_config.
self.sample_rate = self.rx_sample_rate
self.center_freq = self.rx_center_frequency
self.gain = self.rx_gain
self._rx_initialized = True
def init_tx(
self,
sample_rate: float,
center_frequency: float,
gain,
channel: int = 0,
gain_mode: str = "manual",
) -> None:
self.tx_sample_rate = float(sample_rate)
self.tx_center_frequency = float(center_frequency)
self.tx_gain = 40.0 if gain is None else float(gain)
self._tx_initialized = True
def _stream_rx(self, callback) -> None:
"""Generate 1-D AWGN buffers and pass each to *callback* until stopped.
Uses 1-D arrays so the base class ``_validate_buffer`` check does not
incorrectly flag them as corrupted (the (1, N) form triggers a false
positive in the all-same-value check).
"""
self._enable_rx = True
while self._enable_rx:
buf = self._awgn(self.rx_buffer_size)
callback(buf)
time.sleep(_SLEEP_PER_BUFFER_S)
def _stream_tx(self, callback) -> None:
self._enable_tx = True
while self._enable_tx:
callback(self.rx_buffer_size)
time.sleep(_SLEEP_PER_BUFFER_S)
def set_clock_source(self, source: str) -> None:
pass # no-op
def close(self) -> None:
self._enable_rx = False
self._enable_tx = False
self._rx_initialized = False
self._tx_initialized = False
# ------------------------------------------------------------------
# PlutoSDR-style interface used by the controller inference loop
# ------------------------------------------------------------------
def rx(self, num_samples: int) -> np.ndarray:
"""Return *num_samples* complex64 AWGN samples (PlutoSDR-style)."""
return self._awgn(num_samples)
# ------------------------------------------------------------------
# Internal helpers
# ------------------------------------------------------------------
def _awgn(self, n: int) -> np.ndarray:
"""Return *n* normalised complex64 AWGN samples as a 1-D array."""
real = self._rng.standard_normal(n).astype(np.float32)
imag = self._rng.standard_normal(n).astype(np.float32)
buf = real + 1j * imag
peak = np.abs(buf).max()
if peak > 1e-9:
buf /= peak
return buf

View File

@ -7,8 +7,13 @@ from typing import Optional
import adi import adi
import numpy as np import numpy as np
from ria_toolkit_oss.datatypes.recording import Recording from ria_toolkit_oss.data.recording import Recording
from ria_toolkit_oss.sdr.sdr import SDR from ria_toolkit_oss.sdr.sdr import (
SDR,
SDRError,
SDRParameterError,
translate_disconnect,
)
class Pluto(SDR): class Pluto(SDR):
@ -28,6 +33,7 @@ class Pluto(SDR):
print(f"Initializing Pluto radio with identifier [{identifier}].") print(f"Initializing Pluto radio with identifier [{identifier}].")
try: try:
super().__init__() super().__init__()
self._tx_lock = threading.Lock()
if identifier is None: if identifier is None:
uri = "ip:pluto.local" uri = "ip:pluto.local"
@ -74,10 +80,12 @@ class Pluto(SDR):
:type center_frequency: int or float :type center_frequency: int or float
:param gain: The gain set for receiving on the Pluto :param gain: The gain set for receiving on the Pluto
:type gain: int :type gain: int
:param channel: The channel the Pluto is set to. Must be 0 or 1. 0 enables channel 1, 1 enables both channels. :param channel: The channel the Pluto is set to. Must be 0 or 1. 0
enables channel 1, 1 enables both channels.
:type channel: int :type channel: int
:param gain_mode: 'absolute' passes gain directly to the sdr, :param gain_mode: 'absolute' passes gain directly to the sdr,
'relative' means that gain should be a negative value, and it will be subtracted from the max gain (74). 'relative' means that gain should be a negative value, and it will
be subtracted from the max gain (74).
:type gain_mode: str :type gain_mode: str
""" """
print("Initializing RX") print("Initializing RX")
@ -88,20 +96,7 @@ class Pluto(SDR):
self.set_rx_center_frequency(center_frequency=int(center_frequency)) self.set_rx_center_frequency(center_frequency=int(center_frequency))
print(f"Pluto center frequency = {self.radio.rx_lo}") print(f"Pluto center frequency = {self.radio.rx_lo}")
if channel == 0: self.set_rx_channel(channel=channel)
self.radio.rx_enabled_channels = [0]
print(f"Pluto channel(s) = {self.radio.rx_enabled_channels}")
elif channel == 1:
if not self._mimo_capable:
raise ValueError(
"Dual RX channel requested (channel=1) but hardware is not MIMO-capable. "
"Dual RX/TX requires Pluto Rev C/D. Detected hardware: Rev B (single channel only)."
)
self.radio.rx_enabled_channels = [0, 1]
print(f"Pluto channel(s) = {self.radio.rx_enabled_channels}")
else:
raise ValueError("Channel must be either 0 or 1.")
self.set_rx_gain(gain=gain, channel=channel, gain_mode=gain_mode) self.set_rx_gain(gain=gain, channel=channel, gain_mode=gain_mode)
if channel == 0: if channel == 0:
print(f"Pluto gain = {self.radio.rx_hardwaregain_chan0}") print(f"Pluto gain = {self.radio.rx_hardwaregain_chan0}")
@ -109,8 +104,6 @@ class Pluto(SDR):
self.set_rx_gain(gain=gain, channel=0, gain_mode=gain_mode) self.set_rx_gain(gain=gain, channel=0, gain_mode=gain_mode)
print(f"Pluto gain = {self.radio.rx_hardwaregain_chan0}, {self.radio.rx_hardwaregain_chan1}") print(f"Pluto gain = {self.radio.rx_hardwaregain_chan0}, {self.radio.rx_hardwaregain_chan1}")
self.set_rx_buffer_size(getattr(self, "rx_buffer_size", 1024))
self._rx_initialized = True self._rx_initialized = True
self._tx_initialized = False self._tx_initialized = False
@ -134,10 +127,12 @@ class Pluto(SDR):
:type center_frequency: int or float :type center_frequency: int or float
:param gain: The gain set for transmitting on the Pluto :param gain: The gain set for transmitting on the Pluto
:type gain: int :type gain: int
:param channel: The channel the Pluto is set to. Must be 0 or 1. 0 enables channel 1, 1 enables both channels. :param channel: The channel the Pluto is set to. Must be 0 or 1. 0
enables channel 1, 1 enables both channels.
:type channel: int :type channel: int
:param gain_mode: 'absolute' passes gain directly to the sdr, :param gain_mode: 'absolute' passes gain directly to the sdr,
'relative' means that gain should be a negative value, and it will be subtracted from the max gain (0). 'relative' means that gain should be a negative value, and it will
be subtracted from the max gain (0).
:type gain_mode: str :type gain_mode: str
""" """
@ -149,20 +144,7 @@ class Pluto(SDR):
self.set_tx_center_frequency(center_frequency=int(center_frequency)) self.set_tx_center_frequency(center_frequency=int(center_frequency))
print(f"Pluto center frequency = {self.radio.tx_lo}") print(f"Pluto center frequency = {self.radio.tx_lo}")
if channel == 0: self.set_tx_channel(channel=channel)
self.radio.tx_enabled_channels = [0]
print(f"Pluto channel(s) = {self.radio.tx_enabled_channels}")
elif channel == 1:
if not self._mimo_capable:
raise ValueError(
"Dual TX channel requested (channel=1) but hardware is not MIMO-capable. "
"Dual RX/TX requires Pluto Rev C/D. Detected hardware: Rev B (single channel only)."
)
self.radio.tx_enabled_channels = [0, 1]
print(f"Pluto channel(s) = {self.radio.tx_enabled_channels}")
else:
raise ValueError("Channel must be either 0 or 1.")
self.set_tx_gain(gain=gain, channel=channel, gain_mode=gain_mode) self.set_tx_gain(gain=gain, channel=channel, gain_mode=gain_mode)
if channel == 0: if channel == 0:
print(f"Pluto gain = {self.radio.tx_hardwaregain_chan0}") print(f"Pluto gain = {self.radio.tx_hardwaregain_chan0}")
@ -179,16 +161,93 @@ class Pluto(SDR):
if not self._rx_initialized: if not self._rx_initialized:
raise RuntimeError("RX was not initialized. init_rx() must be called before _stream_rx() or record()") raise RuntimeError("RX was not initialized. init_rx() must be called before _stream_rx() or record()")
# print("Starting rx...")
self._enable_rx = True self._enable_rx = True
while self._enable_rx is True: while self._enable_rx is True:
# collect complex signa from radio
signal = self.radio.rx() signal = self.radio.rx()
signal = self._convert_rx_samples(signal)
# send callback complex signal # send callback complex signal
callback(buffer=signal, metadata=None) callback(buffer=signal, metadata=None)
def record(self, num_samples: Optional[int] = None, rx_time: Optional[int | float] = None): def rx(self, num_samples: Optional[int] = None) -> np.ndarray:
"""PlutoSDR-style single-buffer capture returning a complex64 array.
Sets the radio buffer size to *num_samples* (if given) and returns one
buffer directly from ``self.radio.rx()``. Raises
:class:`SdrDisconnectedError` on USB/device drop so callers (e.g. the
streamer) can report the failure and stop cleanly instead of crashing.
"""
if num_samples is not None:
try:
self.set_rx_buffer_size(buffer_size=int(num_samples))
except Exception as exc:
raise translate_disconnect(exc) from exc
try:
samples = self.radio.rx()
except Exception as exc:
raise translate_disconnect(exc) from exc
return np.asarray(samples)
def _record_fast(self, num_samples):
"""Optimized single-buffer capture for ≤16M samples."""
self.set_rx_buffer_size(buffer_size=num_samples)
print("Pluto Starting RX...")
samples = self.radio.rx()
# Handle single/dual channel
if self.radio.rx_enabled_channels == [0]:
samples = [self._convert_rx_samples(samples)]
else:
samples = [self._convert_rx_samples(s) for s in samples]
print("Pluto RX Completed.")
metadata = {
"source": self.__class__.__name__,
"sample_rate": self.rx_sample_rate,
"center_frequency": self.rx_center_frequency,
"gain": self.rx_gain,
}
return Recording(data=samples, metadata=metadata)
def _record_chunked(self, num_samples):
"""Chunked streaming capture for >2M samples."""
# Use base class streaming with pre-allocation
chunk_size = 2_000_000 # 2M sample chunks (safe size)
self.set_rx_buffer_size(buffer_size=chunk_size)
self._max_num_buffers = (num_samples // chunk_size) + 1
self._num_buffers_processed = 0
self._accumulated_buffer = None
# Stream with accumulation callback
print("Pluto Starting RX...")
self._stream_rx(callback=self._accumulate_buffers_callback)
print("Pluto RX Completed.")
print(f"Corrupted buffer count: {self._corrupted_buffer_count}")
# Truncate to exact size
samples = self._accumulated_buffer[:, :num_samples]
samples_list = [self._convert_rx_samples(chan) for chan in samples]
metadata = {
"source": self.__class__.__name__,
"sample_rate": self.rx_sample_rate,
"center_frequency": self.rx_center_frequency,
"gain": self.rx_gain,
}
# Reset for next capture
self._accumulated_buffer = None
return Recording(data=samples_list, metadata=metadata)
def record(
self,
num_samples: Optional[int] = None,
rx_time: Optional[int | float] = None,
) -> Recording:
""" """
Create a radio recording (iq samples and metadata) of a given length from the SDR. Create a radio recording (iq samples and metadata) of a given length from the SDR.
Either num_samples or rx_time must be provided. Either num_samples or rx_time must be provided.
@ -205,38 +264,19 @@ class Pluto(SDR):
raise RuntimeError("RX was not initialized. init_rx() must be called before _stream_rx() or record()") raise RuntimeError("RX was not initialized. init_rx() must be called before _stream_rx() or record()")
if num_samples is not None and rx_time is not None: if num_samples is not None and rx_time is not None:
raise ValueError("Only input one of num_samples or rx_time") raise SDRParameterError("Only input one of num_samples or rx_time")
elif num_samples is not None: elif num_samples is not None:
self._num_samples_to_record = num_samples self._num_samples_to_record = num_samples
elif rx_time is not None: elif rx_time is not None:
self._num_samples_to_record = int(rx_time * self.rx_sample_rate) self._num_samples_to_record = int(rx_time * self.rx_sample_rate)
else: else:
raise ValueError("Must provide input of one of num_samples or rx_time") raise SDRParameterError("Must provide input of one of num_samples or rx_time")
if self._num_samples_to_record > 16000000: # Record in one go if there are less than 2,000,000 samples to record, record in chunks otherwise
raise NotImplementedError("Pluto record for num_samples>16M not implemented yet.") if self._num_samples_to_record <= 2_000_000:
self.radio.rx_buffer_size = self._num_samples_to_record return self._record_fast(self._num_samples_to_record)
print("Pluto Starting RX...")
samples = self.radio.rx()
if self.radio.rx_enabled_channels == [0]:
samples = self._convert_rx_samples(samples)
samples = [samples]
else: else:
channel1 = self._convert_rx_samples(samples[0]) return self._record_chunked(self._num_samples_to_record)
channel2 = self._convert_rx_samples(samples[1])
samples = [channel1, channel2]
print("Pluto RX Completed.")
metadata = {
"source": self.__class__.__name__,
"sample_rate": self.rx_sample_rate,
"center_frequency": self.rx_center_frequency,
"gain": self.rx_gain,
}
recording = Recording(data=samples, metadata=metadata)
return recording
def _format_tx_data(self, recording: Recording | np.ndarray | list): def _format_tx_data(self, recording: Recording | np.ndarray | list):
if isinstance(recording, np.ndarray): if isinstance(recording, np.ndarray):
@ -258,20 +298,16 @@ class Pluto(SDR):
data = [self._convert_tx_samples(samples), self._convert_tx_samples(samples)] data = [self._convert_tx_samples(samples), self._convert_tx_samples(samples)]
else: else:
if len(recording) > 2: if len(recording) > 2:
warnings.warn( warnings.warn("More recordings were provided than channels in the Pluto. \
"More recordings were provided than channels in the Pluto. \ Only the first two recordings will be used")
Only the first two recordings will be used"
)
sample0 = self._convert_tx_samples(recording.data[0]) sample0 = self._convert_tx_samples(recording.data[0])
sample1 = self._convert_tx_samples(recording.data[1]) sample1 = self._convert_tx_samples(recording.data[1])
data = [sample0, sample1] data = [sample0, sample1]
elif isinstance(recording, list): elif isinstance(recording, list):
if len(recording) > 2: if len(recording) > 2:
warnings.warn( warnings.warn("More recordings were provided than channels in the Pluto. \
"More recordings were provided than channels in the Pluto. \ Only the first two recordings will be used")
Only the first two recordings will be used"
)
if isinstance(recording[0], np.ndarray): if isinstance(recording[0], np.ndarray):
data = [self._convert_tx_samples(recording[0]), self._convert_tx_samples(recording[1])] data = [self._convert_tx_samples(recording[0]), self._convert_tx_samples(recording[1])]
@ -289,8 +325,9 @@ class Pluto(SDR):
print("Pluto TX Completed.") print("Pluto TX Completed.")
def interrupt_transmit(self): def interrupt_transmit(self):
self.radio.tx_destroy_buffer() with self._tx_lock:
self.radio.tx_cyclic_buffer = False self.radio.tx_destroy_buffer()
self.radio.tx_cyclic_buffer = False
print("Pluto TX Completed.") print("Pluto TX Completed.")
def tx_recording(self, recording: Recording | np.ndarray | list, num_samples=None, tx_time=None, mode="timed"): def tx_recording(self, recording: Recording | np.ndarray | list, num_samples=None, tx_time=None, mode="timed"):
@ -310,92 +347,128 @@ class Pluto(SDR):
:type mode: str, optional :type mode: str, optional
""" """
if num_samples is not None and tx_time is not None: if num_samples is not None and tx_time is not None:
raise ValueError("Only input one of num_samples or tx_time") raise SDRParameterError("Only input one of num_samples or tx_time")
elif num_samples is not None: elif num_samples is not None:
tx_time = num_samples / self.tx_sample_rate tx_time = num_samples / self.tx_sample_rate
elif tx_time is not None: elif tx_time is not None:
pass pass
else: else:
tx_time = len(recording) / self.tx_sample_rate if isinstance(recording, Recording):
tx_time = recording.data.shape[-1] / self.tx_sample_rate
elif isinstance(recording, np.ndarray):
tx_time = recording.shape[-1] / self.tx_sample_rate
else:
tx_time = len(recording[0]) / self.tx_sample_rate
data = self._format_tx_data(recording=recording) data = self._format_tx_data(recording=recording)
try: with self._tx_lock:
if self.radio.tx_cyclic_buffer: try:
print("Destroying existing TX buffer...") if self.radio.tx_cyclic_buffer:
self.radio.tx_destroy_buffer() print("Destroying existing TX buffer...")
self.radio.tx_cyclic_buffer = False self.radio.tx_destroy_buffer()
except Exception as e: self.radio.tx_cyclic_buffer = False
print(f"Error while destroying TX buffer: {e}") except Exception as e:
print(f"Error while destroying TX buffer: {e}")
self.radio.tx_cyclic_buffer = True self.radio.tx_cyclic_buffer = True
print("Pluto Starting TX...") print("Pluto Starting TX...")
self.radio.tx(data_np=data) self.radio.tx(data_np=data)
if mode == "timed": if mode == "timed":
timeout_thread = threading.Thread(target=self._timeout_cyclic_buffer, args=([tx_time])) timeout_thread = threading.Thread(target=self._timeout_cyclic_buffer, args=([tx_time]))
timeout_thread.start() timeout_thread.start()
timeout_thread.join() timeout_thread.join()
def _stream_tx(self, callback): def _stream_tx(self, callback):
if self._tx_initialized is False: if self._tx_initialized is False:
raise RuntimeError("TX was not initialized, init_tx must be called before _stream_tx") raise RuntimeError("TX was not initialized, init_tx must be called before _stream_tx")
num_samples = 10000 if not hasattr(self, "tx_buffer_size"):
# TODO remove hardcode self.tx_buffer_size = 10000
self._enable_tx = True self._enable_tx = True
while self._enable_tx is True: while self._enable_tx is True:
buffer = self._convert_tx_samples(callback(num_samples)) buffer = self._convert_tx_samples(callback(self.tx_buffer_size))
self.radio.tx(buffer[0]) # pyadi-iio's ``radio.tx`` auto-wraps single-channel 1-D input.
# Indexing ``buffer[0]`` was a latent bug for callbacks that
# returned 1-D samples (scalar → TypeError inside pyadi).
self.radio.tx(buffer)
def set_rx_center_frequency(self, center_frequency): def set_rx_center_frequency(self, center_frequency):
try: """
self.radio.rx_lo = int(center_frequency) Set the center frequency of the receiver. Callable during streaming.
self.rx_center_frequency = center_frequency """
except OSError as e: with self._param_lock:
_handle_OSError(e) if center_frequency < 70e6 or center_frequency > 6e9:
except ValueError as e: raise SDRParameterError(
_handle_OSError(e) f"{self.__class__.__name__}: Center frequency {center_frequency/1e9:.3f} GHz "
f"out of range:\nStandard:\t[{325e6/1e9:.3f} - {3.8e9/1e9:.3f} GHz]\nHacked:\t"
f"[{70e6/1e9:.3f} - {6e9/1e9:.3f} GHz]"
)
try:
self.radio.rx_lo = int(center_frequency)
self.rx_center_frequency = center_frequency
except OSError as e:
raise SDRError(e)
except ValueError:
raise SDRParameterError(
f"{self.__class__.__name__}: Center frequency {center_frequency/1e9:.3f} GHz "
f"out of range:\nStandard:\t[{325e6/1e9:.3f} - {3.8e9/1e9:.3f} GHz]\nHacked:\t"
f"[{70e6/1e9:.3f} - {6e9/1e9:.3f} GHz]"
)
def set_rx_sample_rate(self, sample_rate): def set_rx_sample_rate(self, sample_rate):
self.rx_sample_rate = sample_rate """
Set the sample rate of the receiver. Callable during streaming.
"""
with self._param_lock:
min_rate, max_rate = 65.1e3, 61.44e6
if sample_rate < min_rate or sample_rate > max_rate:
raise SDRParameterError(
f"{self.__class__.__name__}: Sample rate {sample_rate/1e6:.3f} Msps "
f"out of range: [{min_rate/1e6:.3f} - {max_rate/1e6:.3f} Msps]"
)
# TODO add logic for limiting sample rate try:
# set the sample rate
self.radio.sample_rate = int(sample_rate)
self.rx_sample_rate = sample_rate
try: # set the front end filter width
self.radio.sample_rate = int(sample_rate) self.radio.rx_rf_bandwidth = int(sample_rate)
except OSError as e:
# set the front end filter width raise SDRError(e)
self.radio.rx_rf_bandwidth = int(sample_rate) except ValueError:
except OSError as e: raise SDRParameterError(
_handle_OSError(e) f"{self.__class__.__name__}: Sample rate {sample_rate/1e6:.3f} Msps "
except ValueError as e: f"out of range: [{min_rate/1e6:.3f} - {max_rate/1e6:.3f} Msps]"
_handle_OSError(e) )
def set_rx_gain(self, gain, channel=0, gain_mode="absolute"): def set_rx_gain(self, gain, channel=0, gain_mode="absolute"):
rx_gain_min = 0 """
rx_gain_max = 74 Set the gain of the receiver. Callable during streaming.
"""
with self._param_lock:
rx_gain_min = 0
rx_gain_max = 74
if gain_mode == "relative": if gain_mode == "relative":
if gain > 0: if gain > 0:
raise ValueError( raise SDRParameterError("When gain_mode = 'relative', gain must be < 0. This sets \
"When gain_mode = 'relative', gain must be < 0. This sets \ the gain relative to the maximum possible gain.")
the gain relative to the maximum possible gain." else:
) abs_gain = rx_gain_max + gain
else: else:
abs_gain = rx_gain_max + gain abs_gain = gain
else:
abs_gain = gain
if abs_gain < rx_gain_min or abs_gain > rx_gain_max: if abs_gain < rx_gain_min or abs_gain > rx_gain_max:
abs_gain = min(max(gain, rx_gain_min), rx_gain_max) abs_gain = min(max(abs_gain, rx_gain_min), rx_gain_max)
print(f"Gain {gain} out of range for Pluto.") print(f"Gain {gain} out of range for Pluto.")
print(f"Gain range: {rx_gain_min} to {rx_gain_max} dB") print(f"Gain range: {rx_gain_min} to {rx_gain_max} dB")
self.rx_gain = abs_gain self.rx_gain = abs_gain
try:
if channel == 0: if channel == 0:
if abs_gain is None: if abs_gain is None:
self.radio.gain_control_mode_chan0 = "automatic" self.radio.gain_control_mode_chan0 = "automatic"
print("Using Pluto Automatic Gain Control.") print("Using Pluto Automatic Gain Control.")
@ -415,120 +488,150 @@ class Pluto(SDR):
self.radio.rx_hardwaregain_chan1 = abs_gain # dB self.radio.rx_hardwaregain_chan1 = abs_gain # dB
except Exception as e: except Exception as e:
print("Failed to use channel 1 on the PlutoSDR. \nThis is only available for revC versions.") print("Failed to use channel 1 on the PlutoSDR.\nThis is only available for revC versions.")
raise e raise e
else: else:
raise ValueError(f"Pluto channel must be 0 or 1 but was {channel}.") raise SDRParameterError(f"Pluto channel must be 0 or 1 but was {channel}.")
except OSError as e:
_handle_OSError(e)
except ValueError as e:
_handle_OSError(e)
def set_rx_channel(self, channel): def set_rx_channel(self, channel):
if channel == 0: if channel == 0:
self.radio.rx_enabled_channels = [0] self.radio.rx_enabled_channels = [0]
print(f"Pluto channel(s) = {self.radio.rx_enabled_channels}")
elif channel == 1: elif channel == 1:
if not self._mimo_capable:
raise SDRParameterError(
"Dual RX channel requested (channel=1) but hardware is not MIMO-capable. "
"Dual RX/TX requires Pluto Rev C/D. Detected hardware: Rev B (single channel only)."
)
self.radio.rx_enabled_channels = [0, 1] self.radio.rx_enabled_channels = [0, 1]
print(f"Pluto channel(s) = {self.radio.rx_enabled_channels}")
else: else:
raise ValueError("Channel must be either 0 or 1.") raise SDRParameterError("Channel must be either 0 or 1.")
def set_rx_buffer_size(self, buffer_size): print(f"Pluto channel(s) = {self.radio.rx_enabled_channels}")
def set_rx_buffer_size(self, buffer_size: int):
if buffer_size is None: if buffer_size is None:
raise ValueError("Buffer_size must be provided.") raise SDRParameterError("Buffer_size must be provided.")
buffer_size = int(buffer_size)
if buffer_size <= 0: if buffer_size <= 0:
raise ValueError("Buffer_size must be a positive integer.") raise SDRParameterError("Buffer_size must be a positive integer.")
self.rx_buffer_size = buffer_size
if hasattr(self, "radio"): if hasattr(self, "radio"):
try: try:
self.radio.rx_buffer_size = buffer_size self.radio.rx_buffer_size = buffer_size
except OSError as e: except Exception as e:
_handle_OSError(e) raise SDRError(e)
except ValueError as e:
_handle_OSError(e)
def set_tx_center_frequency(self, center_frequency): def set_tx_center_frequency(self, center_frequency):
try: # ``adi.Pluto`` exposes one radio handle shared between RX and TX; concurrent
self.radio.tx_lo = int(center_frequency) # RX + TX sessions (see the agent ``_SdrRegistry``) may call RX and TX
self.tx_center_frequency = center_frequency # setters at the same time. Serialize with ``_param_lock`` — RX setters hold
# the same reentrant lock — so native attribute writes don't interleave.
with self._param_lock:
if center_frequency < 70e6 or center_frequency > 6e9:
raise SDRParameterError(
f"{self.__class__.__name__}: Center frequency {center_frequency/1e9:.3f} GHz "
f"out of range:\nStandard:\t[{325e6/1e9:.3f} - {3.8e9/1e9:.3f} GHz]\nHacked:\t"
f"[{70e6/1e9:.3f} - {6e9/1e9:.3f} GHz]"
)
except OSError as e: try:
_handle_OSError(e) self.radio.tx_lo = int(center_frequency)
except ValueError as e: self.tx_center_frequency = center_frequency
_handle_OSError(e) except OSError as e:
raise SDRError(e)
except ValueError:
raise SDRParameterError(
f"{self.__class__.__name__}: Center frequency {center_frequency/1e9:.3f} GHz "
f"out of range:\nStandard:\t[{325e6/1e9:.3f} - {3.8e9/1e9:.3f} GHz]\nHacked:\t"
f"[{70e6/1e9:.3f} - {6e9/1e9:.3f} GHz]"
)
def set_tx_sample_rate(self, sample_rate): def set_tx_sample_rate(self, sample_rate):
try: # ``self.radio.sample_rate`` is shared between RX and TX on Pluto — RX's
self.radio.sample_rate = sample_rate # ``set_rx_sample_rate`` writes the same native attribute. Hold ``_param_lock``
self.tx_sample_rate = sample_rate # so full-duplex sessions can't interleave writes.
with self._param_lock:
min_rate, max_rate = 65.1e3, 61.44e6
if sample_rate < min_rate or sample_rate > max_rate:
raise SDRParameterError(
f"{self.__class__.__name__}: Sample rate {sample_rate/1e6:.3f} Msps "
f"out of range: [{min_rate/1e6:.3f} - {max_rate/1e6:.3f} Msps]"
)
except OSError as e: try:
_handle_OSError(e) self.radio.sample_rate = sample_rate
except ValueError as e: self.tx_sample_rate = sample_rate
_handle_OSError(e) except OSError as e:
raise SDRError(e)
except ValueError:
raise SDRParameterError(
f"{self.__class__.__name__}: Sample rate {sample_rate/1e6:.3f} Msps "
f"out of range: [{min_rate/1e6:.3f} - {max_rate/1e6:.3f} Msps]"
)
def set_tx_gain(self, gain, channel=0, gain_mode="absolute"): def set_tx_gain(self, gain, channel=0, gain_mode="absolute"):
tx_gain_min = -89 # Serialize with RX setters: see ``set_tx_sample_rate`` above.
tx_gain_max = 0 with self._param_lock:
tx_gain_min = -89
tx_gain_max = 0
if gain_mode == "relative": if gain_mode == "relative":
if gain > 0: if gain > 0:
raise ValueError( raise SDRParameterError("When gain_mode = 'relative', gain must be < 0. This sets\
"When gain_mode = 'relative', gain must be < 0. This sets\ the gain relative to the maximum possible gain.")
the gain relative to the maximum possible gain." else:
) abs_gain = tx_gain_max + gain
else: else:
abs_gain = tx_gain_max + gain abs_gain = gain
else:
abs_gain = gain
if abs_gain < tx_gain_min or abs_gain > tx_gain_max: if abs_gain < tx_gain_min or abs_gain > tx_gain_max:
abs_gain = min(max(gain, tx_gain_min), tx_gain_max) abs_gain = min(max(gain, tx_gain_min), tx_gain_max)
print(f"Gain {gain} out of range for Pluto.") print(f"Gain {gain} out of range for Pluto.")
print(f"Gain range: {tx_gain_min} to {tx_gain_max} dB") print(f"Gain range: {tx_gain_min} to {tx_gain_max} dB")
try: try:
self.tx_gain = abs_gain self.tx_gain = abs_gain
if channel == 0: if channel == 0:
self.radio.tx_hardwaregain_chan0 = int(abs_gain) self.radio.tx_hardwaregain_chan0 = int(abs_gain)
elif channel == 1: elif channel == 1:
self.radio.tx_hardwaregain_chan1 = int(abs_gain) self.radio.tx_hardwaregain_chan1 = int(abs_gain)
else: else:
raise ValueError(f"Pluto channel must be 0 or 1 but was {channel}.") raise SDRParameterError(f"Pluto channel must be 0 or 1 but was {channel}.")
except OSError as e: except Exception as e:
_handle_OSError(e) raise SDRError(e)
except ValueError as e:
_handle_OSError(e)
def set_tx_channel(self, channel): def set_tx_channel(self, channel):
if channel == 1: if channel == 0:
self.radio.tx_enabled_channels = [0, 1]
print(f"Pluto channel(s) = {self.radio.tx_enabled_channels}")
elif channel == 0:
self.radio.tx_enabled_channels = [0] self.radio.tx_enabled_channels = [0]
print(f"Pluto channel(s) = {self.radio.tx_enabled_channels}") elif channel == 1:
if not self._mimo_capable:
raise SDRParameterError(
"Dual TX channel requested (channel=1) but hardware is not MIMO-capable. "
"Dual RX/TX requires Pluto Rev C/D. Detected hardware: Rev B (single channel only)."
)
self.radio.tx_enabled_channels = [0, 1]
else: else:
raise ValueError("Channel must be either 0 or 1.") raise SDRParameterError("Channel must be either 0 or 1.")
def set_tx_buffer_size(self, buffer_size): print(f"Pluto channel(s) = {self.radio.tx_enabled_channels}")
raise NotImplementedError
def set_tx_buffer_size(self, buffer_size: int):
if buffer_size is None:
raise SDRParameterError("Buffer_size must be provided.")
if buffer_size <= 0:
raise SDRParameterError("Buffer_size must be a positive integer.")
self.tx_buffer_size = buffer_size
def close(self): def close(self):
if not hasattr(self, "radio"):
return
if self.radio.tx_cyclic_buffer: if self.radio.tx_cyclic_buffer:
self.radio.tx_destroy_buffer() self.radio.tx_destroy_buffer()
del self.radio del self.radio
def shutdown(self):
del self.radio
def _convert_rx_samples(self, samples): def _convert_rx_samples(self, samples):
return samples / (2**11) return samples / (2**11)
@ -538,6 +641,9 @@ class Pluto(SDR):
def set_clock_source(self, source): def set_clock_source(self, source):
raise NotImplementedError raise NotImplementedError
def supports_dynamic_updates(self) -> dict:
return {"center_frequency": True, "sample_rate": True, "gain": True}
def _handle_OSError(e): def _handle_OSError(e):

View File

@ -11,8 +11,8 @@ try:
except ImportError as exc: # pragma: no cover - dependency provided by end user except ImportError as exc: # pragma: no cover - dependency provided by end user
raise ImportError("pyrtlsdr is required to use the RTLSDR class") from exc raise ImportError("pyrtlsdr is required to use the RTLSDR class") from exc
from ria_toolkit_oss.datatypes.recording import Recording from ria_toolkit_oss.data.recording import Recording
from ria_toolkit_oss.sdr.sdr import SDR from ria_toolkit_oss.sdr.sdr import SDR, SDRParameterError
class RTLSDR(SDR): class RTLSDR(SDR):
@ -45,8 +45,7 @@ class RTLSDR(SDR):
print(f"Initialized RTL-SDR with identifier [{identifier}].") print(f"Initialized RTL-SDR with identifier [{identifier}].")
except Exception as e: except Exception as e:
print(f"Failed to find RTL-SDR with identifier [{identifier}].") raise RuntimeError(f"RTL-SDR: Failed to find device with identifier '{identifier}'\nError: {e}")
raise e
def init_rx( def init_rx(
self, self,
@ -55,18 +54,18 @@ class RTLSDR(SDR):
gain: Optional[int], gain: Optional[int],
channel: int, channel: int,
gain_mode: Optional[str] = "absolute", gain_mode: Optional[str] = "absolute",
buffer_size: Optional[int] = 256_000,
bias_t: bool = False, bias_t: bool = False,
): ):
if channel not in (0, None): if channel not in (0, None):
raise ValueError("RTL-SDR supports only channel 0 for RX.") raise SDRParameterError("RTL-SDR supports only channel 0 for RX.")
self.set_rx_sample_rate(sample_rate=sample_rate) self.set_rx_sample_rate(sample_rate=sample_rate)
self.set_rx_center_frequency(center_frequency=center_frequency) self.set_rx_center_frequency(center_frequency=center_frequency)
self.set_rx_gain(gain=gain, gain_mode=gain_mode) self.set_rx_gain(gain=gain, gain_mode=gain_mode)
self.rx_buffer_size = int(buffer_size or self.rx_buffer_size)
self.rx_channel = 0 self.rx_channel = 0
self.rx_buffer_size = self._calculate_optimal_buffer_size(sample_rate)
print(f"RTL-SDR buffer: {self.rx_buffer_size} samples for {sample_rate/1e6:.1f} MS/s")
if bias_t: if bias_t:
self.set_bias_tee(True) self.set_bias_tee(True)
@ -78,58 +77,98 @@ class RTLSDR(SDR):
return {"sample_rate": self.rx_sample_rate, "center_frequency": self.rx_center_frequency, "gain": self.rx_gain} return {"sample_rate": self.rx_sample_rate, "center_frequency": self.rx_center_frequency, "gain": self.rx_gain}
def set_rx_sample_rate(self, sample_rate): def set_rx_sample_rate(self, sample_rate):
"""
Set the sample rate of the receiver.
Not callable during recording; RTL-SDR requires stream stop/restart to change sample rate.
"""
if not ((sample_rate > 230e3 and sample_rate < 300e3) or (sample_rate > 900 and sample_rate < 3.2e6)):
raise SDRParameterError(
f"{self.__class__.__name__}: Sample rate {sample_rate/1e6:.3f} Msps "
f"out of range: [{2:.3f} - {20:.3f} Msps]"
)
self.radio.sample_rate = float(sample_rate) self.radio.sample_rate = float(sample_rate)
self.rx_sample_rate = self.radio.sample_rate self.rx_sample_rate = self.radio.sample_rate
print(f"RTL RX Sample Rate = {self.radio.get_sample_rate()}") print(f"RTL RX Sample Rate = {self.radio.get_sample_rate()}")
def set_rx_center_frequency(self, center_frequency): def set_rx_center_frequency(self, center_frequency):
self.radio.center_freq = float(center_frequency) """
self.rx_center_frequency = self.radio.center_freq Set the center frequency of the receiver.
print(f"RTL RX Center Frequency = {self.radio.get_center_freq()}") Not callable during recording; RTL-SDR requires stream stop/restart to change center frequency.
"""
with self._param_lock:
min_rate, max_rate = 25e6, 1.75e9
if center_frequency < min_rate or center_frequency > max_rate:
raise SDRParameterError(
f"{self.__class__.__name__}: Center frequency {center_frequency/1e9:.3f} GHz "
f"out of range: [{min_rate/1e9:.3f} - {max_rate/1e9:.3f} GHz]"
)
self.radio.center_freq = float(center_frequency)
self.rx_center_frequency = self.radio.center_freq
print(f"RTL RX Center Frequency = {self.radio.get_center_freq()}")
def set_rx_gain(self, gain, gain_mode="absolute"): def set_rx_gain(self, gain, gain_mode="absolute"):
available_gains = self.radio.get_gains() """
Set the gain of the receiver. Callable during streaming.
"""
with self._param_lock:
available_gains = self.radio.get_gains()
if gain is None: if gain is None:
self.radio.gain = "auto" self.radio.gain = "auto"
self.rx_gain = "auto" self.rx_gain = "auto"
else:
if not available_gains:
warnings.warn(
"No gain table reported by RTL-SDR; applying requested gain directly.",
RuntimeWarning,
)
target_gain = gain
else: else:
min_gain = min(available_gains) if not available_gains:
max_gain = max(available_gains) warnings.warn(
"No gain table reported by RTL-SDR; applying requested gain directly.",
if gain_mode == "relative": RuntimeWarning,
if gain > 0:
raise ValueError(
"When gain_mode = 'relative', gain must be < 0. This sets\
the gain relative to the maximum possible gain."
)
target_gain = max_gain + gain
else:
target_gain = gain
if target_gain < min_gain or target_gain > max_gain:
print(
f"Requested gain {target_gain} dB out of range;\
clamping to valid span {min_gain}-{max_gain} dB."
) )
target_gain = min(max(target_gain, min_gain), max_gain) target_gain = gain
else:
min_gain = min(available_gains)
max_gain = max(available_gains)
target_gain = min(available_gains, key=lambda g: abs(g - target_gain)) if gain_mode == "relative":
if gain > 0:
raise SDRParameterError("When gain_mode = 'relative', gain must be < 0. This sets\
the gain relative to the maximum possible gain.")
target_gain = max_gain + gain
else:
target_gain = gain
self.radio.set_gain(target_gain) if target_gain < min_gain or target_gain > max_gain:
self.rx_gain = self.radio.get_gain() print(f"Requested gain {target_gain} dB out of range;\
clamping to valid span {min_gain}-{max_gain} dB.")
target_gain = min(max(target_gain, min_gain), max_gain)
print(f"RTL RX Gain = {self.radio.get_gain()}") target_gain = min(available_gains, key=lambda g: abs(g - target_gain))
print(f"Available RTL RX Gains: {available_gains}")
def record(self, num_samples: Optional[int] = None, rx_time: Optional[int | float] = None): self.radio.set_gain(target_gain)
self.rx_gain = self.radio.get_gain()
print(f"RTL RX Gain = {self.radio.get_gain()}")
print(f"Available RTL RX Gains: {available_gains}")
def _calculate_optimal_buffer_size(self, sample_rate):
"""USB packet alignment for stability."""
# RTL-SDR USB transfers in 16k chunks
min_size = 16384
max_size = 262144 # 256k
# Target: 50ms of data per buffer
target = int(sample_rate * 0.05)
# Round up to 16k boundary
size = ((target + 16383) // 16384) * 16384
return max(min_size, min(size, max_size))
def record(
self,
num_samples: Optional[int] = None,
rx_time: Optional[int | float] = None,
) -> Recording:
""" """
Create a radio recording (iq samples and metadata) of a given length from the RTL-SDR. Create a radio recording (iq samples and metadata) of a given length from the RTL-SDR.
Either num_samples or rx_time must be provided. Either num_samples or rx_time must be provided.
@ -147,13 +186,13 @@ class RTLSDR(SDR):
raise RuntimeError("RX was not initialized. init_rx() must be called before record().") raise RuntimeError("RX was not initialized. init_rx() must be called before record().")
if num_samples is not None and rx_time is not None: if num_samples is not None and rx_time is not None:
raise ValueError("Only input one of num_samples or rx_time") raise SDRParameterError("Only input one of num_samples or rx_time")
elif num_samples is not None: elif num_samples is not None:
pass pass
elif rx_time is not None: elif rx_time is not None:
num_samples = int(rx_time * self.rx_sample_rate) num_samples = int(rx_time * self.rx_sample_rate)
else: else:
raise ValueError("Must provide input of one of num_samples or rx_time") raise SDRParameterError("Must provide input of one of num_samples or rx_time")
# RTL-SDR has USB buffer limitations - use consistent 256k chunks # RTL-SDR has USB buffer limitations - use consistent 256k chunks
# Always read full chunks to avoid USB overflow issues with partial reads # Always read full chunks to avoid USB overflow issues with partial reads
@ -232,6 +271,10 @@ class RTLSDR(SDR):
def close(self): def close(self):
try: try:
self.radio.close() self.radio.close()
del self.radio
finally: finally:
self._enable_rx = False self._enable_rx = False
self._enable_tx = False self._enable_tx = False
def supports_dynamic_updates(self) -> dict:
return {"center_frequency": False, "sample_rate": False, "gain": True}

View File

@ -1,5 +1,6 @@
import math import math
import pickle import pickle
import threading
import warnings import warnings
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Optional from typing import Optional
@ -7,7 +8,7 @@ from typing import Optional
import numpy as np import numpy as np
import zmq import zmq
from ria_toolkit_oss.datatypes.recording import Recording from ria_toolkit_oss.data.recording import Recording
class SDR(ABC): class SDR(ABC):
@ -27,17 +28,27 @@ class SDR(ABC):
self._tx_initialized = False self._tx_initialized = False
self._enable_rx = False self._enable_rx = False
self._enable_tx = False self._enable_tx = False
self._accumulated_buffer = None self._accumulated_buffer = None
self._max_num_buffers = None self._max_num_buffers = None
self._num_buffers_processed = 0 self._num_buffers_processed = 0
self._accumulated_buffer = None
self._last_buffer = None self._last_buffer = None
self._corrupted_buffer_count = 0
self.rx_sample_rate = None self.rx_sample_rate = None
self.rx_center_frequency = None self.rx_center_frequency = None
self.rx_gain = None self.rx_gain = None
self.tx_sample_rate = None self.tx_sample_rate = None
self.tx_center_frequency = None self.tx_center_frequency = None
self.tx_gain = None self.tx_gain = None
self._param_lock = threading.RLock() # Reentrant lock
# Pending config consumed by rx() on first call and by _apply_sdr_config
# in the agent inference loop. Subclasses that need different defaults
# (e.g. MockSDR) can overwrite these in their own __init__.
self.center_freq: float = 2.4e9
self.sample_rate: float = 10e6
self.gain: float = 40.0
def record(self, num_samples: Optional[int] = None, rx_time: Optional[int | float] = None) -> Recording: def record(self, num_samples: Optional[int] = None, rx_time: Optional[int | float] = None) -> Recording:
""" """
@ -71,7 +82,6 @@ class SDR(ABC):
self._max_num_buffers = num_buffers self._max_num_buffers = num_buffers
self._num_buffers_processed = 0 self._num_buffers_processed = 0
self._num_buffers_processed = 0
self._last_buffer = None self._last_buffer = None
self._accumulated_buffer = None self._accumulated_buffer = None
print("Starting stream") print("Starting stream")
@ -94,8 +104,35 @@ class SDR(ABC):
# reset to record again # reset to record again
self._accumulated_buffer = None self._accumulated_buffer = None
self._num_buffers_processed = 0
return recording return recording
def rx(self, num_samples: int) -> "np.ndarray":
"""Return *num_samples* complex IQ samples as a 1-D complex64 array.
This is the interface used by the agent inference loop. On first call,
``init_rx()`` is invoked automatically using the values stored in
``center_freq``, ``sample_rate``, and ``gain`` (set beforehand by
``_apply_sdr_config``). Subsequent calls stream directly.
Subclasses may override this for hardware-native capture APIs (e.g.
``MockSDR`` uses AWGN generation; ``PlutoSDR`` could use
``self.radio.rx()``).
"""
if not self._rx_initialized:
gain = self.gain if isinstance(self.gain, (int, float)) else 40.0
self.init_rx(
sample_rate=self.sample_rate,
center_frequency=self.center_freq,
gain=gain,
channel=0,
)
recording = self.record(num_samples=num_samples)
# Recording.data is either a list of 1-D arrays (one per channel) or a
# 2-D ndarray (channels × samples). Either way, index 0 is channel 0.
data = recording.data
return data[0] if hasattr(data, "__getitem__") else data
def stream_to_zmq(self, zmq_address, n_samples: int, buffer_size: Optional[int] = 10000): def stream_to_zmq(self, zmq_address, n_samples: int, buffer_size: Optional[int] = 10000):
""" """
Stream iq samples as interleaved bytes via zmq. Stream iq samples as interleaved bytes via zmq.
@ -110,21 +147,23 @@ class SDR(ABC):
:return: The trimmed Recording. :return: The trimmed Recording.
:rtype: Recording :rtype: Recording
""" """
try:
self._previous_buffer = None
self._max_num_buffers = np.inf if n_samples == np.inf else math.ceil(n_samples / buffer_size)
self._num_buffers_processed = 0
self.zmq_address = _generate_full_zmq_address(str(zmq_address))
self.context = zmq.Context()
self.socket = self.context.socket(zmq.PUB)
self.socket.bind(self.zmq_address)
self._previous_buffer = None self._stream_rx(
self._max_num_buffers = np.inf if n_samples == np.inf else math.ceil(n_samples / buffer_size) self._zmq_bytestream_callback,
self._num_buffers_processed = 0 )
self.zmq_address = _generate_full_zmq_address(str(zmq_address)) finally:
self.context = zmq.Context() if hasattr(self, "socket"):
self.socket = self.context.socket(zmq.PUB) self.socket.close()
self.socket.bind(self.zmq_address) if hasattr(self, "context"):
self.context.destroy()
self._stream_rx(
self._zmq_bytestream_callback,
)
self.context.destroy()
self.socket.close()
def _accumulate_buffers_callback(self, buffer, metadata=None): def _accumulate_buffers_callback(self, buffer, metadata=None):
""" """
@ -134,62 +173,72 @@ class SDR(ABC):
# save the buffer until max reached # save the buffer until max reached
# return a recording # return a recording
buffer = np.array(buffer) # make it 1d # Validate buffer
if len(buffer.shape) == 1: if not self._validate_buffer(buffer):
buffer = np.array([buffer]) print("Warning: Corrupted buffer detected, skipping")
self._corrupted_buffer_count += 1
return # Skip this buffer
# it runs these checks each time, is that an efficiency issue? if isinstance(buffer, np.ndarray):
if buffer.ndim == 1:
if self._max_num_buffers is None: buffer = buffer[np.newaxis, :] # make shape (1, N)
# default then
# this should probably print, but that would happen every buffer...
raise ValueError("Number of buffers for block capture not set.")
# add the given buffer to the pre-allocated buffer
if metadata is not None:
self.received_metadata = metadata
# TODO optimize, pre-allocate
if self._accumulated_buffer is not None:
self._accumulated_buffer = np.concatenate((self._accumulated_buffer, buffer), axis=1)
else: else:
# the first time buffer = np.array(buffer) # make it 1d
self._accumulated_buffer = buffer.copy() if len(buffer.shape) == 1:
buffer = np.array([buffer])
self._num_buffers_processed = self._num_buffers_processed + 1 # First call: pre-allocate if we know the final size
if self._accumulated_buffer is None:
# Check that _max_num_buffers is set
if self._max_num_buffers is None:
raise ValueError("Number of buffers for block capture not set.")
if self._num_samples_to_record is None:
raise ValueError("Number of samples not set before RX start.")
if metadata is not None:
self.received_metadata = metadata
# Preallocate once (avoid np.zeros; use np.empty for speed)
num_channels = buffer.shape[0]
self._accumulated_buffer = np.empty((num_channels, self._num_samples_to_record), dtype=buffer.dtype)
self._write_position = 0
print(f"Pre-allocated buffer for {self._num_samples_to_record:,} samples.")
# Write new buffer into pre-allocated array
n = buffer.shape[1]
start = self._write_position
end = min(start + n, self._num_samples_to_record)
samples_to_write = end - start
if samples_to_write > 0:
self._accumulated_buffer[:, start:end] = buffer[:, : end - start]
self._write_position = end
# Check if we're done
self._num_buffers_processed += 1
if self._num_buffers_processed >= self._max_num_buffers: if self._num_buffers_processed >= self._max_num_buffers:
self.stop() self.stop()
if self._last_buffer is not None: def _validate_buffer(self, buffer):
if (buffer == self._last_buffer).all(): """Check for obviously corrupt data."""
print("\033[93mWarning: Buffer Overflow Detected\033[0m") # Check for all zeros
self._last_buffer = buffer.copy() if np.all(buffer == 0):
else: return False
self._last_buffer = buffer.copy() # Check for all same value
if np.all(buffer == buffer[0]):
# print("Number of buffers received: " + str(self._num_buffers_processed)) return False
return True
def _zmq_bytestream_callback(self, buffer, metadata=None): def _zmq_bytestream_callback(self, buffer, metadata=None):
# push to ZMQ port # push to ZMQ port
data = np.array(buffer).tobytes() # convert to bytes for transport data = np.array(buffer).tobytes() # convert to bytes for transport
self.socket.send(data) self.socket.send(data)
# print(f"Sent {self._num_buffers_processed} ZMQ buffers to {self.zmq_address}")
self._num_buffers_processed = self._num_buffers_processed + 1 self._num_buffers_processed = self._num_buffers_processed + 1
if self._max_num_buffers is not None: if self._max_num_buffers is not None:
if self._num_buffers_processed >= self._max_num_buffers: if self._num_buffers_processed >= self._max_num_buffers:
self.pause_rx() self.pause_rx()
if self._previous_buffer is not None:
if (buffer == self._previous_buffer).all():
print("\033[93mWarning: Buffer Overflow Detected\033[0m")
# TODO: I suggest we think about moving this part to the top of this function
# and skip the rest of the function in case of overflow.
# like, it's not necessary to stream repeated IQ data anyways!
self._previous_buffer = buffer.copy()
def pickle_buffer_to_zmq(self, zmq_address, buffer_size, num_buffers): def pickle_buffer_to_zmq(self, zmq_address, buffer_size, num_buffers):
""" """
Stream samples to a zmq address, packaged in binary buffers using numpy.pickle. Stream samples to a zmq address, packaged in binary buffers using numpy.pickle.
@ -229,7 +278,7 @@ class SDR(ABC):
self.stop() self.stop()
if self._last_buffer is not None: if self._last_buffer is not None:
if (buffer == self._last_buffer).all(): if np.array_equal(buffer, self._last_buffer):
print("\033[93mWarning: Buffer Overflow Detected\033[0m") print("\033[93mWarning: Buffer Overflow Detected\033[0m")
self._last_buffer = buffer.copy() self._last_buffer = buffer.copy()
else: else:
@ -265,7 +314,7 @@ class SDR(ABC):
elif num_samples is not None: elif num_samples is not None:
self._num_samples_to_transmit = num_samples self._num_samples_to_transmit = num_samples
elif tx_time is not None: elif tx_time is not None:
self._num_samples_to_transmit = tx_time * self.tx_sample_rate self._num_samples_to_transmit = int(tx_time * self.tx_sample_rate)
else: else:
self._num_samples_to_transmit = len(recording) self._num_samples_to_transmit = len(recording)
@ -373,6 +422,58 @@ class SDR(ABC):
""" """
return self.tx_gain return self.tx_gain
def set_rx_sample_rate(self):
"""
Set the sample rate of the receiver.
"""
raise NotImplementedError
def set_rx_center_frequency(self):
"""
Set the center frequency of the receiver.
"""
raise NotImplementedError
def set_rx_gain(self):
"""
Set the gain setting of the receiver.
"""
raise NotImplementedError
def set_tx_sample_rate(self):
"""
Set the sample rate of the transmitter.
"""
raise NotImplementedError
def set_tx_center_frequency(self):
"""
Set the center frequency of the transmitter.
"""
raise NotImplementedError
def set_tx_gain(self):
"""
Set the gain setting of the transmitter.
"""
raise NotImplementedError
def supports_dynamic_updates(self) -> dict:
"""
Report which parameters can be updated during streaming.
Returns:
dict: {'center_frequency': bool, 'sample_rate': bool, 'gain': bool}
"""
return {"center_frequency": False, "sample_rate": False, "gain": False}
def __del__(self):
"""Cleanup on garbage collection."""
try:
self.close()
except Exception:
pass
@abstractmethod @abstractmethod
def close(self): def close(self):
pass pass
@ -442,3 +543,69 @@ def _verify_sample_format(samples):
""" """
return np.max(np.abs(samples)) <= 1 return np.max(np.abs(samples)) <= 1
class SDRError(Exception):
"""Base exception for SDR errors."""
pass
class SDRParameterError(SDRError):
"""Invalid parameter (sample rate, freq, gain)."""
pass
class SDROverflowError(SDRError):
"""Buffer overflow detected."""
pass
class SdrDisconnectedError(SDRError):
"""Raised when the SDR device disappears mid-operation (USB unplug, network drop)."""
pass
# Substrings that strongly indicate a device has disappeared rather than a
# transient / recoverable error. Checked case-insensitively against str(exc).
_DISCONNECT_MARKERS = (
"no such device",
"device not found",
"not found",
"broken pipe",
"disconnected",
"no device",
"device unplugged",
"usb",
"i/o error",
"input/output error",
"errno 19", # ENODEV
"errno 5", # EIO
)
def translate_disconnect(exc: BaseException) -> BaseException:
"""Return ``SdrDisconnectedError`` if *exc* looks like a USB/device drop, else *exc*.
Drivers wrap their native-API calls with::
try:
return self.radio.rx()
except Exception as exc:
raise translate_disconnect(exc) from exc
The caller (e.g. the streamer) can then catch ``SdrDisconnectedError``
specifically and report it to the hub rather than crashing the loop.
"""
if isinstance(exc, SdrDisconnectedError):
return exc
msg = str(exc).lower()
if any(marker in msg for marker in _DISCONNECT_MARKERS):
return SdrDisconnectedError(str(exc))
# OSError subclass with ENODEV / EIO errno is also a disconnect signal.
if isinstance(exc, OSError) and getattr(exc, "errno", None) in (5, 19):
return SdrDisconnectedError(str(exc))
return exc

View File

@ -36,7 +36,7 @@ except SyntaxError as exc: # pragma: no cover - Python 2/3 compatibility issue
print("Manual fix: Run `python scripts/fix_pyrf_python3.py` from ria-toolkit-oss directory") print("Manual fix: Run `python scripts/fix_pyrf_python3.py` from ria-toolkit-oss directory")
raise exc raise exc
from ria_toolkit_oss.sdr.sdr import SDR from ria_toolkit_oss.sdr.sdr import SDR, SDRParameterError
class ThinkRF(SDR): class ThinkRF(SDR):
@ -51,7 +51,7 @@ class ThinkRF(SDR):
super().__init__() super().__init__()
if identifier is None: if identifier is None:
raise ValueError("ThinkRF requires an IP address or hostname identifier") raise SDRParameterError("ThinkRF requires an IP address or hostname identifier")
self.identifier = identifier self.identifier = identifier
try: try:
@ -90,7 +90,7 @@ class ThinkRF(SDR):
mode = capture_mode.lower() mode = capture_mode.lower()
if mode not in {"block", "stream"}: if mode not in {"block", "stream"}:
raise ValueError("capture_mode must be either 'block' or 'stream'") raise SDRParameterError("capture_mode must be either 'block' or 'stream'")
self._rfe_mode = rfe_mode self._rfe_mode = rfe_mode
self._attenuation = int(max(0, min(attenuation, 30))) self._attenuation = int(max(0, min(attenuation, 30)))
@ -113,10 +113,12 @@ class ThinkRF(SDR):
decimation: Optional[int] = None, decimation: Optional[int] = None,
): ):
if channel not in (0, None): if channel not in (0, None):
raise ValueError("ThinkRF devices expose a single receive channel") raise SDRParameterError("ThinkRF supports only channel 0 for RX.")
stream_mode = getattr(self, "_capture_mode", "block") == "stream" stream_mode = getattr(self, "_capture_mode", "block") == "stream"
actual_decimation, actual_sample_rate = self.set_rx_sample_rate(sample_rate=sample_rate, decimation=decimation) actual_decimation, _ = self.set_rx_sample_rate(
sample_rate=sample_rate, decimation=decimation, stream_mode=stream_mode
)
self.radio.reset() self.radio.reset()
self.radio.scpiset(":SYSTEM:FLUSH") self.radio.scpiset(":SYSTEM:FLUSH")
@ -127,15 +129,7 @@ class ThinkRF(SDR):
self.radio.rfe_mode(self._rfe_mode) self.radio.rfe_mode(self._rfe_mode)
self.set_rx_center_frequency(center_frequency=center_frequency) self.set_rx_center_frequency(center_frequency=center_frequency)
self.set_rx_gain(gain=gain, gain_mode=gain_mode, actual_decimation=actual_decimation)
attenuation = self._attenuation if gain is None else int(gain) # gain
attenuation = max(0, min(attenuation, 30))
self.radio.attenuator(attenuation)
gain_profile = self._gain_profile
if gain_mode and isinstance(gain_mode, str) and gain_mode.upper() in {"LOW", "MEDIUM", "HIGH", "VLOW"}:
gain_profile = gain_mode.upper()
self.radio.gain(gain_profile.lower()) # WSA.gain() expects lowercase
self.radio.decimation(actual_decimation) self.radio.decimation(actual_decimation)
if stream_mode: if stream_mode:
@ -153,14 +147,6 @@ class ThinkRF(SDR):
self.radio.scpiset(f":TRACE:BLOCK:PACKETS {self._packets_per_block}") self.radio.scpiset(f":TRACE:BLOCK:PACKETS {self._packets_per_block}")
self.radio.scpiset(":TRACE:BLOCK:DATA?") self.radio.scpiset(":TRACE:BLOCK:DATA?")
self.rx_gain = {
"attenuation_dB": attenuation,
"profile": gain_profile,
"decimation": actual_decimation,
"rfe_mode": self._rfe_mode,
"spp": self._samples_per_packet,
"ppb": self._packets_per_block,
}
self.rx_buffer_size = self._samples_per_packet self.rx_buffer_size = self._samples_per_packet
self.rx_channel = 0 self.rx_channel = 0
@ -168,6 +154,10 @@ class ThinkRF(SDR):
self._tx_initialized = False self._tx_initialized = False
def set_rx_sample_rate(self, sample_rate, decimation, stream_mode): def set_rx_sample_rate(self, sample_rate, decimation, stream_mode):
"""
Set the sample rate of the receiver.
Not callable during recording; ThinkRF requires stream stop/restart to change sample rate.
"""
# Enforce sample rate / decimation # Enforce sample rate / decimation
# Note: decimation parameter takes precedence if provided # Note: decimation parameter takes precedence if provided
actual_decimation, actual_sample_rate = self.enforce_sample_rate(sample_rate, decimation) actual_decimation, actual_sample_rate = self.enforce_sample_rate(sample_rate, decimation)
@ -188,9 +178,32 @@ class ThinkRF(SDR):
return actual_decimation, actual_sample_rate return actual_decimation, actual_sample_rate
def set_rx_center_frequency(self, center_frequency): def set_rx_center_frequency(self, center_frequency):
self.radio.freq(int(center_frequency)) """
self.rx_center_frequency = self.radio.freq Set the center frequency of the receiver. Callable during streaming.
print(f"ThinkRF RX Center Frequency = {self.radio.freq}") """
with self._param_lock:
self.radio.freq(int(center_frequency))
self.rx_center_frequency = self.radio.freq
print(f"ThinkRF RX Center Frequency = {self.radio.freq}")
def set_rx_gain(self, gain, gain_mode, actual_decimation):
attenuation = self._attenuation if gain is None else int(gain) # gain
attenuation = max(0, min(attenuation, 30))
self.radio.attenuator(attenuation)
gain_profile = self._gain_profile
if gain_mode and isinstance(gain_mode, str) and gain_mode.upper() in {"LOW", "MEDIUM", "HIGH", "VLOW"}:
gain_profile = gain_mode.upper()
self.radio.gain(gain_profile.lower()) # WSA.gain() expects lowercase
self.rx_gain = {
"attenuation_dB": attenuation,
"profile": gain_profile,
"decimation": actual_decimation,
"rfe_mode": self._rfe_mode,
"spp": self._samples_per_packet,
"ppb": self._packets_per_block,
}
def _stream_rx(self, callback): def _stream_rx(self, callback):
if not self._rx_initialized: if not self._rx_initialized:
@ -379,10 +392,8 @@ class ThinkRF(SDR):
actual_sample_rate = self.BASE_SAMPLE_RATE / decimation actual_sample_rate = self.BASE_SAMPLE_RATE / decimation
if abs(actual_sample_rate - requested_sample_rate) > 1e3: # More than 1 kHz difference if abs(actual_sample_rate - requested_sample_rate) > 1e3: # More than 1 kHz difference
print( print(f"ThinkRF: Requested {requested_sample_rate/1e6:.2f} MS/s → \
f"ThinkRF: Requested {requested_sample_rate/1e6:.2f} MS/s → \ Using decimation={decimation} ({actual_sample_rate/1e6:.2f} MS/s)")
Using decimation={decimation} ({actual_sample_rate/1e6:.2f} MS/s)"
)
return decimation, actual_sample_rate return decimation, actual_sample_rate
@ -431,7 +442,7 @@ class ThinkRF(SDR):
For decimation 1 or 2, block captures are limited by onboard RAM. For decimation 1 or 2, block captures are limited by onboard RAM.
""" """
if decimation <= 2 and num_samples > self.MAX_ONBOARD_SAMPLES: if decimation <= 2 and num_samples > self.MAX_ONBOARD_SAMPLES:
raise ValueError( raise SDRParameterError(
f"ThinkRF: Cannot capture {num_samples} samples at decimation {decimation}. " f"ThinkRF: Cannot capture {num_samples} samples at decimation {decimation}. "
f"Onboard RAM limit is ~{self.MAX_ONBOARD_SAMPLES} samples for dec 1/2. " f"Onboard RAM limit is ~{self.MAX_ONBOARD_SAMPLES} samples for dec 1/2. "
f"Either reduce num_samples or use stream mode (increase decimation to >=4)." f"Either reduce num_samples or use stream mode (increase decimation to >=4)."
@ -446,3 +457,6 @@ class ThinkRF(SDR):
"fstop": int(center_frequency) + half, "fstop": int(center_frequency) + half,
"amplitude": -100, "amplitude": -100,
} }
def supports_dynamic_updates(self) -> dict:
return {"center_frequency": True, "sample_rate": False, "gain": False}

View File

@ -6,8 +6,8 @@ from typing import Optional
import numpy as np import numpy as np
import uhd import uhd
from ria_toolkit_oss.datatypes.recording import Recording from ria_toolkit_oss.data.recording import Recording
from ria_toolkit_oss.sdr.sdr import SDR from ria_toolkit_oss.sdr.sdr import SDR, SDRParameterError
class USRP(SDR): class USRP(SDR):
@ -40,7 +40,7 @@ class USRP(SDR):
channel: int, channel: int,
gain: int, gain: int,
gain_mode: Optional[str] = "absolute", gain_mode: Optional[str] = "absolute",
rx_buffer_size: int = 960000, rx_buffer_size: Optional[int] = None,
): ):
""" """
Initializes the USRP for receiving. Initializes the USRP for receiving.
@ -54,7 +54,7 @@ class USRP(SDR):
:param channel: The channel the USRP is set to. :param channel: The channel the USRP is set to.
:type channel: int :type channel: int
:param gain_mode: 'absolute' passes gain directly to the sdr, :param gain_mode: 'absolute' passes gain directly to the sdr,
'relative' means that gain should be a negative value, and it will be subtracted from the max gain. 'relative' means that gain should be a negative value, and it will be subtracted from the max gain.
:type gain_mode: str :type gain_mode: str
:param rx_buffer_size: Internal buffer size for receiving samples. Defaults to 960000. :param rx_buffer_size: Internal buffer size for receiving samples. Defaults to 960000.
:type rx_buffer_size: int :type rx_buffer_size: int
@ -63,8 +63,6 @@ class USRP(SDR):
:rtype: dict :rtype: dict
""" """
self.rx_buffer_size = rx_buffer_size
# build USRP object # build USRP object
usrp_args = _generate_usrp_config_string(sample_rate=sample_rate, device_dict=self.device_dict) usrp_args = _generate_usrp_config_string(sample_rate=sample_rate, device_dict=self.device_dict)
self.usrp = uhd.usrp.MultiUSRP(usrp_args) self.usrp = uhd.usrp.MultiUSRP(usrp_args)
@ -72,7 +70,7 @@ class USRP(SDR):
# check if channel arg is valid # check if channel arg is valid
max_num_channels = self.usrp.get_rx_num_channels() max_num_channels = self.usrp.get_rx_num_channels()
if channel + 1 > max_num_channels: if channel + 1 > max_num_channels:
raise IOError(f"Channel {channel} not valid for device with {max_num_channels} channels.") raise SDRParameterError(f"Channel {channel} not valid for device with {max_num_channels} channels.")
self.set_rx_sample_rate(sample_rate=sample_rate, channel=channel) self.set_rx_sample_rate(sample_rate=sample_rate, channel=channel)
self.set_rx_center_frequency(center_frequency=center_frequency, channel=channel) self.set_rx_center_frequency(center_frequency=center_frequency, channel=channel)
@ -81,6 +79,20 @@ class USRP(SDR):
self.rx_channel = channel self.rx_channel = channel
print(f"USRP RX Channel = {self.rx_channel}") print(f"USRP RX Channel = {self.rx_channel}")
stream_args = uhd.usrp.StreamArgs("fc32", "sc16")
stream_args.channels = [self.rx_channel]
self.metadata = uhd.types.RXMetadata()
self.rx_stream = self.usrp.get_rx_stream(stream_args)
if rx_buffer_size is None: # In case it's none
self.rx_buffer_size = self.rx_stream.get_max_num_samps()
else:
self.rx_buffer_size = rx_buffer_size
# set timeout based on buffer size and sample rate, with a safety factor of 5
self.timeout = (self.rx_buffer_size / self.rx_sample_rate) * 5
# flag to prevent user from calling certain functions before this one. # flag to prevent user from calling certain functions before this one.
self._rx_initialized = True self._rx_initialized = True
self._tx_initialized = False self._tx_initialized = False
@ -88,68 +100,74 @@ class USRP(SDR):
return {"sample_rate": self.rx_sample_rate, "center_frequency": self.rx_center_frequency, "gain": self.rx_gain} return {"sample_rate": self.rx_sample_rate, "center_frequency": self.rx_center_frequency, "gain": self.rx_gain}
def set_rx_sample_rate(self, sample_rate, channel=0): def set_rx_sample_rate(self, sample_rate, channel=0):
"""
Set the sample rate of the receiver. Callable during streaming.
"""
# check if sample rate arg is valid # check if sample rate arg is valid
# Note: B200/B210 devices auto-adjust master clock rate, so get_rx_rates() returns # Note: B200/B210 devices auto-adjust master clock rate, so get_rx_rates() returns
# the range for the CURRENT master clock, not the maximum possible range. # the range for the CURRENT master clock, not the maximum possible range.
# Skip validation for B-series devices and let UHD handle it. # Skip validation for B-series devices and let UHD handle it.
device_type = self.device_dict.get("type", "").lower() with self._param_lock:
if device_type not in ["b200", "b210"]: device_type = self.device_dict.get("type", "").lower()
sample_rate_range = self.usrp.get_rx_rates() if device_type not in ["b200", "b210"]:
if sample_rate < sample_rate_range.start() or sample_rate > sample_rate_range.stop(): sample_rate_range = self.usrp.get_rx_rates()
raise IOError( min_rate, max_rate = sample_rate_range.start(), sample_rate_range.stop()
f"Sample rate {sample_rate} not valid for this USRP.\nValid\ if sample_rate < min_rate or sample_rate > max_rate:
range is {sample_rate_range.start()}\ raise SDRParameterError(
to {sample_rate_range.stop()}." f"{self.__class__.__name__}: Sample rate {sample_rate/1e6:.3f} Msps "
) f"out of range: [{min_rate/1e6:.3f} - {max_rate/1e6:.3f} Msps]"
self.usrp.set_rx_rate(sample_rate, channel) )
self.rx_sample_rate = self.usrp.get_rx_rate(channel)
print(f"USRP RX Sample Rate = {self.rx_sample_rate}") self.usrp.set_rx_rate(sample_rate, channel)
self.rx_sample_rate = self.usrp.get_rx_rate(channel)
print(f"USRP RX Sample Rate = {self.rx_sample_rate}")
def set_rx_center_frequency(self, center_frequency, channel=0): def set_rx_center_frequency(self, center_frequency, channel=0):
center_frequency_range = self.usrp.get_rx_freq_range() """
if center_frequency < center_frequency_range.start() or center_frequency > center_frequency_range.stop(): Set the center frequency of the receiver. Callable during streaming.
raise IOError( """
f"Center frequency {center_frequency} out of range for USRP.\ with self._param_lock:
\nValid range is {center_frequency_range.start()} \ center_frequency_range = self.usrp.get_rx_freq_range()
to {center_frequency_range.stop()}." min_rate, max_rate = center_frequency_range.start(), center_frequency_range.stop()
) if center_frequency < min_rate or center_frequency > max_rate:
self.usrp.set_rx_freq(uhd.libpyuhd.types.tune_request(center_frequency), channel) raise SDRParameterError(
self.rx_center_frequency = self.usrp.get_rx_freq(channel) f"{self.__class__.__name__}: Center frequency {center_frequency/1e9:.3f} GHz "
print(f"USRP RX Center Frequency = {self.rx_center_frequency}") f"out of range: [{min_rate/1e9:.3f} - {max_rate/1e9:.3f} GHz]"
)
self.usrp.set_rx_freq(uhd.libpyuhd.types.tune_request(center_frequency), channel)
self.rx_center_frequency = self.usrp.get_rx_freq(channel)
print(f"USRP RX Center Frequency = {self.rx_center_frequency}")
def set_rx_gain(self, gain, gain_mode="absolute", channel=0): def set_rx_gain(self, gain, gain_mode="absolute", channel=0):
# check if gain arg is valid """
gain_range = self.usrp.get_rx_gain_range() Set the gain of the receiver. Callable during streaming.
if gain_mode == "relative": """
if gain > 0: with self._param_lock:
raise ValueError( # check if gain arg is valid
"When gain_mode = 'relative', gain must be < 0. This sets\ gain_range = self.usrp.get_rx_gain_range()
the gain relative to the maximum possible gain." if gain_mode == "relative":
) if gain > 0:
raise SDRParameterError("When gain_mode = 'relative', gain must be < 0. This sets\
the gain relative to the maximum possible gain.")
else:
# set gain relative to max
abs_gain = gain_range.stop() + gain
else: else:
# set gain relative to max abs_gain = gain
abs_gain = gain_range.stop() + gain if abs_gain < gain_range.start() or abs_gain > gain_range.stop():
else: print(f"Gain {abs_gain} out of range for this USRP.")
abs_gain = gain print(f"Gain range: {gain_range.start()} to {gain_range.stop()} dB")
if abs_gain < gain_range.start() or abs_gain > gain_range.stop(): abs_gain = min(max(abs_gain, gain_range.start()), gain_range.stop())
print(f"Gain {abs_gain} out of range for this USRP.") self.usrp.set_rx_gain(abs_gain, channel)
print(f"Gain range: {gain_range.start()} to {gain_range.stop()} dB") self.rx_gain = self.usrp.get_rx_gain(channel)
abs_gain = min(max(abs_gain, gain_range.start()), gain_range.stop()) print(f"USRP RX Gain = {self.rx_gain}")
self.usrp.set_rx_gain(abs_gain, channel)
self.rx_gain = self.usrp.get_rx_gain(channel)
print(f"USRP RX Gain = {self.rx_gain}")
def _stream_rx(self, callback): def _stream_rx(self, callback):
if not self._rx_initialized: if not self._rx_initialized:
raise RuntimeError("RX was not initialized. init_rx() must be called before _stream_rx() or record()") raise RuntimeError("RX was not initialized. init_rx() must be called before _stream_rx() or record()")
stream_args = uhd.usrp.StreamArgs("fc32", "sc16") # send command to start the rx stream
stream_args.channels = [self.rx_channel]
self.metadata = uhd.types.RXMetadata()
self.rx_stream = self.usrp.get_rx_stream(stream_args)
stream_command = uhd.types.StreamCMD(uhd.types.StreamMode.start_cont) stream_command = uhd.types.StreamCMD(uhd.types.StreamMode.start_cont)
stream_command.stream_now = True stream_command.stream_now = True
self.rx_stream.issue_stream_cmd(stream_command) self.rx_stream.issue_stream_cmd(stream_command)
@ -160,19 +178,19 @@ class USRP(SDR):
receive_buffer = np.zeros((1, self.rx_buffer_size), dtype=np.complex64) receive_buffer = np.zeros((1, self.rx_buffer_size), dtype=np.complex64)
while self._enable_rx: while self._enable_rx:
self.rx_stream.recv(receive_buffer, self.metadata, self.timeout)
# 1 is the timeout #TODO maybe set this intelligently based on the desired sample rate
self.rx_stream.recv(receive_buffer, self.metadata, 1)
# TODO set metadata correctly, sending real sample rate plus any error codes # TODO set metadata correctly, sending real sample rate plus any error codes
# sending complex signal # sending complex signal
callback(buffer=receive_buffer, metadata=self.metadata) callback(buffer=receive_buffer, metadata=self.metadata)
if self.metadata.error_code != uhd.types.RXMetadataErrorCode.none: if self.metadata.error_code != uhd.types.RXMetadataErrorCode.none:
print(f"Error while receiving samples: {self.metadata.strerror()}") if self.metadata.error_code == uhd.types.RXMetadataErrorCode.overflow:
print("\033[93mWarning: Buffer Overflow Detected.\033[0m")
if self.metadata.error_code == uhd.types.RXMetadataErrorCode.timeout: if self.metadata.error_code == uhd.types.RXMetadataErrorCode.timeout:
print("Stopping receive due to timeout error.") print("\033[91Stopping receive due to timeout error.\033[0m")
self.stop() self.stop()
# stop streaming
wait_time = 0.1 wait_time = 0.1
stop_time = self.usrp.get_time_now() + wait_time stop_time = self.usrp.get_time_now() + wait_time
stop_cmd = uhd.types.StreamCMD(uhd.types.StreamMode.stop_cont) stop_cmd = uhd.types.StreamCMD(uhd.types.StreamMode.stop_cont)
@ -180,10 +198,14 @@ class USRP(SDR):
stop_cmd.time_spec = stop_time stop_cmd.time_spec = stop_time
self.rx_stream.issue_stream_cmd(stop_cmd) self.rx_stream.issue_stream_cmd(stop_cmd)
time.sleep(wait_time) # TODO figure out what a realistic wait time is here. time.sleep(wait_time) # TODO figure out what a realistic wait time is here.
del self.rx_stream
print("USRP RX Completed.") print("USRP RX Completed.")
def record(self, num_samples: Optional[int] = None, rx_time: Optional[int | float] = None): def record(
self,
num_samples: Optional[int] = None,
rx_time: Optional[int | float] = None,
) -> Recording:
""" """
Create a radio recording (iq samples and metadata) of a given length from the USRP. Create a radio recording (iq samples and metadata) of a given length from the USRP.
Either num_samples or rx_time must be provided. Either num_samples or rx_time must be provided.
@ -200,41 +222,31 @@ class USRP(SDR):
raise RuntimeError("RX was not initialized. init_rx() must be called before _stream_rx() or record()") raise RuntimeError("RX was not initialized. init_rx() must be called before _stream_rx() or record()")
if num_samples is not None and rx_time is not None: if num_samples is not None and rx_time is not None:
raise ValueError("Only input one of num_samples or rx_time") raise SDRParameterError("Only input one of num_samples or rx_time")
elif num_samples is not None: elif num_samples is not None:
pass pass
elif rx_time is not None: elif rx_time is not None:
num_samples = int(rx_time * self.rx_sample_rate) num_samples = int(rx_time * self.rx_sample_rate)
else: else:
raise ValueError("Must provide input of one of num_samples or rx_time") raise SDRParameterError("Must provide input of one of num_samples or rx_time")
stream_args = uhd.usrp.StreamArgs("fc32", "sc16")
stream_args.channels = [self.rx_channel]
self.metadata = uhd.types.RXMetadata()
self.rx_stream = self.usrp.get_rx_stream(stream_args)
# send command to start the rx stream
stream_command = uhd.types.StreamCMD(uhd.types.StreamMode.start_cont) stream_command = uhd.types.StreamCMD(uhd.types.StreamMode.start_cont)
stream_command.stream_now = True stream_command.stream_now = True
self.rx_stream.issue_stream_cmd(stream_command) self.rx_stream.issue_stream_cmd(stream_command)
# receive loop # receive loop
self._enable_rx = True self._enable_rx = True
print("USRP Starting RX...")
store_array = np.zeros((1, (num_samples // self.rx_buffer_size + 1) * self.rx_buffer_size), dtype=np.complex64) store_array = np.zeros((1, (num_samples // self.rx_buffer_size + 1) * self.rx_buffer_size), dtype=np.complex64)
receive_buffer = np.zeros((1, self.rx_buffer_size), dtype=np.complex64) receive_buffer = np.zeros((1, self.rx_buffer_size), dtype=np.complex64)
print("USRP Starting RX...")
# write complex samples to receive buffer
for i in range(num_samples // self.rx_buffer_size + 1): for i in range(num_samples // self.rx_buffer_size + 1):
self.rx_stream.recv(receive_buffer, self.metadata, self.timeout)
# write samples to receive buffer
# they should already be complex
# 1 is the timeout #TODO maybe set this intelligently based on the desired sample rate
self.rx_stream.recv(receive_buffer, self.metadata, 1)
# TODO set metadata correctly, sending real sample rate plus any error codes
# sending complex signal
store_array[:, i * self.rx_buffer_size : (i + 1) * self.rx_buffer_size] = receive_buffer store_array[:, i * self.rx_buffer_size : (i + 1) * self.rx_buffer_size] = receive_buffer
# stop streaming
wait_time = 0.1 wait_time = 0.1
stop_time = self.usrp.get_time_now() + wait_time stop_time = self.usrp.get_time_now() + wait_time
stop_cmd = uhd.types.StreamCMD(uhd.types.StreamMode.stop_cont) stop_cmd = uhd.types.StreamCMD(uhd.types.StreamMode.stop_cont)
@ -242,7 +254,7 @@ class USRP(SDR):
stop_cmd.time_spec = stop_time stop_cmd.time_spec = stop_time
self.rx_stream.issue_stream_cmd(stop_cmd) self.rx_stream.issue_stream_cmd(stop_cmd)
time.sleep(wait_time) # TODO figure out what a realistic wait time is here. time.sleep(wait_time) # TODO figure out what a realistic wait time is here.
del self.rx_stream
print("USRP RX Completed.") print("USRP RX Completed.")
metadata = { metadata = {
"source": self.__class__.__name__, "source": self.__class__.__name__,
@ -273,7 +285,7 @@ class USRP(SDR):
:param channel: The channel the USRP is set to. :param channel: The channel the USRP is set to.
:type channel: int :type channel: int
:param gain_mode: 'absolute' passes gain directly to the sdr, :param gain_mode: 'absolute' passes gain directly to the sdr,
'relative' means that gain should be a negative value, and it will be subtracted from the max gain. 'relative' means that gain should be a negative value, and it will be subtracted from the max gain.
:type gain_mode: str :type gain_mode: str
""" """
@ -287,7 +299,7 @@ class USRP(SDR):
# check if channel arg is valid # check if channel arg is valid
max_num_channels = self.usrp.get_rx_num_channels() max_num_channels = self.usrp.get_rx_num_channels()
if channel + 1 > max_num_channels: if channel + 1 > max_num_channels:
raise IOError(f"Channel {channel} not valid for device with {max_num_channels} channels.") raise SDRParameterError(f"Channel {channel} not valid for device with {max_num_channels} channels.")
self.set_tx_sample_rate(sample_rate=sample_rate, channel=channel) self.set_tx_sample_rate(sample_rate=sample_rate, channel=channel)
self.set_tx_center_frequency(center_frequency=center_frequency, channel=channel) self.set_tx_center_frequency(center_frequency=center_frequency, channel=channel)
@ -313,23 +325,26 @@ class USRP(SDR):
device_type = self.device_dict.get("type", "").lower() device_type = self.device_dict.get("type", "").lower()
if device_type not in ["b200", "b210"]: if device_type not in ["b200", "b210"]:
sample_rate_range = self.usrp.get_tx_rates() sample_rate_range = self.usrp.get_tx_rates()
if sample_rate < sample_rate_range.start() or sample_rate > sample_rate_range.stop(): min_rate, max_rate = sample_rate_range.start(), sample_rate_range.stop()
raise IOError( if sample_rate < min_rate or sample_rate > max_rate:
f"Sample rate {sample_rate} not valid for this USRP.\nValid\ raise SDRParameterError(
range is {sample_rate_range.start()} to {sample_rate_range.stop()}." f"{self.__class__.__name__}: Sample rate {sample_rate/1e6:.3f} Msps "
f"out of range: [{min_rate/1e6:.3f} - {max_rate/1e6:.3f} Msps]"
) )
self.usrp.set_tx_rate(sample_rate, channel) self.usrp.set_tx_rate(sample_rate, channel)
self.tx_sample_rate = self.usrp.get_tx_rate(channel) self.tx_sample_rate = self.usrp.get_tx_rate(channel)
print(f"USRP TX Sample Rate = {self.tx_sample_rate}") print(f"USRP TX Sample Rate = {self.tx_sample_rate}")
def set_tx_center_frequency(self, center_frequency, channel=0): def set_tx_center_frequency(self, center_frequency, channel=0):
center_frequency_range = self.usrp.get_tx_freq_range() center_frequency_range = self.usrp.get_tx_freq_range()
if center_frequency < center_frequency_range.start() or center_frequency > center_frequency_range.stop(): min_rate, max_rate = center_frequency_range.start(), center_frequency_range.stop()
raise IOError( if center_frequency < min_rate or center_frequency > max_rate:
f"Center frequency {center_frequency} out of range for USRP.\ raise SDRParameterError(
\nValid range is {center_frequency_range.start()}\ f"{self.__class__.__name__}: Center frequency {center_frequency/1e9:.3f} GHz "
to {center_frequency_range.stop()}." f"out of range: [{min_rate/1e9:.3f} - {max_rate/1e9:.3f} GHz]"
) )
self.usrp.set_tx_freq(uhd.types.TuneRequest(center_frequency), channel) self.usrp.set_tx_freq(uhd.types.TuneRequest(center_frequency), channel)
self.tx_center_frequency = self.usrp.get_tx_freq(channel) self.tx_center_frequency = self.usrp.get_tx_freq(channel)
print(f"USRP TX Center Frequency = {self.tx_center_frequency}") print(f"USRP TX Center Frequency = {self.tx_center_frequency}")
@ -339,10 +354,8 @@ class USRP(SDR):
gain_range = self.usrp.get_tx_gain_range() gain_range = self.usrp.get_tx_gain_range()
if gain_mode == "relative": if gain_mode == "relative":
if gain > 0: if gain > 0:
raise ValueError( raise SDRParameterError("When gain_mode = 'relative', gain must be < 0. This sets\
"When gain_mode = 'relative', gain must be < 0. This sets\ the gain relative to the maximum possible gain.")
the gain relative to the maximum possible gain."
)
else: else:
# set gain relative to max # set gain relative to max
abs_gain = gain_range.stop() + gain abs_gain = gain_range.stop() + gain
@ -358,7 +371,13 @@ class USRP(SDR):
print(f"USRP TX Gain = {self.tx_gain}") print(f"USRP TX Gain = {self.tx_gain}")
def close(self): def close(self):
pass self._tx_initialized = False
self._rx_initialized = False
if hasattr(self, "rx_stream"):
del self.rx_stream
if hasattr(self, "usrp"):
del self.usrp
self.usrp = None
def _stream_tx(self, callback): def _stream_tx(self, callback):
@ -439,6 +458,9 @@ class USRP(SDR):
print(f"USRP clock source set to {self.usrp.get_clock_source(0)}") print(f"USRP clock source set to {self.usrp.get_clock_source(0)}")
def supports_dynamic_updates(self) -> dict:
return {"center_frequency": True, "sample_rate": True, "gain": True}
def _create_device_dict(identifier_value=None): def _create_device_dict(identifier_value=None):
""" """

View File

@ -0,0 +1,5 @@
"""RT-OSS HTTP server for RIA Hub integration."""
from .app import create_app
__all__ = ["create_app"]

View File

@ -0,0 +1,48 @@
"""FastAPI application factory for the RT-OSS HTTP server."""
from fastapi import Depends, FastAPI
from .auth import require_api_key
from .routers import conductor, inference
def create_app(api_key: str = "") -> FastAPI:
"""Create and configure the RT-OSS FastAPI application.
Args:
api_key: Secret key required in the ``X-API-Key`` request header.
Pass an empty string to disable authentication (development only).
Returns:
Configured FastAPI application instance.
"""
app = FastAPI(
title="RIA Toolkit OSS Server",
version="0.1.0",
description=(
"HTTP API for RT-OSS campaign orchestration and RF zone inference. "
"All endpoints (except /health) require the X-API-Key header when "
"an API key is configured."
),
)
app.state.api_key = api_key
app.include_router(
conductor.router,
prefix="/conductor",
tags=["Conductor"],
dependencies=[Depends(require_api_key)],
)
app.include_router(
inference.router,
prefix="/inference",
tags=["Inference"],
dependencies=[Depends(require_api_key)],
)
@app.get("/health", tags=["Health"])
async def health():
"""Health check — always returns 200."""
return {"status": "ok"}
return app

View File

@ -0,0 +1,36 @@
"""API key authentication dependency."""
import hmac
import logging
from fastapi import Depends, HTTPException, Request, status
from fastapi.security import APIKeyHeader
_api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False)
logger = logging.getLogger(__name__)
async def require_api_key(
request: Request,
api_key: str | None = Depends(_api_key_header),
) -> None:
"""FastAPI dependency that enforces X-API-Key header authentication.
If no API key is configured on the server (empty string), all requests
are allowed this is intended for local development only.
"""
expected: str = request.app.state.api_key
if not expected:
return # dev mode: no key set, allow all
if not hmac.compare_digest(api_key or "", expected):
client = getattr(request.client, "host", "unknown")
logger.warning(
"Authentication failure from %s%s %s",
client,
request.method,
request.url.path,
)
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Invalid or missing API key",
)

View File

@ -0,0 +1,47 @@
"""CLI entry point for the RT-OSS HTTP server.
Usage:
ria-server # default: 0.0.0.0:8080, no auth
RT_OSS_API_KEY=secret ria-server # enforce X-API-Key header
RT_OSS_PORT=9000 ria-server # custom port
Environment variables:
RT_OSS_API_KEY Shared secret for X-API-Key auth (empty = dev mode, no auth)
RT_OSS_PORT TCP port to listen on (default: 8080)
RT_OSS_HOST Bind address (default: 0.0.0.0)
"""
from __future__ import annotations
import os
def serve() -> None:
try:
import uvicorn
except ImportError:
raise SystemExit(
"uvicorn is required to run the RT-OSS server.\n" "Install it with: pip install 'ria-toolkit-oss[server]'"
)
from .app import create_app
api_key = os.environ.get("RT_OSS_API_KEY", "")
host = os.environ.get("RT_OSS_HOST", "0.0.0.0")
port = int(os.environ.get("RT_OSS_PORT", "8080"))
app = create_app(api_key=api_key)
if not api_key:
print(
"\n"
"╔══════════════════════════════════════════════════════════════╗\n"
"║ WARNING: RT_OSS_API_KEY is not set. ║\n"
"║ The server is running with NO authentication. ║\n"
"║ Anyone who can reach this port has full API access. ║\n"
"║ Set RT_OSS_API_KEY=<secret> before exposing to a network. ║\n"
"╚══════════════════════════════════════════════════════════════╝\n",
flush=True,
)
uvicorn.run(app, host=host, port=port)

View File

@ -0,0 +1,114 @@
"""Pydantic request and response models for the RT-OSS HTTP server."""
from __future__ import annotations
from pathlib import Path
from pydantic import BaseModel, field_validator
# ---------------------------------------------------------------------------
# Conductor
# ---------------------------------------------------------------------------
class DeployRequest(BaseModel):
config: dict
class DeployResponse(BaseModel):
campaign_id: str
class CampaignStatusResponse(BaseModel):
campaign_id: str
status: str
config_name: str
progress: int
total_steps: int
started_at: float
ended_at: float | None = None
result: dict | None = None
error: str | None = None
class CancelResponse(BaseModel):
campaign_id: str
cancelled: bool
# ---------------------------------------------------------------------------
# Inference
# ---------------------------------------------------------------------------
class SdrConfig(BaseModel):
device: str
center_freq: float
sample_rate: float
gain: float | str = "auto"
class LoadModelRequest(BaseModel):
model_path: str
label_map: dict[str, int] # class_name -> class_index
@field_validator("model_path")
@classmethod
def validate_model_path(cls, v: str) -> str:
p = Path(v)
if ".." in p.parts:
raise ValueError("model_path must not contain path traversal components")
if p.suffix.lower() != ".onnx":
raise ValueError("model_path must point to an .onnx file")
# Resolve to catch symlink-based traversal; return the resolved absolute path
# so callers always work with the real filesystem location.
resolved = p.resolve()
if resolved.suffix.lower() != ".onnx":
raise ValueError("Resolved model_path must point to an .onnx file")
return str(resolved)
class LoadModelResponse(BaseModel):
loaded: bool
model_path: str
num_classes: int
class StartInferenceRequest(BaseModel):
sdr_config: SdrConfig
class StartInferenceResponse(BaseModel):
running: bool
class StopInferenceResponse(BaseModel):
stopped: bool
class ConfigureRequest(BaseModel):
"""Partial SDR reconfiguration — only supplied fields are updated."""
center_freq: float | None = None
sample_rate: float | None = None
gain: float | str | None = None
class ConfigureResponse(BaseModel):
configured: bool
class InferenceStatusResponse(BaseModel):
"""Latest inference result as returned by GET /inference/status.
When ``idle`` is True the radio is scanning but no signal was detected.
``device_id`` is the raw prediction label from the model's label map.
The frontend is responsible for mapping device_id to a human name and
determining whether the device is authorized.
"""
timestamp: float
idle: bool = False
device_id: str | None = None # prediction label; None when idle
confidence: float = 0.0
snr_db: float = 0.0

View File

@ -0,0 +1,112 @@
"""Conductor routes: campaign deployment, status, and cancellation."""
from __future__ import annotations
import threading
import time
import uuid
from typing import Any
from fastapi import APIRouter, HTTPException, status
from ria_toolkit_oss.orchestration.campaign import CampaignConfig
from ria_toolkit_oss.orchestration.executor import CampaignExecutor
from ..models import (
CampaignStatusResponse,
CancelResponse,
DeployRequest,
DeployResponse,
)
from ..state import (
CampaignCancelledError,
CampaignState,
get_campaign,
set_campaign,
update_campaign,
)
router = APIRouter()
def _make_progress_cb(campaign_id: str, cancel_event: threading.Event):
def cb(step_index: int, total_steps: int, step_result: Any) -> None:
update_campaign(campaign_id, progress=step_index)
if cancel_event.is_set():
raise CampaignCancelledError(f"Cancelled at step {step_index}/{total_steps}")
return cb
def _run_campaign_thread(campaign_id: str, cfg: CampaignConfig) -> None:
state = get_campaign(campaign_id)
try:
result = CampaignExecutor(
config=cfg,
progress_cb=_make_progress_cb(campaign_id, state.cancel_event),
).run()
update_campaign(
campaign_id, status="completed", progress=cfg.total_steps(), result=result.to_dict(), ended_at=time.time()
)
except CampaignCancelledError:
update_campaign(campaign_id, status="cancelled", ended_at=time.time())
except Exception as e:
update_campaign(campaign_id, status="failed", error=str(e), ended_at=time.time())
@router.post("/deploy", response_model=DeployResponse)
async def deploy(request: DeployRequest):
"""Deploy a campaign config and start execution. Returns a ``campaign_id`` for polling.
Cancellation takes effect at step boundaries, not mid-capture.
External scripts are not permitted in server-deployed campaigns. Configure
transmitters without the ``script`` field, or run campaigns via the CLI.
"""
try:
cfg = CampaignConfig.from_dict(request.config)
except (ValueError, KeyError) as e:
raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail=str(e))
if cfg.transmitters and any(t.script for t in cfg.transmitters):
raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
detail="External scripts are not permitted in server-deployed campaigns. "
"Remove the 'script' field from all transmitters, or run the campaign via the CLI.",
)
campaign_id = str(uuid.uuid4())
cancel_event = threading.Event()
thread = threading.Thread(target=_run_campaign_thread, args=(campaign_id, cfg), daemon=True)
set_campaign(
CampaignState(
campaign_id=campaign_id,
status="running",
config_name=cfg.name,
cancel_event=cancel_event,
thread=thread,
total_steps=cfg.total_steps(),
)
)
thread.start()
return DeployResponse(campaign_id=campaign_id)
@router.get("/status/{campaign_id}", response_model=CampaignStatusResponse)
async def get_status(campaign_id: str):
"""Get the status and progress of a deployed campaign."""
state = get_campaign(campaign_id)
if not state:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=f"Campaign {campaign_id!r} not found")
return CampaignStatusResponse(**state.to_dict())
@router.post("/cancel/{campaign_id}", response_model=CancelResponse)
async def cancel(campaign_id: str):
"""Request cancellation. Takes effect at the next step boundary."""
state = get_campaign(campaign_id)
if not state:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=f"Campaign {campaign_id!r} not found")
if state.status != "running":
return CancelResponse(campaign_id=campaign_id, cancelled=False)
state.cancel_event.set()
return CancelResponse(campaign_id=campaign_id, cancelled=True)

View File

@ -0,0 +1,253 @@
"""Inference routes: model loading, inference loop control, and status polling."""
from __future__ import annotations
import logging
import threading
import time
from pathlib import Path
import numpy as np
from fastapi import APIRouter, HTTPException, status
from scipy.special import softmax
from ..models import (
ConfigureRequest,
ConfigureResponse,
InferenceStatusResponse,
LoadModelRequest,
LoadModelResponse,
StartInferenceRequest,
StartInferenceResponse,
StopInferenceResponse,
)
from ..state import InferenceState, get_inference, set_inference
router = APIRouter()
logger = logging.getLogger(__name__)
_INFERENCE_NUM_SAMPLES = 4096
# Prediction labels that mean "no signal detected" — UI should treat these as idle.
_IDLE_LABELS: frozenset[str] = frozenset({"noise", "idle", "no_signal", "unknown_protocol", "background"})
def _load_onnx_session(model_path: str):
try:
import onnxruntime as ort
except ImportError:
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail="onnxruntime not installed. Install with: pip install ria-toolkit-oss[server]",
)
resolved = Path(model_path).resolve()
if not resolved.is_file():
raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
detail=f"Model file not found: {model_path}",
)
try:
return ort.InferenceSession(str(resolved), providers=["CPUExecutionProvider"])
except Exception as e:
raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail=f"Failed to load ONNX model: {e}")
def _preprocess_samples(samples: np.ndarray, expected_shape: tuple) -> np.ndarray:
"""Reshape complex IQ samples to float32 matching the model's expected input.
Supports ``(batch, 2*N)`` interleaved and ``(batch, 2, N)`` two-channel conventions.
"""
iq = samples.astype(np.complex64)
i_ch, q_ch = iq.real, iq.imag
if len(expected_shape) == 2:
n = expected_shape[1] // 2
interleaved = np.empty(expected_shape[1], dtype=np.float32)
interleaved[0::2] = i_ch[:n]
interleaved[1::2] = q_ch[:n]
return interleaved.reshape(1, -1)
elif len(expected_shape) == 3:
n = expected_shape[2]
return np.stack([i_ch[:n], q_ch[:n]], axis=0).astype(np.float32).reshape(1, 2, n)
else:
raise ValueError(f"Unsupported model input shape: {expected_shape}")
def _stop_current_inference(state: InferenceState, timeout: float = 5.0) -> None:
state.stop_event.set()
if state.thread and state.thread.is_alive():
state.thread.join(timeout=timeout)
if state.thread.is_alive():
logger.warning("Inference thread did not stop within %.1fs; SDR resources may not be released", timeout)
def _apply_sdr_config(sdr, config: dict) -> None:
"""Re-initialise the SDR receiver with updated parameters."""
gain = config.get("gain")
if gain == "auto":
gain = None
elif gain is not None:
gain = float(gain)
kwargs: dict = {}
if config.get("center_freq") is not None:
kwargs["center_frequency"] = float(config["center_freq"])
if config.get("sample_rate") is not None:
kwargs["sample_rate"] = float(config["sample_rate"])
if gain is not None:
kwargs["gain"] = gain
if kwargs:
sdr.init_rx(**kwargs, channel=0)
def _inference_loop(state: InferenceState, sdr) -> None:
from ria_toolkit_oss.orchestration.qa import estimate_snr_db
state.sdr = sdr
state.set_running(True)
session = state.session
input_name = session.get_inputs()[0].name
expected_shape = tuple(
d if isinstance(d, int) and d > 0 else _INFERENCE_NUM_SAMPLES for d in session.get_inputs()[0].shape
)
try:
while not state.stop_event.is_set():
# Apply any pending SDR reconfiguration before the next capture.
pending = state.pop_pending_config()
if pending:
try:
_apply_sdr_config(sdr, pending)
except Exception as exc:
logger.warning("SDR reconfigure failed: %s", exc)
recording = sdr.record(num_samples=_INFERENCE_NUM_SAMPLES)
samples = recording.data[0] if recording.data.ndim > 1 else recording.data
snr_db = estimate_snr_db(samples)
try:
model_input = _preprocess_samples(samples, expected_shape)
logits = session.run(None, {input_name: model_input})[0][0].astype(np.float32)
probs = softmax(logits)
pred_idx = int(np.argmax(probs))
prediction = state.index_to_label.get(pred_idx, str(pred_idx))
except Exception as exc:
logger.warning("Inference prediction failed: %s", exc)
continue
is_idle = prediction in _IDLE_LABELS
state.set_latest(
{
"timestamp": time.time(),
"idle": is_idle,
"device_id": prediction if not is_idle else None,
"confidence": round(float(probs[pred_idx]), 4),
"snr_db": round(snr_db, 2),
}
)
finally:
state.sdr = None
try:
sdr.close()
except Exception:
pass
state.set_running(False)
@router.post("/load", response_model=LoadModelResponse)
async def load_model(request: LoadModelRequest):
"""Load an ONNX model. Stops any running inference first.
``label_map`` maps class names to integer indices (e.g. ``{"iphone13_wifi_001": 0}``).
``enrolled_devices`` enriches status responses with human names and authorization flags.
"""
existing = get_inference()
if existing and existing.get_running():
_stop_current_inference(existing)
session = _load_onnx_session(request.model_path)
set_inference(
InferenceState(
model_path=request.model_path,
label_map=request.label_map,
index_to_label={v: k for k, v in request.label_map.items()},
session=session,
)
)
return LoadModelResponse(loaded=True, model_path=request.model_path, num_classes=len(request.label_map))
@router.post("/start", response_model=StartInferenceResponse)
async def start_inference(request: StartInferenceRequest):
"""Start continuous inference. Requires a model to be loaded first."""
state = get_inference()
if not state:
raise HTTPException(
status_code=status.HTTP_409_CONFLICT, detail="No model loaded. Call POST /inference/load first."
)
if state.get_running():
raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail="Inference is already running.")
try:
from ria_toolkit_oss.orchestration.executor import _DEVICE_ALIASES
from ria_toolkit_oss.sdr import get_sdr_device
except ImportError as e:
raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail=f"SDR import failed: {e}")
sdr_cfg = request.sdr_config
# Merge any pending configure request on top of the start config.
pending = state.pop_pending_config() or {}
center_freq = float(pending.get("center_freq") or sdr_cfg.center_freq)
sample_rate = float(pending.get("sample_rate") or sdr_cfg.sample_rate)
raw_gain = pending.get("gain") if "gain" in pending else sdr_cfg.gain
gain = None if raw_gain == "auto" else float(raw_gain)
try:
sdr = get_sdr_device(_DEVICE_ALIASES.get(sdr_cfg.device.lower(), sdr_cfg.device.lower()))
sdr.init_rx(sample_rate=sample_rate, center_frequency=center_freq, gain=gain, channel=0)
except Exception as e:
raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail=f"SDR initialisation failed: {e}")
state.stop_event.clear()
state.thread = threading.Thread(target=_inference_loop, args=(state, sdr), daemon=True)
state.thread.start()
return StartInferenceResponse(running=True)
@router.post("/stop", response_model=StopInferenceResponse)
async def stop_inference():
"""Stop the running inference loop."""
state = get_inference()
if not state or not state.get_running():
return StopInferenceResponse(stopped=False)
_stop_current_inference(state)
return StopInferenceResponse(stopped=True)
@router.post("/configure", response_model=ConfigureResponse)
async def configure_inference(request: ConfigureRequest):
"""Update SDR parameters (center_freq, sample_rate, gain) on the fly.
If inference is running the change is applied at the next capture boundary.
If inference is not running the config is stored and applied when it starts.
Only fields present in the request body are updated.
"""
state = get_inference()
if not state:
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail="No model loaded. Call POST /inference/load first.",
)
pending = {k: v for k, v in request.model_dump().items() if v is not None}
if pending:
state.set_pending_config(pending)
return ConfigureResponse(configured=bool(pending))
@router.get("/status", response_model=InferenceStatusResponse | None)
async def inference_status():
"""Return the latest inference result, or null if no model is loaded."""
state = get_inference()
if not state:
return None
latest = state.get_latest()
return InferenceStatusResponse(**latest) if latest else None

View File

@ -0,0 +1,121 @@
"""In-memory state for running campaigns and inference sessions."""
from __future__ import annotations
import threading
import time
from dataclasses import dataclass, field
from typing import Any, Optional
class CampaignCancelledError(Exception):
"""Raised by the progress callback when a cancel is requested."""
@dataclass
class CampaignState:
campaign_id: str
status: str # "running" | "completed" | "failed" | "cancelled"
config_name: str
cancel_event: threading.Event
thread: threading.Thread
total_steps: int = 0
progress: int = 0
result: Optional[dict] = None
error: Optional[str] = None
started_at: float = field(default_factory=time.time)
ended_at: Optional[float] = None
def to_dict(self) -> dict:
return {
"campaign_id": self.campaign_id,
"status": self.status,
"config_name": self.config_name,
"progress": self.progress,
"total_steps": self.total_steps,
"result": self.result,
"error": self.error,
"started_at": self.started_at,
"ended_at": self.ended_at,
}
@dataclass
class InferenceState:
model_path: str
label_map: dict[str, int] # class_name -> class_index
index_to_label: dict[int, str] # reverse: class_index -> class_name
session: Any # onnxruntime.InferenceSession
stop_event: threading.Event = field(default_factory=threading.Event)
thread: Optional[threading.Thread] = None
sdr: Any = None # live SDR object while inference is running
running: bool = False
_lock: threading.Lock = field(default_factory=threading.Lock, repr=False)
_latest: Optional[dict] = field(default=None, repr=False)
_pending_sdr_config: Optional[dict] = field(default=None, repr=False)
def set_latest(self, result: dict) -> None:
with self._lock:
self._latest = result
def get_latest(self) -> Optional[dict]:
with self._lock:
return self._latest
def set_pending_config(self, config: dict) -> None:
with self._lock:
self._pending_sdr_config = config
def pop_pending_config(self) -> Optional[dict]:
with self._lock:
cfg = self._pending_sdr_config
self._pending_sdr_config = None
return cfg
def set_running(self, value: bool) -> None:
with self._lock:
self.running = value
def get_running(self) -> bool:
with self._lock:
return self.running
# ---------------------------------------------------------------------------
# Module-level stores
# ---------------------------------------------------------------------------
_campaigns: dict[str, CampaignState] = {}
_campaigns_lock = threading.Lock()
_inference: Optional[InferenceState] = None
_inference_lock = threading.Lock()
def get_campaign(campaign_id: str) -> Optional[CampaignState]:
with _campaigns_lock:
return _campaigns.get(campaign_id)
def set_campaign(state: CampaignState) -> None:
with _campaigns_lock:
_campaigns[state.campaign_id] = state
def update_campaign(campaign_id: str, **kwargs) -> None:
with _campaigns_lock:
state = _campaigns.get(campaign_id)
if state:
for k, v in kwargs.items():
setattr(state, k, v)
def get_inference() -> Optional[InferenceState]:
with _inference_lock:
return _inference
def set_inference(state: Optional[InferenceState]) -> None:
global _inference
with _inference_lock:
_inference = state

View File

@ -0,0 +1,7 @@
"""
The Signal Package provides a comprehensive suite of tools for signal generation and processing.
"""
from .recordable import Recordable
__all__ = ["Recordable"]

View File

@ -0,0 +1,405 @@
"""
.. todo:: Need to add some information here about signal generation and the signal generators in this module.
"""
import warnings
from typing import Optional
import numpy as np
import scipy.signal
from scipy.signal import butter
from scipy.signal import chirp as sci_chirp
from scipy.signal import hilbert, lfilter
from ria_toolkit_oss.data.recording import Recording
def sine(
sample_rate: Optional[int] = 1000,
length: Optional[int] = 1000,
frequency: Optional[float] = 1000,
amplitude: Optional[float] = 1,
baseband_phase: Optional[float] = 0,
rf_phase: Optional[float] = 0,
dc_offset: Optional[float] = 0,
) -> Recording:
"""Generate a basic sine wave signal.
:param sample_rate: The number of samples per second (Hz). Defaults to 1,000.
:type sample_rate: int, optional
:param length: Number of samples in the recording. Defaults to 1,000.
:type length: int, optional
:param frequency: The frequency of the sine wave (Hz). Defaults to 1,000.
:type frequency: float, optional
:param amplitude: Amplitude of the sine wave. Defaults to 1.
:type amplitude: float, optional
:param baseband_phase: Phase offset in radians, relative to the sine wave frequency. Defaults to 0.
:type baseband_phase: float, optional
:param rf_phase: Phase offset in radians of the complex samples. Defaults to 0.
:type rf_phase: float, optional
:param dc_offset: DC offset (average of the sine wave). Defaults to 0.
:type dc_offset: float, optional
:return: A Recording object containing the generated sine wave signal.
:rtype: Recording
Examples:
.. todo:: Usage examples coming soon!
"""
if sample_rate < 1:
raise ValueError("sample_rate must be > 1")
total_time = length / sample_rate
t = np.linspace(0, total_time, length, endpoint=False)
sine_wave = amplitude * np.sin(2 * np.pi * frequency * t + baseband_phase) + dc_offset
complex_sine_wave = sine_wave * np.exp(1j * rf_phase)
metadata = {
"signal": "sine",
"source": "synth",
"sample_rate": sample_rate,
"length": length,
"signal_frequency": frequency,
"amplitude": amplitude,
"baseband_phase": baseband_phase,
"rf_phase": rf_phase,
"dc_offset": dc_offset,
}
return Recording(data=complex_sine_wave, metadata=metadata)
def square(
sample_rate: Optional[int] = 1000,
length: Optional[int] = 1000,
frequency: Optional[float] = 1,
amplitude: Optional[float] = 1,
duty_cycle: Optional[float] = 0.5,
baseband_phase: Optional[float] = 0,
rf_phase: Optional[float] = 0,
dc_offset: Optional[float] = 0,
) -> Recording:
"""Generate a square wave signal.
:param sample_rate: The number of samples per second (Hz). Defaults to 1,000.
:type sample_rate: int, optional
:param length: Number of samples in the recording. Defaults to 1,000.
:type length: int, optional
:param frequency: The frequency of the square wave (Hz). Defaults to 1.
:type frequency: float, optional
:param amplitude: The amplitude of the square wave. Defaults to 1.
:type amplitude: float, optional
:param duty_cycle: The duty cycle of the square wave as a decimal in the range [0, 1]. Defaults to 0.5.
:param baseband_phase: Phase offset in radians, relative to the square wave frequency. Defaults to 0.
:type baseband_phase: float, optional
:param rf_phase: Phase offset in radians of the complex samples. Defaults to 0.
:type rf_phase: float, optional
:param dc_offset: DC offset. If dc_offset is 0 but duty_cycle is not 0.5, the actual dc offset may not be
exactly 0. Defaults to 0.
:type dc_offset: float, optional
:return: A Recording object containing the generated square wave signal.
:rtype: Recording
Examples:
.. todo:: Usage examples coming soon!
"""
if sample_rate < 1:
raise ValueError("sample_rate must be > 1")
t = np.arange(length)
square_wave = amplitude * scipy.signal.square(
2 * np.pi * frequency * (t / sample_rate - (baseband_phase / (2 * np.pi))), duty=duty_cycle
)
square_wave = square_wave + dc_offset
complex_square_wave = square_wave * np.exp(1j * rf_phase)
metadata = {
"signal": "square",
"source": "synth",
"sample_rate": sample_rate,
"length": length,
"signal_frequency": frequency,
"amplitude": amplitude,
"baseband_phase": baseband_phase,
"duty_cycle": duty_cycle,
"rf_phase": rf_phase,
"dc_offset": dc_offset,
}
return Recording(data=complex_square_wave, metadata=metadata)
def sawtooth(
sample_rate: Optional[int] = 1000,
length: Optional[int] = 1000,
frequency: Optional[float] = 1,
amplitude: Optional[float] = 1,
baseband_phase: Optional[float] = 0,
rf_phase: Optional[float] = 0,
dc_offset: Optional[float] = 0,
) -> Recording:
"""Generate a sawtooth wave signal.
:param sample_rate: The number of samples per second (Hz). Defaults to 1,000.
:type sample_rate: int, optional
:param length: Number of samples in the recording. Defaults to 1,000.
:type length: int, optional
:param frequency: The frequency of the sawtooth wave (Hz). Defaults to 1.
:type frequency: float, optional
:param amplitude: Amplitude of the sawtooth wave. Defaults to 1.
:type amplitude: float, optional
:param baseband_phase: Phase offset in radians, relative to the wave frequency. Defaults to 0.
:type baseband_phase: float, optional
:param rf_phase: Phase offset in radians of the complex samples. Defaults to 0.
:type rf_phase: float, optional
:param dc_offset: DC offset (average of the wave). Defaults to 0.
:type dc_offset: float, optional
:return: A Recording object containing the generated sawtooth signal.
:rtype: Recording
Examples:
.. todo:: Usage examples coming soon!
"""
if sample_rate < 1:
raise ValueError("sample_rate must be > 1")
t = np.arange(length)
saw_wave = amplitude * scipy.signal.sawtooth(
2 * np.pi * frequency * (t / sample_rate - (baseband_phase / (2 * np.pi)))
)
saw_wave = saw_wave + dc_offset
complex_sine_wave = saw_wave * np.exp(1j * rf_phase)
metadata = {
"signal": "sawtooth",
"source": "synth",
"sample_rate": sample_rate,
"length": length,
"signal_frequency": frequency,
"amplitude": amplitude,
"baseband_phase": baseband_phase,
"rf_phase": rf_phase,
"dc_offset": dc_offset,
}
return Recording(data=complex_sine_wave, metadata=metadata)
def noise(
sample_rate: Optional[int] = 1000,
length: Optional[int] = 1000,
rms_power: Optional[float] = 0.2,
dc_offset: Optional[float] = 0,
) -> Recording:
"""Generate a Gaussian white noise (GWN) wave signal.
:param sample_rate: The number of samples per second (Hz). Defaults to 1,000.
:type sample_rate: int, optional
:param length: Number of samples in the recording. Defaults to 1,000.
:type length: int, optional
:param rms_power: Root-Mean-Square power of the generated signal. Defaults to 0.2.
:type rms_power: float, optional
:param dc_offset: DC offset (average of the wave). Defaults to 0.
:type dc_offset: float, optional
:return: A Recording object containing the generated noise signal.
:rtype: Recording
Examples:
.. todo:: Usage examples coming soon!
"""
if sample_rate < 1:
raise ValueError("sample_rate must be > 1")
variance = rms_power**2
magnitude = np.random.normal(loc=0, scale=np.sqrt(variance), size=length)
magnitude2 = np.clip(magnitude, -1, 1)
# TODO figure out a better way to make it conform to [-1,1]
if not np.array_equal(magnitude, magnitude2):
warnings.warn("basic_signal_generator.noise: magnitude clipped to [-1, 1]")
phase = np.random.uniform(low=0, high=2 * np.pi, size=length)
complex_awgn = magnitude2 * np.exp(1j * phase)
complex_awgn = complex_awgn + dc_offset
metadata = {
"signal": "awgn",
"source": "synth",
"sample_rate": sample_rate,
"length": length,
"amplitude": np.max(np.abs(complex_awgn)),
"dc_offset": dc_offset,
}
return Recording(data=complex_awgn, metadata=metadata)
def chirp(sample_rate: int, num_samples: int, center_frequency: Optional[float] = 0) -> Recording:
"""Generator a sinusoidal waveform with a linear frequency sweep.
Start and end frequencies are chosen based on the maximum frequency range that can be covered without aliasing,
which is determined by the sample rate. To chirp over a larger frequency range, increase the sample rate.
Chirps are often used in radar, sonar, and communication systems because they can effectively cover a wide
frequency range and are useful for testing and measurement purposes.
:param sample_rate: The number of samples per second (Hz).
:type sample_rate: int
:param num_samples: The number of samples in the chirp.
:type num_samples: int
:param center_frequency: The center frequency of the chirp.
:type center_frequency: float, optional
:return: A Recording object containing the generated noise signal.
:rtype: Recording
Examples:
.. todo:: Usage examples coming soon!
"""
# Ensure that the generated chirp signal remains within a safe frequency range to avoid aliasing.
if num_samples < 2:
raise ValueError("num_samples must be >= 2 for chirp generation")
chirp_start_frequency = center_frequency - sample_rate / 4
chirp_end_frequency = center_frequency + sample_rate / 4
t = np.arange(num_samples) / int(sample_rate)
f_t = chirp_start_frequency + (chirp_end_frequency - chirp_start_frequency) * t / t[-1]
complex_samples = np.exp(2.0j * np.pi * f_t * t)
metadata = {"sample_rate": sample_rate, "num_samples:": num_samples}
return Recording(data=complex_samples, metadata=metadata)
def lfm_chirp_complex(
sample_rate: int, width: int, chirp_period: float, sigfc: int | float, total_time: float, chirp_type: str
):
"""
Generate a complex linearly frequency modulated chirp signal.
:param sample_rate:
"""
# Time vector for one chirp
chirp_length = int(chirp_period * sample_rate)
t_chirp = np.linspace(0, chirp_period, chirp_length)
if len(t_chirp) > chirp_length:
t_chirp = t_chirp[:chirp_length]
# Generate one chirp from 0 Hz to the full width
if chirp_type == "up":
baseband_chirp = sci_chirp(t_chirp, f0=0, f1=width, t1=chirp_period, method="linear")
elif chirp_type == "down":
baseband_chirp = sci_chirp(t_chirp, f0=width, f1=0, t1=chirp_period, method="linear")
elif chirp_type == "up_down":
half_duration = chirp_period / 2
t_up_half, t_down_half = np.array_split(t_chirp, 2)
up_part = sci_chirp(t_up_half, f0=0, t1=half_duration, f1=width, method="linear")
down_part = np.flip(up_part)
baseband_chirp = np.concatenate([up_part, down_part])
else:
raise ValueError(f"Unknown chirp_type '{chirp_type}'. Must be 'up', 'down', or 'up_down'.")
# Generate the full signal by tiling the windowed chirp
num_chirps = round(total_time / chirp_period)
full_signal = np.tile(baseband_chirp, num_chirps)
# Create an analytic signal (complex with no negative frequency components)
analytic_signal = hilbert(full_signal)
# Shift the chirp to the signal center frequency
t_full = np.linspace(0, total_time, len(analytic_signal))
complex_chirp = analytic_signal * np.exp(1j * 2 * np.pi * (sigfc - width / 2) * t_full)
nyquist = 0.5 * sample_rate # Nyquist frequency
normal_cutoff = width / nyquist # Normalize cutoff
b, a = butter(8, normal_cutoff, btype="low", analog=False)
filtered_chirp = lfilter(b, a, complex_chirp)
metadata = {
"source": "basic_signal_generator",
"sample_rate": sample_rate,
"width": width,
"chirp_period": chirp_period,
"chirp_center_frequency": sigfc,
"total_time": total_time,
"filter": "low_pass",
}
return Recording(data=filtered_chirp, metadata=metadata)
def complex_sine(sample_rate, length, frequency):
"""
Generates a complex sine wave.
:param sample_rate: The number of samples per second (Hz). Defaults to 1,000.
:type sample_rate: int, optional
:param length: Number of samples in the recording. Defaults to 1,000.
:type length: int, optional
:param frequency: The frequency of the square wave (Hz). Defaults to 1.
:type frequency: float, optional
"""
if sample_rate < 1:
raise ValueError("sample_rate must be > 1")
total_time = length / sample_rate
t = np.linspace(0, total_time, length, endpoint=False)
power_factor = np.random.uniform(-8, 0)
complex_sine_wave = (10**power_factor) * np.exp(1j * 2 * np.pi * frequency * t)
metadata = {
"signal": "complex_sine",
"source": "synth",
"sample_rate": sample_rate,
"length": length,
"signal_frequency": frequency,
"power_factor": power_factor,
}
return Recording(data=complex_sine_wave, metadata=metadata)
def birdie(sample_rate, length, frequency):
"""
Generates a complex sine wave for birdies in demos.
:param sample_rate: The number of samples per second (Hz). Defaults to 1,000.
:type sample_rate: int, optional
:param length: Number of samples in the recording. Defaults to 1,000.
:type length: int, optional
:param frequency: The frequency of the square wave (Hz). Defaults to 1.
:type frequency: float, optional
"""
if sample_rate < 1:
raise ValueError("sample_rate must be > 1")
total_time = length / sample_rate
t = np.linspace(0, total_time, length, endpoint=False)
power_factor = np.random.uniform(-2.5, -0.5)
complex_sine_wave = (10**power_factor) * np.exp(1j * 2 * np.pi * frequency * t)
metadata = {
"signal": "complex_sine",
"source": "synth",
"sample_rate": sample_rate,
"length": length,
"signal_frequency": frequency,
"power_factor": power_factor,
}
return Recording(data=complex_sine_wave, metadata=metadata)

View File

@ -0,0 +1,63 @@
# RIA Block Signal Generator
Welcome to the RIA block generator! These modular signal processing blocks can be used together to create synthetic radio signals, and it is easy to add new blocks.
These instructions apply to using the block system within python, and not to the front end GUI.
# Overview
A block can be a SourceBlock or a ProcessBlock. Either of these can also be a RecordableBlock, or not.
SourceBlocks produce samples, and have no input.
ProcessBlocks process samples. They also provide a .process() method that can be used to directly operate on samples without using the block system.
RecordableBlocks provide a .record() method to create a recording. Some blocks, such as the RandomBinarySource produce non IQ sample formats such as bits, which is why they are not recordable.
Blocks are connected in a tree structure terminating in a final RecordableBlock. Blocks may have multiple inputs but can only have one output, and this output cannot be connected to the inputs of more than one block.
# Getting Started
Let's create a block flow tree to create a QPSK signal, add a LFM jamming signal, and add some noise.
First, imports:
```
from ria_toolkit_oss.signal.block_generator import RandomBinarySource, Mapper, Upsampling, RaisedCosineFilter, FrequencyShift, LFMChirpSource, Add, AWGNSource\
sample_rate = 1000000
```
Create the random binary source block:
```
source = RandomBinarySource()
```
Create a constellation mapper block to convert bits to QPSK symbols, connecting its input to the source block.
```
mapper = Mapper(input=[source], constellation_type="PSK", num_bits_per_symbol=2)
```
Add an upsampling block and a raised cosine filter for pulse shaping:
```
upsampler = Upsampling(input = [mapper], factor = 4)
filter = RaisedCosineFilter(input=[upsampler], span_in_symbols=100, upsampling_factor=4, beta=0.1)
```
Create another branch of the block tree for the LFM jamming source and frequency shifter:
```
jammer=LFMChirpSource(sample_rate=sample_rate, bandwidth=sample_rate/2, chirp_period=0.01, chirp_type='up')
f_shift = FrequencyShift(input = [jammer], shift_frequency=100000, sampling_rate=sample_rate)
```
Sum the two signals with an Add block:
```
adder = Add(input=[filter, f_shift])
```
Add another branch to create noise:
```
awgn_source = AWGNSource(variance = 0.05)
adder2 = Add(input = [adder, awgn_source])
```
Finally create a recording at the terminal block in the tree:
```
recording = mapper.record(100000)
recording.view()
recording.to_sigmf()
```

View File

@ -0,0 +1,83 @@
"""
RIA Block-Based Signal Generator Module
This module provides a flexible framework for simulating communication systems using configurable blocks. It includes:
- Various block types: filters, mappers, modulators, demodulators, and channels
- Easy-to-use classes for creating custom signal processing chains
- Pre-configured generators for common use cases
Key features:
- Modular design for building complex systems
- Customizable block parameters
- Ready-to-use generators for quick prototyping
Usage:
1. Import desired blocks
2. Configure block parameters
3. Connect blocks to create a processing chain
4. Run simulations with custom or provided input signals
For detailed examples and API reference, see the documentation.
"""
from .basic import Add, FrequencyShift, MultiplyConstant, PhaseShift
from .generators import PAMGenerator, PSKGenerator, QAMGenerator, SignalGenerator
from .mapping import Mapper, SymbolDemapper
from .process_block import ProcessBlock
from .pulse_shaping import (
GaussianFilter,
RaisedCosineFilter,
RectFilter,
RootRaisedCosineFilter,
SincFilter,
Upsampling,
)
from .recordable_block import RecordableBlock
from .siso_channel import AWGNChannel, FlatRayleigh
from .source import (
AWGNSource,
BinarySource,
ConstantSource,
LFMChirpSource,
RecordingSource,
SawtoothSource,
SineSource,
SquareSource,
)
from .source_block import SourceBlock
from .symbol_modulation import GMSKModulator, OOKModulator, OQPSKModulator
__all__ = [
"Add",
"FrequencyShift",
"MultiplyConstant",
"PhaseShift",
"PAMGenerator",
"PSKGenerator",
"QAMGenerator",
"SignalGenerator",
"Mapper",
"SymbolDemapper",
"GMSKModulator",
"OOKModulator",
"OQPSKModulator",
"RaisedCosineFilter",
"RootRaisedCosineFilter",
"SincFilter",
"RectFilter",
"GaussianFilter",
"Upsampling",
"AWGNChannel",
"FlatRayleigh",
"AWGNSource",
"ConstantSource",
"LFMChirpSource",
"BinarySource",
"RecordingSource",
"SawtoothSource",
"SineSource",
"SquareSource",
]

View File

@ -0,0 +1,6 @@
from .add import Add
from .frequency_shift import FrequencyShift
from .multiply_constant import MultiplyConstant
from .phase_shift import PhaseShift
__all__ = ["Add", "FrequencyShift", "MultiplyConstant", "PhaseShift"]

View File

@ -0,0 +1,67 @@
import numpy as np
from ria_toolkit_oss.signal.block_generator.data_types import DataType
from ria_toolkit_oss.signal.block_generator.process_block import ProcessBlock
from ria_toolkit_oss.signal.block_generator.recordable_block import RecordableBlock
class Add(RecordableBlock, ProcessBlock):
"""
Add Block
Sums the input from two blocks.
Input type: [BASEBAND_SIGNAL, BASEBAND_SIGNAL]
Output type: BASEBAND_SIGNAL
"""
def __init__(self):
super().__init__()
def connect_input(self, input):
datatype = input[0].output_type
for input_block in input:
if input_block.output_type != datatype:
print(input_block.output_type)
raise ValueError(
f"'Add' block requires inputs to have the same datatype but got \
{'[' + ',' .join(f'{block.__class__.__name__}({block.output_type()})' for block in input) + ']'}"
) # TODO make this print the strings not numbers
return super().connect_input(input)
def _get_input_samples(self, block, num_samples):
"""
Request n samples from a block and validate the correct shape of CxN samples was received.
"""
samples = block.get_samples(num_samples)
if len(samples) != num_samples:
raise ValueError(f"Block {self.__class__.__name__} requested {num_samples} \
from block {block.__class__.__name__} but got {len(samples)}.")
return samples
@property
def input_type(self):
return [DataType.BASEBAND_SIGNAL, DataType.BASEBAND_SIGNAL]
@property
def output_type(self):
return DataType.BASEBAND_SIGNAL
def __call__(self, samples: list[np.array]):
"""
Add two signals together.
:param samples: A list containing two sample arrays of the same length.
:type samples: list of np.array
:returns: An array of output samples.
:rtype: np.array"""
if len(samples) != 2:
raise ValueError("Input must be a list of two input arrays.")
if len(samples[0]) != len(samples[1]):
raise ValueError(f"Input arrays must be equal length but were {len(samples[0])} and {len(samples[1])}")
return samples[0] + samples[1]

View File

@ -0,0 +1,56 @@
from typing import Optional
import numpy as np
from ria_toolkit_oss.signal.block_generator.data_types import DataType
from ria_toolkit_oss.signal.block_generator.process_block import ProcessBlock
from ria_toolkit_oss.signal.block_generator.recordable_block import RecordableBlock
class FrequencyShift(ProcessBlock, RecordableBlock):
"""
Frequency Shift Block
Applies a frequency shift the input signal.
Input type: BASEBAND_SIGNAL
Output type: BASEBAND_SIGNAL
:param shift_frequency: The frequency to shift the signal by.
:type shift_frequency: float
:param sample_rate: The sample rate to use in frequency calculations.
:type sample_rate: float.
WARNING: This block does not include any anti-aliasing filters.
It is the responsiblity of the user to ensure proper
filtering is performed before/after this block to prevent aliasing.
"""
def __init__(self, shift_frequency: Optional[float] = 100000, sampling_rate: Optional[float] = 1000000):
self.shift_frequency = shift_frequency
self.sampling_rate = sampling_rate
super().__init__()
@property
def input_type(self) -> DataType:
return [DataType.BASEBAND_SIGNAL]
@property
def output_type(self) -> DataType:
return DataType.BASEBAND_SIGNAL
def __call__(self, samples: list[np.array]):
"""
Frequency shift input samples by the previously intialized shift frequency.
:param samples: A list containing a single array of complex samples.
:type samples: list of np.array
:returns: Processed samples.
:rtype: np.array
"""
signal = samples[0]
num_samples = len(signal)
t = np.arange(num_samples) / self.sampling_rate
carrier = np.exp(1j * 2 * np.pi * self.shift_frequency * t)
return signal * carrier

View File

@ -0,0 +1,41 @@
from typing import Optional
from ria_toolkit_oss.signal.block_generator.data_types import DataType
from ria_toolkit_oss.signal.block_generator.process_block import ProcessBlock
from ria_toolkit_oss.signal.block_generator.recordable_block import RecordableBlock
class MultiplyConstant(ProcessBlock, RecordableBlock):
"""
MultiplyConstant Block
Multiply the input samples by a constant.
Input Type: BASEBAND_SIGNAL
Output Type: BASEBAND_SIGNAL
:param multiplier: The value to multiply the samples by.
:type multiplier: float.
"""
def __init__(self, multiplier: Optional[float] = 0.5):
self.multiplier = multiplier
@property
def input_type(self):
return [DataType.BASEBAND_SIGNAL]
@property
def output_type(self):
return DataType.BASEBAND_SIGNAL
def __call__(self, samples):
"""
Multiply an array of complex samples by the previously initialised value.
:param samples: A list containing a single array of complex samples.
:type samples: list of np.array
:returns: Processed samples.
:rtype: np.array"""
return samples[0] * self.multiplier

View File

@ -0,0 +1,40 @@
from typing import Optional
import numpy as np
from ria_toolkit_oss.signal.block_generator.data_types import DataType
from ria_toolkit_oss.signal.block_generator.process_block import ProcessBlock
from ria_toolkit_oss.signal.block_generator.recordable_block import RecordableBlock
class PhaseShift(ProcessBlock, RecordableBlock):
"""
PhaseShift Block
Apply a complex phase shift to the input signal.
:param phase: The complex phase shift in radians.
:type phase: float."""
def __init__(self, phase: Optional[float] = 0):
self.phase = phase
super().__init__()
@property
def input_type(self):
return [DataType.BASEBAND_SIGNAL]
@property
def output_type(self):
return DataType.BASEBAND_SIGNAL
def __call__(self, samples):
"""
Phase shift an array of complex samples by the previously initialised phase.
:param samples: A list containing a single array of complex samples.
:type samples: list of np.array
:returns: Processed samples.
:rtype: np.array"""
return samples[0] * np.exp(1j * self.phase)

View File

@ -0,0 +1,122 @@
import json
from abc import ABC, abstractmethod
import numpy as np
from ria_toolkit_oss.signal.block_generator.data_types import DataType
class Block(ABC):
"""
Abstract base class for signal processing blocks.
This class defines the interface for all signal processing blocks,
including input and output data types and the call method for processing.
"""
@property
@abstractmethod
def input_type(self) -> DataType:
"""
Get the input data type for the block.
:return: The input data type.
:rtype: DataType
"""
pass
@property
@abstractmethod
def output_type(self) -> DataType:
"""
Get the output data type for the block.
:return: The output data type.
:rtype: DataType
"""
pass
@abstractmethod
def get_samples(self, num_samples) -> np.ndarray:
"""
Process the input data and produce output.
:param args: Positional arguments for the processing method.
:param kwargs: Keyword arguments for the processing method.
:return: The processed output data.
:rtype: numpy array
"""
pass
def _get_metadata(self):
metadata = {}
for key, value in vars(self).items():
try:
# Try to serialize the value to check if it's JSON serializable
json.dumps(value)
metadata[f"BlockGenerator:{self.__class__.__name__}:{key}"] = value
except (TypeError, ValueError):
# If the value is not JSON serializable, skip it
continue
for block in self.input:
metadata = self._combine_dicts_and_handle_double_keys(block._get_metadata(), metadata)
return metadata
# TODO improve this
def _combine_dicts_and_handle_double_keys(self, source_dict, other_dict):
for key, value in source_dict.items():
# Find the last colon in the key
last_colon_index = key.rfind(":")
# Ensure there's at least one colon in the key
if last_colon_index == -1:
# If no colon, just append "(1)"
new_key = f"{key}(1)"
else:
# Extract the prefix and the part after the last colon
prefix = key[:last_colon_index]
suffix = key[last_colon_index + 1 :]
# Check if the suffix has a number inside parentheses
if suffix.startswith("(") and suffix.endswith(")") and suffix[1:-1].isdigit():
# Extract the number inside the parentheses and increment it
number = int(suffix[1:-1]) + 1
new_key = f"{prefix}({number})"
else:
# No number at the end, so just append "(1)"
new_key = f"{key}(1)"
# Ensure the new key is unique in both dictionaries
while new_key in other_dict:
# Find the last parentheses to extract the current number
last_paren_index = new_key.rfind(")")
prefix = new_key[:last_paren_index]
suffix = new_key[last_paren_index + 1 :]
# Extract the number in parentheses and increment it
if suffix.startswith("(") and suffix.endswith(")") and suffix[1:-1].isdigit():
number = int(suffix[1:-1]) + 1
else:
number = 1 # Default to 1 if no number in parentheses
# Create the new key with the incremented number
new_key = f"{prefix}({number})"
# Update the other dictionary with the new key
other_dict[new_key] = value
return other_dict
@abstractmethod
def __call__(self, *args, **kwargs) -> np.ndarray:
"""
Process the input data and produce output.
:param args: Positional arguments for the processing method.
:param kwargs: Keyword arguments for the processing method.
:return: The processed output data.
:rtype: numpy array
"""
pass

View File

@ -0,0 +1,112 @@
import numpy as np
from ria_toolkit_oss.signal.block_generator.continuous_modulation.demodulator import (
Demodulator,
)
from ria_toolkit_oss.signal.block_generator.data_types import DataType
class CoherentCorrelator(Demodulator):
"""
A correlator for coherent detection that performs frequency downconversion via correlation.
This class implements a coherent correlator by multiplying the received passband signal
with a reference carrier and integrating (or convolving with an optional matched filter)
over one symbol period. The reference carrier can be generated in one of two ways:
- If 'per_symbol' is True, the carrier reference is generated for each symbol separately
(i.e. a time vector that resets to zero for every symbol).
- If 'per_symbol' is False, a continuous time vector is used over the entire signal.
Optionally, a pulse-shaping filter (subclass of PulseShapingFilter) can be provided. When set,
each symbol's downconverted product is first convolved with the matched filter (via its
`apply_matched_filter` method) before integration. If not provided, a simple integration (sum)
is performed.
:param carrier_frequency: The carrier frequency (Hz) used for demodulation.
:param symbol_duration: The duration (seconds) of one symbol period.
:param sampling_rate: The sampling rate (Hz) of the received signal.
:param per_symbol: If True, uses a per-symbol time vector; if False, uses a continuous time vector.
"""
def __init__(
self,
carrier_frequency: float,
symbol_duration: float,
sampling_rate: float,
per_symbol: bool = True,
):
self.carrier_frequency = carrier_frequency
self.symbol_duration = symbol_duration
self.sampling_rate = sampling_rate
self.samples_per_symbol = int(self.symbol_duration * self.sampling_rate)
self.per_symbol = per_symbol
@property
def input_type(self) -> DataType:
"""The correlator expects a passband signal as input."""
return DataType.PASSBAND_SIGNAL
@property
def output_type(self) -> DataType:
"""The correlator produces decision statistics (typically complex or real values)."""
return DataType.BITS
def __call__(self, signal: np.ndarray) -> np.ndarray:
"""
Correlate the input passband signal with a reference carrier to produce decision statistics.
The input signal is assumed to be a 2D numpy array of shape (batch_size, total_samples),
where total_samples is an integer multiple of the number of samples per symbol.
Depending on the 'per_symbol' flag, the reference carrier is generated as:
- If True: a per-symbol time vector (from 0 to symbol_duration) is used.
- If False: a continuous time vector for the entire signal is used.
If a pulse shaping filter is provided (self.filter is not None), the symbol's product
(signal multiplied by the reference carrier) is convolved with the filter via its
`apply_matched_filter` method before integration.
:param signal: The input passband signal (shape: (batch_size, total_samples)).
:return: A 2D numpy array of decision statistics with shape (batch_size, num_symbols).
:raises ValueError: If the total number of samples is not an integer multiple of samples_per_symbol.
"""
batch_size, total_samples = signal.shape
samples_per_symbol = self.samples_per_symbol
if total_samples % samples_per_symbol != 0:
raise ValueError(
"The total number of samples in the signal must be an integer multiple of the samples per symbol."
)
num_symbols = total_samples // samples_per_symbol
# Reshape the signal into symbols: shape (batch_size, num_symbols, samples_per_symbol)
symbols = signal.reshape(batch_size, num_symbols, samples_per_symbol)
if self.per_symbol:
# Generate per-symbol time vector (from 0 to symbol_duration)
t_symbol = np.arange(samples_per_symbol) / self.sampling_rate
if np.iscomplexobj(signal):
reference = np.exp(-1j * 2 * np.pi * self.carrier_frequency * t_symbol)
else:
reference = np.cos(2 * np.pi * self.carrier_frequency * t_symbol)
# Multiply each symbol with the reference (broadcasted) to obtain the product.
product = symbols * reference[None, None, :]
else:
# Use a continuous time vector for the entire signal.
t_full = np.arange(total_samples) / self.sampling_rate
if np.iscomplexobj(signal):
reference_full = np.exp(-1j * 2 * np.pi * self.carrier_frequency * t_full)
else:
reference_full = np.cos(2 * np.pi * self.carrier_frequency * t_full)
reference_full = reference_full.reshape(1, num_symbols, samples_per_symbol)
product = symbols * reference_full
decision_stats = np.sum(product, axis=2)
return decision_stats
def __str__(self) -> str:
"""Return a string representation of the CoherentCorrelator."""
return (
f"CoherentCorrelator(carrier_frequency={self.carrier_frequency}, "
f"symbol_duration={self.symbol_duration}, sampling_rate={self.sampling_rate} "
)

Some files were not shown because too many files have changed in this diff Show More