Compare commits
90 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 2baae2f63e | |||
| 4df5455af4 | |||
| 2881aaf06e | |||
| 50d04161b7 | |||
| 07c72294f5 | |||
| c9b19949ad | |||
| 53e8e5adb6 | |||
| a502dd97a9 | |||
| 34b67c0c17 | |||
| 39d5d74d6a | |||
| 8a66860d33 | |||
| 4d3aaf6ec8 | |||
| 4aea2841be | |||
| 4c2c9c0288 | |||
|
Mmuq
|
a68a325cb4 | ||
| 50438558d4 | |||
| c27a5944c7 | |||
| 062a0e766f | |||
| cdcc03327b | |||
| 8d2f9eebaf | |||
| 6019a38b8b | |||
| ef772a3755 | |||
| 98f63b622b | |||
| db000da517 | |||
| c043eb0377 | |||
| d6c66d2a07 | |||
| 414597e940 | |||
| 1fa9ab2495 | |||
| ab4cb0ea5a | |||
| 22b035dbee | |||
| 8e23558d90 | |||
| ae07eef885 | |||
| 912fc54f25 | |||
| b884397f1f | |||
| f03825a6db | |||
| e506d26450 | |||
| 138fdeb68b | |||
| dae9510981 | |||
|
Mmuq
|
d3a7e9ef0f | ||
|
Mmuq
|
93ae08bc91 | ||
| 0642dcc2db | |||
|
J
|
ea8ed56a7d | ||
| 84a7893c8f | |||
| 8e542919a8 | |||
| 27049f00ea | |||
| 78ecd171bd | |||
| 6fb73c1daa | |||
| 83515d6e3f | |||
| 638fe5df1f | |||
| efc0948110 | |||
|
J
|
5035f0654a | ||
|
J
|
8c247f9f7a | ||
|
J
|
b955256479 | ||
|
J
|
20fe86d399 | ||
|
J
|
87bc78e063 | ||
| 8f39c4d855 | |||
| 195db4a27d | |||
|
J
|
84f3a63e8b | ||
|
J
|
a268f2ab25 | ||
| 5718e109b5 | |||
| d81c61c3cf | |||
| 54b9bd4fc8 | |||
| 1005228d69 | |||
| 3eb084bc08 | |||
| c7b88f1f14 | |||
| e1025794a8 | |||
| cfa1da9f4d | |||
| 0e3e022084 | |||
| 2182899162 | |||
| da9a0b07bd | |||
| 3e9ac43800 | |||
| 44df45160e | |||
| f67c995846 | |||
| c36fdcf607 | |||
|
F
|
5d909c4a22 | ||
|
F
|
5cfced8855 | ||
|
F
|
ee2ce3b1f4 | ||
|
F
|
5b1c51797b | ||
| 9a960e2f29 | |||
|
F
|
2bb2d9d5a7 | ||
|
F
|
11d9532b5c | ||
| 7335dc4c52 | |||
| 019b0c6f4b | |||
|
Mmuq
|
e41f061caa | ||
|
Mmuq
|
16ac8dbfb6 | ||
| af3ae03baf | |||
| 5c0c20619f | |||
| 4ee8ee5fe0 | |||
| f7eedfa2bd | |||
| fc6a1824a5 |
1
.gitignore
vendored
1
.gitignore
vendored
|
|
@ -52,6 +52,7 @@ tests/sdr/
|
||||||
|
|
||||||
# Sphinx documentation
|
# Sphinx documentation
|
||||||
docs/build/
|
docs/build/
|
||||||
|
docs/_build/
|
||||||
|
|
||||||
# Jupyter Notebook
|
# Jupyter Notebook
|
||||||
.ipynb_checkpoints
|
.ipynb_checkpoints
|
||||||
|
|
|
||||||
36
CHANGELOG.md
Normal file
36
CHANGELOG.md
Normal file
|
|
@ -0,0 +1,36 @@
|
||||||
|
# Changelog
|
||||||
|
|
||||||
|
## [0.1.0] - 2026-02-20
|
||||||
|
|
||||||
|
### Added
|
||||||
|
- **Dual-Threshold Detection:** Logic to capture the start and end of signals, not just the peak.
|
||||||
|
- **Signal Smoothing & Noise Filters:** Prevents detections from breaking into fragments and ignores short interference spikes.
|
||||||
|
- **Auto-Frequency Calculation:** Automatically adjusts bounding boxes to fit signal frequency ranges tightly.
|
||||||
|
|
||||||
|
### Changed
|
||||||
|
- **Signal Power Detection:** Switched from raw signal strength to power for improved accuracy.
|
||||||
|
- **CLI Workflow:** `Clear` and `Remove` commands now modify files directly (in-place) to avoid redundant copies.
|
||||||
|
- **Metadata Logic:** Updated labels to show detection percentages and overhauled internal metadata cleaning.
|
||||||
|
- **Viewer UI:** Moved legend outside the plot, added a black background, and adjusted transparency for better spectrogram visibility.
|
||||||
|
|
||||||
|
### Fixed
|
||||||
|
- Prevented redundant `_annotated` suffixes in file naming patterns.
|
||||||
|
- Simplified internal math to increase processing speed and precision.
|
||||||
|
All notable changes to this project will be documented in this file.
|
||||||
|
|
||||||
|
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/) and [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
|
||||||
|
|
||||||
|
---
|
||||||
|
## [0.1.1] - 2026-03-20
|
||||||
|
|
||||||
|
### Added
|
||||||
|
|
||||||
|
- **Campaign orchestration** — new `orchestration` module that manages the full lifecycle of an RF data collection campaign: SDR capture, automatic labeling, QA checks, and dataset packaging.
|
||||||
|
- **HTTP inference server** — `ria-server` command starts a REST API server for deploying campaigns and controlling live inference from external systems such as the RIA Hub platform.
|
||||||
|
- **Campaign CLI** — `ria campaign` commands for starting, monitoring, and managing campaigns from the terminal.
|
||||||
|
|
||||||
|
### Changed
|
||||||
|
|
||||||
|
- **Visualization layout** — recording and dataset views have been reformatted with improved sizing, repositioned titles, and updated Qoherent branding.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
@ -159,7 +159,7 @@ Finally, RIA Toolkit OSS can be installed directly from the source code. This ap
|
||||||
Once the project is installed, you can import modules, functions, and classes from the Toolkit for use in your Python code. For example, you can use the following import statement to access the `Recording` object:
|
Once the project is installed, you can import modules, functions, and classes from the Toolkit for use in your Python code. For example, you can use the following import statement to access the `Recording` object:
|
||||||
|
|
||||||
```python
|
```python
|
||||||
from ria_toolkit_oss.datatypes import Recording
|
from ria_toolkit_oss.data import Recording
|
||||||
```
|
```
|
||||||
|
|
||||||
Additional usage information is provided in the project documentation: [RIA Toolkit OSS Documentation](https://ria-toolkit-oss.readthedocs.io/).
|
Additional usage information is provided in the project documentation: [RIA Toolkit OSS Documentation](https://ria-toolkit-oss.readthedocs.io/).
|
||||||
|
|
|
||||||
1083
docs/_build/html/_sources/intro/getting_started.rst.txt
vendored
Normal file
1083
docs/_build/html/_sources/intro/getting_started.rst.txt
vendored
Normal file
File diff suppressed because it is too large
Load Diff
29
docs/source/_static/custom.css
Normal file
29
docs/source/_static/custom.css
Normal file
|
|
@ -0,0 +1,29 @@
|
||||||
|
/* Change the hex values below to customize heading colours */
|
||||||
|
|
||||||
|
.rst-content h1 { color: #2c3e50; }
|
||||||
|
.rst-content h2,
|
||||||
|
.rst-content h2 a { color: #ffffff !important; font-size: 22px !important; }
|
||||||
|
|
||||||
|
.rst-content h3,
|
||||||
|
.rst-content h3 a { color: #ffffff !important; font-size: 16px !important; }
|
||||||
|
|
||||||
|
.rst-content h3 code { font-size: inherit !important; }
|
||||||
|
|
||||||
|
.rst-content .admonition.warning {
|
||||||
|
background: #1a1a2e !important;
|
||||||
|
border-left: 4px solid #c0392b !important;
|
||||||
|
}
|
||||||
|
|
||||||
|
.rst-content .admonition.warning .admonition-title {
|
||||||
|
background: #c0392b !important;
|
||||||
|
color: #ffffff !important;
|
||||||
|
}
|
||||||
|
|
||||||
|
.rst-content .admonition.warning p {
|
||||||
|
color: #ffffff !important;
|
||||||
|
}
|
||||||
|
.rst-content h4 { color: #404040; }
|
||||||
|
|
||||||
|
.highlight * { color: #ffffff !important; }
|
||||||
|
|
||||||
|
.ria-cmd { color: #2980b9 !important; }
|
||||||
8
docs/source/_static/custom.js
Normal file
8
docs/source/_static/custom.js
Normal file
|
|
@ -0,0 +1,8 @@
|
||||||
|
document.addEventListener('DOMContentLoaded', function () {
|
||||||
|
document.querySelectorAll('.highlight pre').forEach(function (pre) {
|
||||||
|
pre.innerHTML = pre.innerHTML.replace(
|
||||||
|
/((?:^|\n|>))(ria)(?=[ \t]|<)/g,
|
||||||
|
'$1<span class="ria-cmd">$2</span>'
|
||||||
|
);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
@ -12,9 +12,9 @@ sys.path.insert(0, os.path.abspath(os.path.join('..', '..')))
|
||||||
# https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information
|
# https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information
|
||||||
|
|
||||||
project = 'ria-toolkit-oss'
|
project = 'ria-toolkit-oss'
|
||||||
copyright = '2025, Qoherent Inc'
|
copyright = '2026, Qoherent Inc'
|
||||||
author = 'Qoherent Inc.'
|
author = 'Qoherent Inc.'
|
||||||
release = '0.1.4'
|
release = '0.1.5'
|
||||||
|
|
||||||
# -- General configuration ---------------------------------------------------
|
# -- General configuration ---------------------------------------------------
|
||||||
# https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration
|
# https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration
|
||||||
|
|
@ -73,3 +73,6 @@ def setup(app):
|
||||||
# https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output
|
# https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output
|
||||||
|
|
||||||
html_theme = 'sphinx_rtd_theme'
|
html_theme = 'sphinx_rtd_theme'
|
||||||
|
html_static_path = ['_static']
|
||||||
|
html_css_files = ['custom.css']
|
||||||
|
html_js_files = ['custom.js']
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
.. _examples:
|
.. _sdr_examples:
|
||||||
|
|
||||||
############
|
############
|
||||||
SDR Examples
|
SDR Examples
|
||||||
|
|
|
||||||
|
|
@ -25,7 +25,7 @@ In this example, we initialize the `Blade` SDR, configure it to record a signal
|
||||||
|
|
||||||
import time
|
import time
|
||||||
|
|
||||||
from ria_toolkit_oss.datatypes.recording import Recording
|
from ria_toolkit_oss.data.recording import Recording
|
||||||
from ria_toolkit_oss.sdr.blade import Blade
|
from ria_toolkit_oss.sdr.blade import Blade
|
||||||
|
|
||||||
my_radio = Blade()
|
my_radio = Blade()
|
||||||
|
|
|
||||||
|
|
@ -21,7 +21,7 @@ Code
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from ria_toolkit_oss.datatypes.recording import Recording
|
from ria_toolkit_oss.data.recording import Recording
|
||||||
from ria_toolkit_oss.sdr.blade import Blade
|
from ria_toolkit_oss.sdr.blade import Blade
|
||||||
|
|
||||||
# Parameters
|
# Parameters
|
||||||
|
|
|
||||||
File diff suppressed because it is too large
Load Diff
|
|
@ -11,15 +11,15 @@ The Radio Dataset Framework provides a software interface to access and manipula
|
||||||
the need for users to interface with the source files directly. Instead, users initialize and interact with a Python
|
the need for users to interface with the source files directly. Instead, users initialize and interact with a Python
|
||||||
object, while the complexities of efficient data retrieval and source file manipulation are managed behind the scenes.
|
object, while the complexities of efficient data retrieval and source file manipulation are managed behind the scenes.
|
||||||
|
|
||||||
Utils includes an abstract class called :py:obj:`ria_toolkit_oss.datatypes.datasets.RadioDataset`, which defines common properties and
|
Ria Toolkit OSS includes an abstract class called :py:obj:`ria_toolkit_oss.data.datasets.RadioDataset`, which defines common properties and
|
||||||
behaviors for all radio datasets. :py:obj:`ria_toolkit_oss.datatypes.datasets.RadioDataset` can be considered a blueprint for all
|
behaviors for all radio datasets. :py:obj:`ria_toolkit_oss.data.datasets.RadioDataset` can be considered a blueprint for all
|
||||||
other radio dataset classes. This class is then subclassed to define more specific blueprints for different types
|
other radio dataset classes. This class is then subclassed to define more specific blueprints for different types
|
||||||
of radio datasets. For example, :py:obj:`ria_toolkit_oss.datatypes.datasets.IQDataset`, which is tailored for machine learning tasks
|
of radio datasets. For example, :py:obj:`ria_toolkit_oss.data.datasets.IQDataset`, which is tailored for machine learning tasks
|
||||||
involving the processing of signals represented as IQ (In-phase and Quadrature) samples.
|
involving the processing of signals represented as IQ (In-phase and Quadrature) samples.
|
||||||
|
|
||||||
Then, in the various project backends, there are concrete dataset classes, which inherit from both Utils and the base
|
Then, in the various project backends, there are concrete dataset classes, which inherit from both Ria Toolkit OSS and the base
|
||||||
dataset class from the respective backend. For example, the :py:obj:`TorchIQDataset` class extends both
|
dataset class from the respective backend. For example, the :py:obj:`TorchIQDataset` class extends both
|
||||||
:py:obj:`ria_toolkit_oss.datatypes.datasets.IQDataset` from Utils and :py:obj:`torch.ria_toolkit_oss.datatypes.IterableDataset` from
|
:py:obj:`ria_toolkit_oss.data.datasets.IQDataset` from Ria Toolkit OSS and :py:obj:`torch.ria_toolkit_oss.data.IterableDataset` from
|
||||||
PyTorch, providing a concrete dataset class tailored for IQ datasets and optimized for the PyTorch backend.
|
PyTorch, providing a concrete dataset class tailored for IQ datasets and optimized for the PyTorch backend.
|
||||||
|
|
||||||
Dataset initialization
|
Dataset initialization
|
||||||
|
|
@ -130,7 +130,7 @@ Dataset processing and manipulation
|
||||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
All radio datasets support methods tailored specifically for radio processing. These methods are backend-independent,
|
All radio datasets support methods tailored specifically for radio processing. These methods are backend-independent,
|
||||||
inherited from the blueprints in Utils like :py:obj:`ria_toolkit_oss.datatypes.datasets.RadioDataset`.
|
inherited from the blueprints in Ria Toolkit OSS like :py:obj:`ria_toolkit_oss.data.datasets.RadioDataset`.
|
||||||
|
|
||||||
For example, we can trim down the length of the examples from 1,024 to 512 samples, and then augment the dataset:
|
For example, we can trim down the length of the examples from 1,024 to 512 samples, and then augment the dataset:
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
Dataset License SubModule
|
Dataset License SubModule
|
||||||
=========================
|
=========================
|
||||||
|
|
||||||
.. automodule:: ria_toolkit_oss.datatypes.datasets.license
|
.. automodule:: ria_toolkit_oss.data.datasets.license
|
||||||
:members:
|
:members:
|
||||||
:undoc-members:
|
:undoc-members:
|
||||||
:show-inheritance:
|
:show-inheritance:
|
||||||
|
|
|
||||||
|
|
@ -1,11 +1,11 @@
|
||||||
Datatypes Package (ria_toolkit_oss.datatypes)
|
Datatypes Package (ria_toolkit_oss.data)
|
||||||
=============================================
|
=============================================
|
||||||
|
|
||||||
.. |br| raw:: html
|
.. |br| raw:: html
|
||||||
|
|
||||||
<br />
|
<br />
|
||||||
|
|
||||||
.. automodule:: ria_toolkit_oss.datatypes
|
.. automodule:: ria_toolkit_oss.data
|
||||||
:members:
|
:members:
|
||||||
:undoc-members:
|
:undoc-members:
|
||||||
:show-inheritance:
|
:show-inheritance:
|
||||||
|
|
@ -13,7 +13,7 @@ Datatypes Package (ria_toolkit_oss.datatypes)
|
||||||
Radio Dataset SubPackage
|
Radio Dataset SubPackage
|
||||||
------------------------
|
------------------------
|
||||||
|
|
||||||
.. automodule:: ria_toolkit_oss.datatypes.datasets
|
.. automodule:: ria_toolkit_oss.data.datasets
|
||||||
:members:
|
:members:
|
||||||
:undoc-members:
|
:undoc-members:
|
||||||
:show-inheritance:
|
:show-inheritance:
|
||||||
|
|
@ -21,5 +21,5 @@ Radio Dataset SubPackage
|
||||||
.. toctree::
|
.. toctree::
|
||||||
:maxdepth: 2
|
:maxdepth: 2
|
||||||
|
|
||||||
Dataset License SubModule <ria_toolkit_oss.datatypes.datasets.license>
|
Dataset License SubModule <ria_toolkit_oss.data.datasets.license>
|
||||||
Radio Datasets <radio_datasets>
|
Radio Datasets <radio_datasets>
|
||||||
|
|
|
||||||
|
|
@ -11,7 +11,7 @@ class and function signatures, and doctest examples where available.
|
||||||
:maxdepth: 2
|
:maxdepth: 2
|
||||||
:caption: Contents:
|
:caption: Contents:
|
||||||
|
|
||||||
Datatypes Package <datatypes/ria_toolkit_oss.datatypes>
|
Data Package <data/ria_toolkit_oss.data>
|
||||||
SDR Package <ria_toolkit_oss.sdr>
|
SDR Package <ria_toolkit_oss.sdr>
|
||||||
IO Package <ria_toolkit_oss.io>
|
IO Package <ria_toolkit_oss.io>
|
||||||
Transforms Package <ria_toolkit_oss.transforms>
|
Transforms Package <ria_toolkit_oss.transforms>
|
||||||
|
|
|
||||||
|
|
@ -40,34 +40,44 @@ Limitations
|
||||||
- USB 3.0 connectivity is required for optimal performance; using USB 2.0 will significantly limit data
|
- USB 3.0 connectivity is required for optimal performance; using USB 2.0 will significantly limit data
|
||||||
transfer rates.
|
transfer rates.
|
||||||
|
|
||||||
Set up instructions (Linux, Radioconda)
|
Set up instructions (Linux)
|
||||||
---------------------------------------
|
---------------------------
|
||||||
|
|
||||||
1. Activate your Radioconda environment.
|
No additional Python packages are required for BladeRF beyond the base RIA Toolkit OSS installation.
|
||||||
|
|
||||||
|
1. Install the system library:
|
||||||
|
|
||||||
.. code-block:: bash
|
.. code-block:: bash
|
||||||
|
|
||||||
conda activate <your-env-name>
|
sudo apt install libbladerf-dev
|
||||||
|
|
||||||
2. Install the base dependencies and drivers (*Easy method*):
|
For a more complete installation including CLI tools and FPGA images, use the Nuand PPA:
|
||||||
|
|
||||||
.. code-block:: bash
|
.. code-block:: bash
|
||||||
|
|
||||||
sudo add-apt-repository ppa:nuandllc/bladerf
|
sudo add-apt-repository ppa:nuandllc/bladerf
|
||||||
sudo apt-get update
|
sudo apt-get update
|
||||||
sudo apt-get install bladerf
|
sudo apt-get install bladerf libbladerf-dev
|
||||||
sudo apt-get install libbladerf-dev
|
sudo apt-get install bladerf-fpga-hostedxa4 # Necessary for BladeRF 2.0 Micro xA4
|
||||||
sudo apt-get install bladerf-fpga-hostedxa4 # Necessary for installation of bladeRF 2.0 Micro A4.
|
|
||||||
|
|
||||||
3. Install a ``udev`` rule by creating a link into your Radioconda installation:
|
2. Install udev rules:
|
||||||
|
|
||||||
|
For most users:
|
||||||
|
|
||||||
.. code-block:: bash
|
.. code-block:: bash
|
||||||
|
|
||||||
sudo ln -s $CONDA_PREFIX/lib/udev/rules.d/88-nuand-bladerf1.rules /etc/udev/rules.d/88-radioconda-nuand-bladerf1.rules
|
sudo udevadm control --reload
|
||||||
sudo ln -s $CONDA_PREFIX/lib/udev/rules.d/88-nuand-bladerf2.rules /etc/udev/rules.d/88-radioconda-nuand-bladerf2.rules
|
sudo udevadm trigger
|
||||||
sudo ln -s $CONDA_PREFIX/lib/udev/rules.d/88-nuand-bootloader.rules /etc/udev/rules.d/88-radioconda-nuand-bootloader.rules
|
|
||||||
sudo udevadm control --reload
|
For **Radioconda** users, create symlinks from your conda environment instead:
|
||||||
sudo udevadm trigger
|
|
||||||
|
.. code-block:: bash
|
||||||
|
|
||||||
|
sudo ln -s $CONDA_PREFIX/lib/udev/rules.d/88-nuand-bladerf1.rules /etc/udev/rules.d/88-radioconda-nuand-bladerf1.rules
|
||||||
|
sudo ln -s $CONDA_PREFIX/lib/udev/rules.d/88-nuand-bladerf2.rules /etc/udev/rules.d/88-radioconda-nuand-bladerf2.rules
|
||||||
|
sudo ln -s $CONDA_PREFIX/lib/udev/rules.d/88-nuand-bootloader.rules /etc/udev/rules.d/88-radioconda-nuand-bootloader.rules
|
||||||
|
sudo udevadm control --reload
|
||||||
|
sudo udevadm trigger
|
||||||
|
|
||||||
Further Information
|
Further Information
|
||||||
-------------------
|
-------------------
|
||||||
|
|
|
||||||
|
|
@ -39,39 +39,44 @@ Limitations
|
||||||
- Bandwidth is limited to 20 MHz.
|
- Bandwidth is limited to 20 MHz.
|
||||||
- USB 2.0 connectivity might limit data transfer rates compared to USB 3.0 or Ethernet-based SDRs.
|
- USB 2.0 connectivity might limit data transfer rates compared to USB 3.0 or Ethernet-based SDRs.
|
||||||
|
|
||||||
Set up instructions (Linux, Radioconda)
|
Set up instructions (Linux)
|
||||||
---------------------------------------
|
---------------------------
|
||||||
|
|
||||||
1. Activate your Radioconda environment:
|
HackRF is supported out of the box after installing RIA Toolkit OSS.
|
||||||
|
|
||||||
|
1. Ensure ``libhackrf`` is installed at the system level. On most Ubuntu installations this is already
|
||||||
|
present. If not:
|
||||||
|
|
||||||
.. code-block:: bash
|
.. code-block:: bash
|
||||||
|
|
||||||
conda activate <your-env-name>
|
sudo apt install libhackrf-dev
|
||||||
|
|
||||||
2. Install the System Package (Ubuntu / Debian):
|
2. Install udev rules to allow non-root device access:
|
||||||
|
|
||||||
|
For most users:
|
||||||
|
|
||||||
.. code-block:: bash
|
.. code-block:: bash
|
||||||
|
|
||||||
sudo apt-get update
|
sudo udevadm control --reload
|
||||||
sudo apt-get install hackrf
|
sudo udevadm trigger
|
||||||
|
|
||||||
3. Install a ``udev`` rule by creating a link into your Radioconda installation:
|
For **Radioconda** users, create a symlink from your conda environment instead:
|
||||||
|
|
||||||
.. code-block:: bash
|
.. code-block:: bash
|
||||||
|
|
||||||
sudo ln -s $CONDA_PREFIX/lib/udev/rules.d/53-hackrf.rules /etc/udev/rules.d/53-radioconda-hackrf.rules
|
sudo ln -s $CONDA_PREFIX/lib/udev/rules.d/53-hackrf.rules /etc/udev/rules.d/53-radioconda-hackrf.rules
|
||||||
sudo udevadm control --reload
|
sudo udevadm control --reload
|
||||||
sudo udevadm trigger
|
sudo udevadm trigger
|
||||||
|
|
||||||
Make sure your user account belongs to the plugdev group in order to access your device:
|
Make sure your user account belongs to the ``plugdev`` group in order to access your device:
|
||||||
|
|
||||||
.. code-block:: bash
|
.. code-block:: bash
|
||||||
|
|
||||||
sudo usermod -a -G plugdev <user>
|
sudo usermod -a -G plugdev <user>
|
||||||
|
|
||||||
.. note::
|
.. note::
|
||||||
|
|
||||||
You may have to restart your system for changes to take effect.
|
You may have to restart your system for group membership changes to take effect.
|
||||||
|
|
||||||
Further information
|
Further information
|
||||||
-------------------
|
-------------------
|
||||||
|
|
|
||||||
|
|
@ -43,34 +43,34 @@ Limitations
|
||||||
affect stability.
|
affect stability.
|
||||||
- USB 2.0 connectivity might limit data transfer rates compared to USB 3.0 or Ethernet-based SDRs.
|
- USB 2.0 connectivity might limit data transfer rates compared to USB 3.0 or Ethernet-based SDRs.
|
||||||
|
|
||||||
Set up instructions (Linux, Radioconda)
|
Set up instructions (Linux)
|
||||||
---------------------------------------
|
---------------------------
|
||||||
|
|
||||||
1. Activate your Radioconda environment:
|
The PlutoSDR is supported out of the box after installing RIA Toolkit OSS. The required Python package
|
||||||
|
(``pyadi-iio``) is included in the toolkit's dependencies.
|
||||||
|
|
||||||
|
1. Ensure ``libiio`` is installed at the system level. On most Ubuntu installations this is already present.
|
||||||
|
If not:
|
||||||
|
|
||||||
.. code-block:: bash
|
.. code-block:: bash
|
||||||
|
|
||||||
conda activate <your-env-name>
|
sudo apt install libiio-dev libiio-utils libiio0
|
||||||
|
|
||||||
2. Install system dependencies:
|
.. note::
|
||||||
|
|
||||||
|
PlutoSDR devices are discoverable over both USB and network (mDNS). Network discovery uses Avahi — if
|
||||||
|
``avahi-daemon`` is not running, network discovery will be skipped but USB discovery still works.
|
||||||
|
|
||||||
|
2. Install a ``udev`` rule to allow non-root device access:
|
||||||
|
|
||||||
|
For most users:
|
||||||
|
|
||||||
.. code-block:: bash
|
.. code-block:: bash
|
||||||
|
|
||||||
sudo apt-get update
|
sudo udevadm control --reload
|
||||||
sudo apt-get install -y \
|
sudo udevadm trigger
|
||||||
build-essential \
|
|
||||||
git \
|
|
||||||
libxml2-dev \
|
|
||||||
bison \
|
|
||||||
flex \
|
|
||||||
libcdk5-dev \
|
|
||||||
cmake \
|
|
||||||
libusb-1.0-0-dev \
|
|
||||||
libavahi-client-dev \
|
|
||||||
libavahi-common-dev \
|
|
||||||
libaio-dev
|
|
||||||
|
|
||||||
3. Install a ``udev`` rule by creating a link into your Radioconda installation:
|
For **Radioconda** users, create a symlink from your conda environment instead:
|
||||||
|
|
||||||
.. code-block:: bash
|
.. code-block:: bash
|
||||||
|
|
||||||
|
|
@ -78,11 +78,18 @@ Set up instructions (Linux, Radioconda)
|
||||||
sudo udevadm control --reload
|
sudo udevadm control --reload
|
||||||
sudo udevadm trigger
|
sudo udevadm trigger
|
||||||
|
|
||||||
Once you can talk to the hardware, you may want to perform the post-install steps detailed on the `PlutoSDR Documentation <https://wiki.analog.com/university/tools/pluto>`_.
|
Once you can communicate with the hardware, you may want to perform the post-install steps detailed on
|
||||||
|
the `PlutoSDR Documentation <https://wiki.analog.com/university/tools/pluto>`_.
|
||||||
|
|
||||||
4. (Optional) Building ``libiio`` or ``libad9361-iio`` from source:
|
3. (Optional) Building ``libiio`` or ``libad9361-iio`` from source:
|
||||||
|
|
||||||
This step is only required if you want the latest version of these libraries not provided in Radioconda.
|
This step is only required if you need a version not available via ``apt``. First install build
|
||||||
|
dependencies:
|
||||||
|
|
||||||
|
.. code-block:: bash
|
||||||
|
|
||||||
|
sudo apt-get install -y build-essential git libxml2-dev bison flex libcdk5-dev cmake \
|
||||||
|
libusb-1.0-0-dev libavahi-client-dev libavahi-common-dev libaio-dev
|
||||||
|
|
||||||
.. code-block:: bash
|
.. code-block:: bash
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -30,18 +30,10 @@ Limitations
|
||||||
- Sensitivity and performance can vary depending on the specific model and components.
|
- Sensitivity and performance can vary depending on the specific model and components.
|
||||||
- Requires external software for signal processing and analysis.
|
- Requires external software for signal processing and analysis.
|
||||||
|
|
||||||
Set up instructions (Linux, Radioconda)
|
Set up instructions (Linux)
|
||||||
---------------------------------------
|
---------------------------
|
||||||
|
|
||||||
1. Activate your Radioconda environment:
|
1. If you previously had RTL-SDR drivers installed, purge them first:
|
||||||
|
|
||||||
.. code-block:: bash
|
|
||||||
|
|
||||||
conda activate <your-env-name>
|
|
||||||
|
|
||||||
2. Purge drivers:
|
|
||||||
|
|
||||||
If you already have other drivers installed, purge them from your system.
|
|
||||||
|
|
||||||
.. code-block:: bash
|
.. code-block:: bash
|
||||||
|
|
||||||
|
|
@ -53,47 +45,95 @@ If you already have other drivers installed, purge them from your system.
|
||||||
sudo rm -rvf /usr/local/include/rtl_*
|
sudo rm -rvf /usr/local/include/rtl_*
|
||||||
sudo rm -rvf /usr/local/bin/rtl_*
|
sudo rm -rvf /usr/local/bin/rtl_*
|
||||||
|
|
||||||
3. Install RTL-SDR Blog drivers:
|
2. Install build dependencies:
|
||||||
|
|
||||||
.. code-block:: bash
|
.. code-block:: bash
|
||||||
|
|
||||||
sudo apt-get install libusb-1.0-0-dev git cmake pkg-config build-essential
|
sudo apt install libusb-1.0-0-dev git cmake pkg-config build-essential
|
||||||
git clone https://github.com/osmocom/rtl-sdr
|
|
||||||
cd rtl-sdr
|
3. Build ``librtlsdr`` from source:
|
||||||
mkdir build
|
|
||||||
cd build
|
The standard ``librtlsdr`` package available via ``apt`` is missing symbols required by the Python
|
||||||
cmake ../ -DINSTALL_UDEV_RULES=ON
|
bindings. Build from the **rtl-sdr-blog fork**:
|
||||||
|
|
||||||
|
.. code-block:: bash
|
||||||
|
|
||||||
|
git clone https://github.com/rtlsdrblog/rtl-sdr-blog.git
|
||||||
|
cd rtl-sdr-blog
|
||||||
|
mkdir build && cd build
|
||||||
|
cmake .. -DINSTALL_UDEV_RULES=ON
|
||||||
make
|
make
|
||||||
sudo make install
|
sudo make install
|
||||||
sudo cp ../rtl-sdr.rules /etc/udev/rules.d/
|
sudo cp ../rtl-sdr.rules /etc/udev/rules.d/
|
||||||
sudo ldconfig
|
sudo ldconfig
|
||||||
|
|
||||||
4. Blacklist the DVB-T modules that would otherwise claim the device:
|
.. important::
|
||||||
|
|
||||||
|
Do not use the osmocom ``rtl-sdr`` repository or the Ubuntu ``librtlsdr-dev`` apt package. Neither
|
||||||
|
provides the ``rtlsdr_set_dithering`` symbol that the Python bindings require.
|
||||||
|
|
||||||
|
4. Blacklist the kernel DVB driver:
|
||||||
|
|
||||||
|
The kernel DVB-T driver (``dvb_usb_rtl28xxu``) claims the RTL-SDR device and prevents ``librtlsdr``
|
||||||
|
from accessing it.
|
||||||
|
|
||||||
|
For most users:
|
||||||
|
|
||||||
.. code-block:: bash
|
.. code-block:: bash
|
||||||
|
|
||||||
|
echo 'blacklist dvb_usb_rtl28xxu' | sudo tee /etc/modprobe.d/blacklist-rtlsdr.conf
|
||||||
|
sudo modprobe -r dvb_usb_rtl28xxu
|
||||||
|
|
||||||
|
For **Radioconda** users, a blacklist configuration is already provided in your conda environment:
|
||||||
|
|
||||||
|
.. code-block:: bash
|
||||||
|
|
||||||
sudo ln -s $CONDA_PREFIX/etc/modprobe.d/rtl-sdr-blacklist.conf /etc/modprobe.d/radioconda-rtl-sdr-blacklist.conf
|
sudo ln -s $CONDA_PREFIX/etc/modprobe.d/rtl-sdr-blacklist.conf /etc/modprobe.d/radioconda-rtl-sdr-blacklist.conf
|
||||||
sudo modprobe -r $(cat $CONDA_PREFIX/etc/modprobe.d/rtl-sdr-blacklist.conf | sed -n -e 's/^blacklist //p')
|
sudo modprobe -r $(cat $CONDA_PREFIX/etc/modprobe.d/rtl-sdr-blacklist.conf | sed -n -e 's/^blacklist //p')
|
||||||
|
|
||||||
.. note::
|
If ``modprobe -r`` fails with "Module is in use", unplug the RTL-SDR dongle, run the command again,
|
||||||
|
then plug it back in. Alternatively, reboot — the blacklist takes effect on next boot.
|
||||||
|
|
||||||
In addition to the Radioconda blacklist file, some systems also require
|
.. note::
|
||||||
manually blacklisting the following DVB-T modules to prevent them from
|
|
||||||
claiming the device:
|
|
||||||
|
|
||||||
- ``dvb_usb_rtl28xxu``
|
Some systems also require blacklisting additional DVB-T modules. Add these entries to your
|
||||||
- ``rtl2832``
|
blacklist configuration if needed:
|
||||||
- ``rtl2830``
|
|
||||||
|
|
||||||
Add these entries to ``rtlsdr.conf`` (or create the file at
|
- ``rtl2832``
|
||||||
``/etc/modprobe.d/rtlsdr.conf``) if they are not already present.
|
- ``rtl2830``
|
||||||
|
|
||||||
5. Install a udev rule by creating a link into your radioconda installation:
|
5. Reload udev rules:
|
||||||
|
|
||||||
|
For most users (rules are installed by the build step above):
|
||||||
|
|
||||||
.. code-block:: bash
|
.. code-block:: bash
|
||||||
|
|
||||||
|
sudo udevadm control --reload
|
||||||
|
sudo udevadm trigger
|
||||||
|
|
||||||
|
For **Radioconda** users, create a symlink from your conda environment instead:
|
||||||
|
|
||||||
|
.. code-block:: bash
|
||||||
|
|
||||||
sudo ln -s $CONDA_PREFIX/lib/udev/rules.d/rtl-sdr.rules /etc/udev/rules.d/radioconda-rtl-sdr.rules
|
sudo ln -s $CONDA_PREFIX/lib/udev/rules.d/rtl-sdr.rules /etc/udev/rules.d/radioconda-rtl-sdr.rules
|
||||||
sudo udevadm control --reload
|
sudo udevadm control --reload
|
||||||
sudo udevadm trigger
|
sudo udevadm trigger
|
||||||
|
|
||||||
|
6. Install Python packages:
|
||||||
|
|
||||||
|
.. code-block:: bash
|
||||||
|
|
||||||
|
pip install pyrtlsdr==0.3.0
|
||||||
|
pip install setuptools==69.5.1
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
|
||||||
|
``pyrtlsdr`` 0.4.0 references a ``rtlsdr_set_dithering`` symbol not present in standard
|
||||||
|
``librtlsdr`` builds. Version 0.3.0 works correctly.
|
||||||
|
|
||||||
|
``pyrtlsdr`` 0.3.0 depends on ``pkg_resources``, which was removed in ``setuptools`` >= 82.
|
||||||
|
Pinning to 69.5.1 ensures ``pkg_resources`` is available.
|
||||||
|
|
||||||
Further Information
|
Further Information
|
||||||
-------------------
|
-------------------
|
||||||
- `RTL-SDR Official Website <https://www.rtl-sdr.com/>`_
|
- `RTL-SDR Official Website <https://www.rtl-sdr.com/>`_
|
||||||
|
|
|
||||||
|
|
@ -39,18 +39,48 @@ Limitations
|
||||||
Set up instructions (Linux)
|
Set up instructions (Linux)
|
||||||
---------------------------------
|
---------------------------------
|
||||||
|
|
||||||
Install PyRF
|
ThinkRF devices require the ``pyrf`` package, which is written in Python 2 syntax and must be patched
|
||||||
|
after installation to work with Python 3.
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
|
||||||
|
``lib2to3`` was fully removed in Python 3.13. ThinkRF support is currently limited to
|
||||||
|
**Python 3.12 and below**.
|
||||||
|
|
||||||
|
1. Install ``lib2to3``:
|
||||||
|
|
||||||
|
On some distributions (including Ubuntu 24.04+), ``lib2to3`` is not included by default:
|
||||||
|
|
||||||
.. code-block:: bash
|
.. code-block:: bash
|
||||||
|
|
||||||
pip install 'pyrf>=2.8.0'
|
sudo apt install python3-lib2to3
|
||||||
|
|
||||||
Convert PyRF scripts to Python 3
|
2. Install ``pyrf``:
|
||||||
|
|
||||||
.. code-block:: bash
|
.. code-block:: bash
|
||||||
|
|
||||||
cd ../scripts
|
pip install pyrf
|
||||||
./convert_pyrf_to_python3.sh
|
|
||||||
|
3. Patch ``pyrf`` for Python 3:
|
||||||
|
|
||||||
|
The ``pyrf`` package contains Python 2 syntax throughout (e.g., ``dict.iteritems()``, ``print``
|
||||||
|
statements). Run the following to automatically convert the entire package to Python 3:
|
||||||
|
|
||||||
|
.. code-block:: bash
|
||||||
|
|
||||||
|
python -c "
|
||||||
|
from lib2to3.refactor import RefactoringTool, get_fixers_from_package
|
||||||
|
import pyrf, os
|
||||||
|
pyrf_path = os.path.dirname(pyrf.__file__)
|
||||||
|
fixers = get_fixers_from_package('lib2to3.fixes')
|
||||||
|
tool = RefactoringTool(fixers)
|
||||||
|
tool.refactor_dir(pyrf_path, write=True)
|
||||||
|
print('Done')
|
||||||
|
"
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
|
||||||
|
This patches the entire ``pyrf`` package in place, which is required for the driver to fully load.
|
||||||
|
|
||||||
Further Information
|
Further Information
|
||||||
-------------------
|
-------------------
|
||||||
|
|
|
||||||
|
|
@ -41,48 +41,111 @@ Limitations
|
||||||
- Compatibility with certain software tools may vary depending on the version of the UHD.
|
- Compatibility with certain software tools may vary depending on the version of the UHD.
|
||||||
- Price range can be a consideration, especially for high-end models.
|
- Price range can be a consideration, especially for high-end models.
|
||||||
|
|
||||||
Set up instructions (Linux, Radioconda)
|
Set up instructions (Linux)
|
||||||
---------------------------------------
|
---------------------------
|
||||||
|
|
||||||
1. Activate your Radioconda environment:
|
USRP devices require the UHD (USRP Hardware Driver) library with Python bindings. There is no pip-installable
|
||||||
|
UHD package — it must either be installed via conda or built from source.
|
||||||
|
|
||||||
|
**Option A: Install via conda (recommended for conda environments)**
|
||||||
|
|
||||||
.. code-block:: bash
|
.. code-block:: bash
|
||||||
|
|
||||||
conda activate <your-env-name>
|
conda install conda-forge::uhd
|
||||||
|
|
||||||
2. Install UHD and Python bindings:
|
**Option B: Build from source (required for pip/venv environments)**
|
||||||
|
|
||||||
.. code-block:: bash
|
The Python bindings must target the same Python version used in your virtual environment.
|
||||||
|
|
||||||
conda install conda-forge::uhd
|
1. Install build dependencies:
|
||||||
|
|
||||||
3. Download UHD images:
|
.. code-block:: bash
|
||||||
|
|
||||||
|
sudo apt install cmake build-essential libboost-all-dev libusb-1.0-0-dev \
|
||||||
|
python3-dev python3-numpy libncurses-dev
|
||||||
|
|
||||||
|
2. Install the Mako template library into your virtual environment (used by UHD's build system):
|
||||||
|
|
||||||
|
.. code-block:: bash
|
||||||
|
|
||||||
|
pip install mako
|
||||||
|
|
||||||
|
3. Clone and build UHD with your virtual environment activated:
|
||||||
|
|
||||||
|
.. code-block:: bash
|
||||||
|
|
||||||
|
git clone https://github.com/EttusResearch/uhd.git
|
||||||
|
cd uhd
|
||||||
|
git checkout v4.7.0.0
|
||||||
|
cd host
|
||||||
|
mkdir build && cd build
|
||||||
|
cmake -DENABLE_PYTHON_API=ON -DPYTHON_EXECUTABLE=$(which python3) ..
|
||||||
|
make -j$(nproc)
|
||||||
|
sudo make install
|
||||||
|
sudo ldconfig
|
||||||
|
|
||||||
|
.. important::
|
||||||
|
|
||||||
|
Run the ``cmake`` command with your virtual environment activated so ``$(which python3)`` points
|
||||||
|
to the correct interpreter. Before running ``make``, verify the cmake output includes::
|
||||||
|
|
||||||
|
-- * LibUHD - Python API → must say "Enabling"
|
||||||
|
-- Python interpreter: .../your-venv/bin/python3
|
||||||
|
|
||||||
|
If "LibUHD - Python API" is not listed under enabled components, the Python bindings will not be
|
||||||
|
built. The build typically takes 10–30 minutes.
|
||||||
|
|
||||||
|
4. Copy the Python bindings into your virtual environment if ``import uhd`` fails after installation:
|
||||||
|
|
||||||
|
.. code-block:: bash
|
||||||
|
|
||||||
|
cp -r ~/uhd/host/build/python/uhd ~/.venv/lib/python3.XX/site-packages/
|
||||||
|
|
||||||
|
Replace ``python3.XX`` with your Python version (e.g., ``python3.12``).
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
|
||||||
|
If you have a pre-existing UHD installation built against a different Python version, you will see
|
||||||
|
a circular import error. The bindings must match the Python version in your virtual environment exactly.
|
||||||
|
|
||||||
|
**After either installation method:**
|
||||||
|
|
||||||
|
1. Download UHD FPGA/firmware images:
|
||||||
|
|
||||||
.. code-block:: bash
|
.. code-block:: bash
|
||||||
|
|
||||||
uhd_images_downloader
|
uhd_images_downloader
|
||||||
|
|
||||||
4. Verify access to your device:
|
2. Verify device access:
|
||||||
|
|
||||||
.. code-block:: bash
|
.. code-block:: bash
|
||||||
|
|
||||||
uhd_find_devices
|
uhd_find_devices
|
||||||
|
|
||||||
For USB devices only (e.g. B series), install a ``udev`` rule by creating a link into your Radioconda installation.
|
For USB devices (e.g. B-series), install a ``udev`` rule.
|
||||||
|
|
||||||
.. code-block:: bash
|
For most users:
|
||||||
|
|
||||||
sudo ln -s $CONDA_PREFIX/lib/uhd/utils/uhd-usrp.rules /etc/udev/rules.d/radioconda-uhd-usrp.rules
|
.. code-block:: bash
|
||||||
sudo udevadm control --reload
|
|
||||||
sudo udevadm trigger
|
|
||||||
|
|
||||||
5. (Optional) Update firmware/FPGA images:
|
sudo udevadm control --reload
|
||||||
|
sudo udevadm trigger
|
||||||
|
|
||||||
.. code-block:: bash
|
For **Radioconda** users, create a symlink from your conda environment instead:
|
||||||
|
|
||||||
uhd_usrp_probe
|
.. code-block:: bash
|
||||||
|
|
||||||
This will ensure your device is running the latest firmware and FPGA versions.
|
sudo ln -s $CONDA_PREFIX/lib/uhd/utils/uhd-usrp.rules /etc/udev/rules.d/radioconda-uhd-usrp.rules
|
||||||
|
sudo udevadm control --reload
|
||||||
|
sudo udevadm trigger
|
||||||
|
|
||||||
|
3. (Optional) Update firmware/FPGA images:
|
||||||
|
|
||||||
|
.. code-block:: bash
|
||||||
|
|
||||||
|
uhd_usrp_probe
|
||||||
|
|
||||||
|
This will ensure your device is running the latest firmware and FPGA versions.
|
||||||
|
|
||||||
Further information
|
Further information
|
||||||
-------------------
|
-------------------
|
||||||
|
|
|
||||||
3136
poetry.lock
generated
3136
poetry.lock
generated
File diff suppressed because it is too large
Load Diff
|
|
@ -1,6 +1,6 @@
|
||||||
[project]
|
[project]
|
||||||
name = "ria-toolkit-oss"
|
name = "ria-toolkit-oss"
|
||||||
version = "0.1.4"
|
version = "0.1.5"
|
||||||
description = "An open-source version of the RIA Toolkit, including the fundamental tools to get started developing, testing, and deploying radio intelligence applications"
|
description = "An open-source version of the RIA Toolkit, including the fundamental tools to get started developing, testing, and deploying radio intelligence applications"
|
||||||
license = { text = "AGPL-3.0-only" }
|
license = { text = "AGPL-3.0-only" }
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
|
|
@ -49,7 +49,8 @@ dependencies = [
|
||||||
"pyzmq (>=27.1.0,<28.0.0)",
|
"pyzmq (>=27.1.0,<28.0.0)",
|
||||||
"pyyaml (>=6.0.3,<7.0.0)",
|
"pyyaml (>=6.0.3,<7.0.0)",
|
||||||
"click (>=8.1.0,<9.0.0)",
|
"click (>=8.1.0,<9.0.0)",
|
||||||
"matplotlib (>=3.8.0,<4.0.0)"
|
"matplotlib (>=3.8.0,<4.0.0)",
|
||||||
|
"paramiko (>=3.5.1)"
|
||||||
]
|
]
|
||||||
|
|
||||||
# [project.optional-dependencies] Commented out to prevent Tox tests from failing
|
# [project.optional-dependencies] Commented out to prevent Tox tests from failing
|
||||||
|
|
@ -85,15 +86,26 @@ build-backend = "poetry.core.masonry.api"
|
||||||
[tool.poetry.group.test.dependencies]
|
[tool.poetry.group.test.dependencies]
|
||||||
pytest = "^8.0.0"
|
pytest = "^8.0.0"
|
||||||
tox = "^4.19.0"
|
tox = "^4.19.0"
|
||||||
|
fastapi = ">=0.111,<1.0"
|
||||||
|
uvicorn = {version = ">=0.29,<1.0", extras = ["standard"]}
|
||||||
|
onnxruntime = {version = ">=1.17,<2.0", python = ">=3.11"}
|
||||||
|
httpx = ">=0.27,<1.0"
|
||||||
|
|
||||||
[tool.poetry.group.docs.dependencies]
|
[tool.poetry.group.docs.dependencies]
|
||||||
sphinx = "^7.2.6"
|
sphinx = "^7.2.6"
|
||||||
sphinx-rtd-theme = "^2.0.0"
|
sphinx-rtd-theme = "^2.0.0"
|
||||||
sphinx-autobuild = "^2024.2.4"
|
sphinx-autobuild = "^2024.2.4"
|
||||||
|
|
||||||
|
[tool.poetry.group.agent]
|
||||||
|
optional = true
|
||||||
|
|
||||||
|
[tool.poetry.group.agent.dependencies]
|
||||||
|
requests = ">=2.28,<3.0"
|
||||||
|
websockets = ">=12.0,<14.0"
|
||||||
|
|
||||||
[tool.poetry.group.dev.dependencies]
|
[tool.poetry.group.dev.dependencies]
|
||||||
flake8 = "^7.1.0"
|
flake8 = "^7.1.0"
|
||||||
black = "^24.3.0"
|
black = "^26.3.1"
|
||||||
isort = "^5.13.2"
|
isort = "^5.13.2"
|
||||||
pylint = "^3.2.6" # For pyreverse, to automate the creation of UML diagrams
|
pylint = "^3.2.6" # For pyreverse, to automate the creation of UML diagrams
|
||||||
|
|
||||||
|
|
@ -105,6 +117,14 @@ pylint = "^3.2.6" # For pyreverse, to automate the creation of UML diagrams
|
||||||
[tool.poetry.scripts]
|
[tool.poetry.scripts]
|
||||||
ria = "ria_toolkit_oss_cli.cli:cli"
|
ria = "ria_toolkit_oss_cli.cli:cli"
|
||||||
ria-tools = "ria_toolkit_oss_cli.cli:cli"
|
ria-tools = "ria_toolkit_oss_cli.cli:cli"
|
||||||
|
ria-server = "ria_toolkit_oss.server.cli:serve"
|
||||||
|
ria-agent = "ria_toolkit_oss.agent.cli:main"
|
||||||
|
ria-app = "ria_toolkit_oss.app.cli:main"
|
||||||
|
|
||||||
|
[tool.poetry.group.server.dependencies]
|
||||||
|
fastapi = ">=0.111,<1.0"
|
||||||
|
uvicorn = {version = ">=0.29,<1.0", extras = ["standard"]}
|
||||||
|
onnxruntime = {version = ">=1.17,<2.0", python = ">=3.11"}
|
||||||
|
|
||||||
[tool.black]
|
[tool.black]
|
||||||
line-length = 119
|
line-length = 119
|
||||||
|
|
@ -127,5 +147,13 @@ exclude = '''
|
||||||
)/
|
)/
|
||||||
'''
|
'''
|
||||||
|
|
||||||
|
[tool.pytest.ini_options]
|
||||||
|
pythonpath = ["src"]
|
||||||
|
filterwarnings = [
|
||||||
|
# FastAPI emits this internally when handling 422 responses; the constant
|
||||||
|
# is not yet renamed in the installed starlette version, so we can't migrate.
|
||||||
|
"ignore:'HTTP_422_UNPROCESSABLE_ENTITY' is deprecated:DeprecationWarning",
|
||||||
|
]
|
||||||
|
|
||||||
[tool.isort]
|
[tool.isort]
|
||||||
profile = "black"
|
profile = "black"
|
||||||
|
|
|
||||||
225
scripts/pluto_tx_smoke.py
Executable file
225
scripts/pluto_tx_smoke.py
Executable file
|
|
@ -0,0 +1,225 @@
|
||||||
|
#!/usr/bin/env python3
|
||||||
|
"""Transmit a continuous tone through the agent's TX pipeline on a real Pluto.
|
||||||
|
|
||||||
|
End-to-end smoke test for the Pluto + Streamer TX path. Drives the same
|
||||||
|
``Streamer`` the hub talks to, but in-process with a logging ``FakeWs`` so
|
||||||
|
the script is self-contained — no hub required.
|
||||||
|
|
||||||
|
Default: 100 kHz baseband tone × 2 450 MHz LO → carrier at 2 450.1 MHz,
|
||||||
|
continuous until you Ctrl-C (or the ``--duration`` timer fires). A spectrum
|
||||||
|
analyzer tuned to 2 450.1 MHz should show a clean CW spike as long as
|
||||||
|
``tx_status: transmitting`` prints.
|
||||||
|
|
||||||
|
Usage::
|
||||||
|
|
||||||
|
python3 scripts/pluto_tx_smoke.py # auto-discover Pluto
|
||||||
|
python3 scripts/pluto_tx_smoke.py --identifier 192.168.3.1
|
||||||
|
python3 scripts/pluto_tx_smoke.py --frequency 2.4e9 --gain -20 --duration 60
|
||||||
|
|
||||||
|
Flags map 1:1 onto the agent's ``radio_config``:
|
||||||
|
|
||||||
|
--identifier Pluto IP or hostname (omitted → ip:pluto.local).
|
||||||
|
--frequency TX LO in Hz. Default 2 450 MHz.
|
||||||
|
--gain Pluto TX gain in dB. Pluto range is ``[-89, 0]``; more negative
|
||||||
|
= more attenuation = less power. Default -30.
|
||||||
|
--sample-rate Baseband sample rate. Default 1 MHz.
|
||||||
|
--tone Baseband tone offset in Hz. Default 100 kHz; set 0 for DC
|
||||||
|
(unmodulated carrier at exactly --frequency, but Pluto's
|
||||||
|
LO leakage will dominate).
|
||||||
|
--buffer-size Complex samples per WS frame. Default 4096.
|
||||||
|
--duration Stop after this many seconds (0 = run until Ctrl-C).
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
import signal
|
||||||
|
import sys
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from ria_toolkit_oss.agent.config import AgentConfig
|
||||||
|
from ria_toolkit_oss.agent.streamer import Streamer
|
||||||
|
|
||||||
|
|
||||||
|
class LoggingFakeWs:
|
||||||
|
"""In-process stand-in for the hub's WebSocket.
|
||||||
|
|
||||||
|
Prints every ``tx_status`` + ``error`` frame the Streamer emits so the
|
||||||
|
operator can watch the lifecycle (armed → transmitting → done) on stdout.
|
||||||
|
"""
|
||||||
|
|
||||||
|
async def send_json(self, payload: dict) -> None:
|
||||||
|
t = payload.get("type")
|
||||||
|
if t == "tx_status":
|
||||||
|
state = payload.get("state")
|
||||||
|
msg = payload.get("message")
|
||||||
|
tail = f" — {msg}" if msg else ""
|
||||||
|
print(f"[tx_status] {state}{tail}")
|
||||||
|
elif t == "error":
|
||||||
|
print(f"[error] {payload.get('message')}")
|
||||||
|
|
||||||
|
async def send_bytes(self, data: bytes) -> None:
|
||||||
|
# Agent side won't send RX bytes in this script (no RX session).
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def _make_iq_frame(
|
||||||
|
buffer_size: int, tone_hz: float, sample_rate: float, phase_offset: float = 0.0
|
||||||
|
) -> tuple[bytes, float]:
|
||||||
|
"""Return ``(interleaved_float32_bytes, next_phase)`` for a sine tone.
|
||||||
|
|
||||||
|
Emitting one continuous phase-coherent tone requires threading the phase
|
||||||
|
across frames; the returned ``next_phase`` should be fed back as
|
||||||
|
``phase_offset`` on the next call so the sinusoid doesn't glitch at frame
|
||||||
|
boundaries. Amplitude is 0.7 to leave some headroom below the [-1, 1] cap
|
||||||
|
that ``_verify_sample_format`` polices elsewhere in the toolkit.
|
||||||
|
"""
|
||||||
|
n = np.arange(buffer_size, dtype=np.float64)
|
||||||
|
phase = 2.0 * np.pi * tone_hz / sample_rate * n + phase_offset
|
||||||
|
amp = 0.7
|
||||||
|
iq = amp * (np.cos(phase) + 1j * np.sin(phase))
|
||||||
|
iq = iq.astype(np.complex64)
|
||||||
|
interleaved = np.empty(buffer_size * 2, dtype=np.float32)
|
||||||
|
interleaved[0::2] = iq.real
|
||||||
|
interleaved[1::2] = iq.imag
|
||||||
|
next_phase = (2.0 * np.pi * tone_hz / sample_rate * buffer_size + phase_offset) % (2.0 * np.pi)
|
||||||
|
return interleaved.tobytes(), next_phase
|
||||||
|
|
||||||
|
|
||||||
|
def _make_pluto_factory(identifier: str | None):
|
||||||
|
def factory(device: str, _ident: str | None):
|
||||||
|
if device != "pluto":
|
||||||
|
raise ValueError(f"this script only drives pluto; got device={device!r}")
|
||||||
|
from ria_toolkit_oss.sdr.pluto import Pluto
|
||||||
|
|
||||||
|
return Pluto(identifier=identifier)
|
||||||
|
|
||||||
|
return factory
|
||||||
|
|
||||||
|
|
||||||
|
async def _run(args: argparse.Namespace) -> int:
|
||||||
|
ws = LoggingFakeWs()
|
||||||
|
cfg = AgentConfig(
|
||||||
|
tx_enabled=True,
|
||||||
|
# Pluto's TX gain range is [-89, 0]. Cap at 0 so a fat-fingered
|
||||||
|
# --gain=+5 still gets rejected at the agent boundary rather than
|
||||||
|
# turned into mystery attenuation by Pluto's setter.
|
||||||
|
tx_max_gain_db=0.0,
|
||||||
|
tx_max_duration_s=float(args.duration) if args.duration > 0 else None,
|
||||||
|
)
|
||||||
|
streamer = Streamer(ws=ws, sdr_factory=_make_pluto_factory(args.identifier), cfg=cfg)
|
||||||
|
|
||||||
|
await streamer.on_message(
|
||||||
|
{
|
||||||
|
"type": "tx_start",
|
||||||
|
"app_id": "smoke",
|
||||||
|
"radio_config": {
|
||||||
|
"device": "pluto",
|
||||||
|
"identifier": args.identifier,
|
||||||
|
"tx_sample_rate": int(args.sample_rate),
|
||||||
|
"tx_center_frequency": int(args.frequency),
|
||||||
|
"tx_gain": int(args.gain),
|
||||||
|
"buffer_size": int(args.buffer_size),
|
||||||
|
# "repeat" keeps the last buffer on the air if we ever stall,
|
||||||
|
# so a continuous carrier stays up even when Python GC or
|
||||||
|
# asyncio scheduling briefly pauses the producer.
|
||||||
|
"underrun_policy": "repeat",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Abort if tx_start was rejected by an interlock (no session → nothing to do).
|
||||||
|
if streamer._tx is None:
|
||||||
|
print("tx_start rejected — see [tx_status] line above for the reason.", file=sys.stderr)
|
||||||
|
return 2
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"Transmitting at {args.frequency/1e6:.3f} MHz with "
|
||||||
|
f"{args.tone/1e3:.1f} kHz baseband tone at gain {args.gain} dB. "
|
||||||
|
f"{'Running for ' + str(args.duration) + 's' if args.duration > 0 else 'Run until Ctrl-C'}."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Arrange a clean shutdown on Ctrl-C.
|
||||||
|
stop = asyncio.Event()
|
||||||
|
loop = asyncio.get_running_loop()
|
||||||
|
for sig in (signal.SIGINT, signal.SIGTERM):
|
||||||
|
try:
|
||||||
|
loop.add_signal_handler(sig, stop.set)
|
||||||
|
except NotImplementedError:
|
||||||
|
# add_signal_handler is not available on Windows event loops.
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Produce buffers at the nominal sample-rate pace. We deliberately stay
|
||||||
|
# slightly ahead of the radio — queue is bounded at 8, so backpressure
|
||||||
|
# flows naturally.
|
||||||
|
phase = 0.0
|
||||||
|
buffer_dt = args.buffer_size / args.sample_rate
|
||||||
|
# Aim for one buffer every ``buffer_dt * 0.5`` seconds so the queue stays
|
||||||
|
# topped up. The queue's own backpressure keeps us from spinning.
|
||||||
|
produce_interval = buffer_dt * 0.5
|
||||||
|
try:
|
||||||
|
|
||||||
|
async def producer():
|
||||||
|
nonlocal phase
|
||||||
|
while not stop.is_set():
|
||||||
|
frame, phase = _make_iq_frame(args.buffer_size, args.tone, args.sample_rate, phase)
|
||||||
|
await streamer.on_binary(frame)
|
||||||
|
await asyncio.sleep(produce_interval)
|
||||||
|
|
||||||
|
producer_task = asyncio.create_task(producer())
|
||||||
|
|
||||||
|
if args.duration > 0:
|
||||||
|
try:
|
||||||
|
await asyncio.wait_for(stop.wait(), timeout=args.duration)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
await stop.wait()
|
||||||
|
|
||||||
|
stop.set()
|
||||||
|
producer_task.cancel()
|
||||||
|
try:
|
||||||
|
await producer_task
|
||||||
|
except (asyncio.CancelledError, Exception):
|
||||||
|
pass
|
||||||
|
finally:
|
||||||
|
await streamer.on_message({"type": "tx_stop", "app_id": "smoke"})
|
||||||
|
|
||||||
|
print("TX session closed.")
|
||||||
|
return 0
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> int:
|
||||||
|
p = argparse.ArgumentParser(
|
||||||
|
description="End-to-end TX smoke test: agent → Pluto continuous tone.",
|
||||||
|
)
|
||||||
|
p.add_argument("--identifier", default=None, help="Pluto IP/hostname (default: auto-discover pluto.local)")
|
||||||
|
p.add_argument("--frequency", type=float, default=3_410_000_000.0, help="TX LO in Hz (default 2.45 GHz)")
|
||||||
|
p.add_argument("--gain", type=float, default=-0.0, help="TX gain in dB; Pluto range [-89, 0] (default -30)")
|
||||||
|
p.add_argument("--sample-rate", type=float, default=1_000_000.0, help="Baseband sample rate (default 1 Msps)")
|
||||||
|
p.add_argument(
|
||||||
|
"--tone", type=float, default=100_000.0, help="Baseband tone offset in Hz; 0 = DC (default 100 kHz)"
|
||||||
|
)
|
||||||
|
p.add_argument("--buffer-size", type=int, default=4096, help="Complex samples per frame (default 4096)")
|
||||||
|
p.add_argument(
|
||||||
|
"--duration", type=float, default=60.0, help="Seconds to transmit; 0 = run until Ctrl-C (default 30)"
|
||||||
|
)
|
||||||
|
p.add_argument("--log-level", default="INFO")
|
||||||
|
args = p.parse_args()
|
||||||
|
|
||||||
|
logging.basicConfig(
|
||||||
|
level=getattr(logging, args.log_level.upper(), logging.INFO),
|
||||||
|
format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
return asyncio.run(_run(args))
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
return 130
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
sys.exit(main())
|
||||||
230
scripts/pluto_tx_ws_smoke.py
Executable file
230
scripts/pluto_tx_ws_smoke.py
Executable file
|
|
@ -0,0 +1,230 @@
|
||||||
|
#!/usr/bin/env python3
|
||||||
|
"""Full-stack TX smoke test: localhost mock-hub → WS → agent → real Pluto.
|
||||||
|
|
||||||
|
Same radio output as ``pluto_tx_smoke.py`` (continuous tone at 2 450.1 MHz),
|
||||||
|
but drives the agent through the *real* WebSocket path instead of calling
|
||||||
|
handlers in-process. Proves that the hub-driven path behaves identically:
|
||||||
|
|
||||||
|
mock hub ── ws:// ──▶ WsClient.run() ──▶ Streamer.on_message
|
||||||
|
└▶ Streamer.on_binary
|
||||||
|
│
|
||||||
|
▼
|
||||||
|
real Pluto
|
||||||
|
|
||||||
|
This is the most rigorous check short of pointing the real ``ria-agent stream``
|
||||||
|
at a live ria-hub. If a tone appears on the spectrum analyzer here but *not*
|
||||||
|
when ria-hub drives it, the fault is above the WS decoder (registration,
|
||||||
|
capability gate, TX operator, hub's binary-frame publisher); everything
|
||||||
|
downstream of ``ws.recv()`` is this script's code path.
|
||||||
|
|
||||||
|
Usage::
|
||||||
|
|
||||||
|
python3 scripts/pluto_tx_ws_smoke.py # default 30s tone
|
||||||
|
python3 scripts/pluto_tx_ws_smoke.py --identifier 192.168.3.1
|
||||||
|
python3 scripts/pluto_tx_ws_smoke.py --duration 0 # until Ctrl-C
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import signal
|
||||||
|
import sys
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import websockets
|
||||||
|
|
||||||
|
from ria_toolkit_oss.agent.config import AgentConfig
|
||||||
|
from ria_toolkit_oss.agent.streamer import Streamer
|
||||||
|
from ria_toolkit_oss.agent.ws_client import WsClient
|
||||||
|
|
||||||
|
|
||||||
|
def _make_iq_frame(buffer_size: int, tone_hz: float, sample_rate: float, phase_offset: float) -> tuple[bytes, float]:
|
||||||
|
n = np.arange(buffer_size, dtype=np.float64)
|
||||||
|
phase = 2.0 * np.pi * tone_hz / sample_rate * n + phase_offset
|
||||||
|
amp = 0.7
|
||||||
|
iq = (amp * (np.cos(phase) + 1j * np.sin(phase))).astype(np.complex64)
|
||||||
|
interleaved = np.empty(buffer_size * 2, dtype=np.float32)
|
||||||
|
interleaved[0::2] = iq.real
|
||||||
|
interleaved[1::2] = iq.imag
|
||||||
|
next_phase = (2.0 * np.pi * tone_hz / sample_rate * buffer_size + phase_offset) % (2.0 * np.pi)
|
||||||
|
return interleaved.tobytes(), next_phase
|
||||||
|
|
||||||
|
|
||||||
|
def _make_pluto_factory(identifier: str | None):
|
||||||
|
def factory(device: str, _ident: str | None):
|
||||||
|
if device != "pluto":
|
||||||
|
raise ValueError(f"this script only drives pluto; got device={device!r}")
|
||||||
|
from ria_toolkit_oss.sdr.pluto import Pluto
|
||||||
|
|
||||||
|
return Pluto(identifier=identifier)
|
||||||
|
|
||||||
|
return factory
|
||||||
|
|
||||||
|
|
||||||
|
async def _mock_hub_handler(ws, args, stop: asyncio.Event):
|
||||||
|
"""Server side of the WS. Sends tx_start, streams IQ, then tx_stop."""
|
||||||
|
# Drain the first heartbeat so the log is clean; we don't need to gate on
|
||||||
|
# it for a localhost smoke test.
|
||||||
|
try:
|
||||||
|
first = await asyncio.wait_for(ws.recv(), timeout=2.0)
|
||||||
|
if isinstance(first, str):
|
||||||
|
payload = json.loads(first)
|
||||||
|
if payload.get("type") == "heartbeat":
|
||||||
|
caps = payload.get("capabilities")
|
||||||
|
print(f"[mock-hub] agent heartbeat: capabilities={caps} " f"tx_enabled={payload.get('tx_enabled')}")
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
print("[mock-hub] warning: no heartbeat received in first 2s")
|
||||||
|
|
||||||
|
# Arm the agent's TX path.
|
||||||
|
await ws.send(
|
||||||
|
json.dumps(
|
||||||
|
{
|
||||||
|
"type": "tx_start",
|
||||||
|
"app_id": "ws-smoke",
|
||||||
|
"radio_config": {
|
||||||
|
"device": "pluto",
|
||||||
|
"identifier": args.identifier,
|
||||||
|
"tx_sample_rate": int(args.sample_rate),
|
||||||
|
"tx_center_frequency": int(args.frequency),
|
||||||
|
"tx_gain": int(args.gain),
|
||||||
|
"buffer_size": int(args.buffer_size),
|
||||||
|
"underrun_policy": "repeat",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
)
|
||||||
|
print(f"[mock-hub] sent tx_start at {args.frequency/1e6:.3f} MHz, " f"gain={args.gain} dB")
|
||||||
|
|
||||||
|
# Producer: push IQ frames at a steady clip. Use a concurrent receiver so
|
||||||
|
# tx_status frames show up in real time rather than being queued behind
|
||||||
|
# the sends.
|
||||||
|
phase = 0.0
|
||||||
|
buffer_dt = args.buffer_size / args.sample_rate
|
||||||
|
|
||||||
|
async def receiver():
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
msg = await ws.recv()
|
||||||
|
if isinstance(msg, str):
|
||||||
|
print(f"[mock-hub] ← {msg}")
|
||||||
|
except (websockets.ConnectionClosed, asyncio.CancelledError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
recv_task = asyncio.create_task(receiver())
|
||||||
|
try:
|
||||||
|
deadline = None if args.duration <= 0 else (asyncio.get_event_loop().time() + args.duration)
|
||||||
|
while not stop.is_set():
|
||||||
|
if deadline is not None and asyncio.get_event_loop().time() >= deadline:
|
||||||
|
break
|
||||||
|
frame, phase = _make_iq_frame(args.buffer_size, args.tone, args.sample_rate, phase)
|
||||||
|
try:
|
||||||
|
await ws.send(frame)
|
||||||
|
except websockets.ConnectionClosed:
|
||||||
|
break
|
||||||
|
# Slightly ahead of real-time; WS backpressure handles the rest.
|
||||||
|
await asyncio.sleep(buffer_dt * 0.5)
|
||||||
|
finally:
|
||||||
|
try:
|
||||||
|
await ws.send(json.dumps({"type": "tx_stop", "app_id": "ws-smoke"}))
|
||||||
|
print("[mock-hub] sent tx_stop")
|
||||||
|
except websockets.ConnectionClosed:
|
||||||
|
pass
|
||||||
|
# Give the agent a moment to emit `tx_status: done` before we tear down.
|
||||||
|
await asyncio.sleep(0.3)
|
||||||
|
recv_task.cancel()
|
||||||
|
try:
|
||||||
|
await recv_task
|
||||||
|
except (asyncio.CancelledError, Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
async def _run(args: argparse.Namespace) -> int:
|
||||||
|
stop = asyncio.Event()
|
||||||
|
loop = asyncio.get_running_loop()
|
||||||
|
for sig in (signal.SIGINT, signal.SIGTERM):
|
||||||
|
try:
|
||||||
|
loop.add_signal_handler(sig, stop.set)
|
||||||
|
except NotImplementedError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Start the mock hub on a local port.
|
||||||
|
async def handler(ws):
|
||||||
|
try:
|
||||||
|
await _mock_hub_handler(ws, args, stop)
|
||||||
|
finally:
|
||||||
|
stop.set()
|
||||||
|
|
||||||
|
server = await websockets.serve(handler, "127.0.0.1", 0)
|
||||||
|
port = server.sockets[0].getsockname()[1]
|
||||||
|
print(f"[mock-hub] listening on ws://127.0.0.1:{port}")
|
||||||
|
|
||||||
|
# Run the agent — exactly as ``ria-agent stream`` would, just with a
|
||||||
|
# different URL and an in-memory AgentConfig instead of one loaded from
|
||||||
|
# ``~/.ria/agent.json``.
|
||||||
|
client = WsClient(
|
||||||
|
f"ws://127.0.0.1:{port}",
|
||||||
|
token="",
|
||||||
|
heartbeat_interval=5.0,
|
||||||
|
reconnect_pause=0.5,
|
||||||
|
)
|
||||||
|
streamer = Streamer(
|
||||||
|
ws=client,
|
||||||
|
sdr_factory=_make_pluto_factory(args.identifier),
|
||||||
|
cfg=AgentConfig(tx_enabled=True, tx_max_gain_db=0.0),
|
||||||
|
)
|
||||||
|
client_task = asyncio.create_task(
|
||||||
|
client.run(
|
||||||
|
on_message=streamer.on_message,
|
||||||
|
heartbeat=streamer.build_heartbeat,
|
||||||
|
on_binary=streamer.on_binary,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
await stop.wait()
|
||||||
|
finally:
|
||||||
|
client.stop()
|
||||||
|
client_task.cancel()
|
||||||
|
try:
|
||||||
|
await client_task
|
||||||
|
except (asyncio.CancelledError, Exception):
|
||||||
|
pass
|
||||||
|
server.close()
|
||||||
|
await server.wait_closed()
|
||||||
|
|
||||||
|
print("Done.")
|
||||||
|
return 0
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> int:
|
||||||
|
p = argparse.ArgumentParser(
|
||||||
|
description="Full-stack TX smoke: localhost mock-hub → WS → agent → Pluto.",
|
||||||
|
)
|
||||||
|
p.add_argument("--identifier", default=None, help="Pluto IP/hostname (default: auto-discover pluto.local)")
|
||||||
|
p.add_argument("--frequency", type=float, default=2_450_000_000.0, help="TX LO in Hz (default 2.45 GHz)")
|
||||||
|
p.add_argument("--gain", type=float, default=0.0, help="TX gain in dB; Pluto range [-89, 0] (default 0)")
|
||||||
|
p.add_argument("--sample-rate", type=float, default=1_000_000.0, help="Baseband sample rate (default 1 Msps)")
|
||||||
|
p.add_argument("--tone", type=float, default=100_000.0, help="Baseband tone offset in Hz (default 100 kHz)")
|
||||||
|
p.add_argument("--buffer-size", type=int, default=4096, help="Complex samples per frame (default 4096)")
|
||||||
|
p.add_argument(
|
||||||
|
"--duration", type=float, default=30.0, help="Seconds to transmit; 0 = run until Ctrl-C (default 30)"
|
||||||
|
)
|
||||||
|
p.add_argument("--log-level", default="INFO")
|
||||||
|
args = p.parse_args()
|
||||||
|
|
||||||
|
logging.basicConfig(
|
||||||
|
level=getattr(logging, args.log_level.upper(), logging.INFO),
|
||||||
|
format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
return asyncio.run(_run(args))
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
return 130
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
sys.exit(main())
|
||||||
26
src/ria_toolkit_oss/agent/__init__.py
Normal file
26
src/ria_toolkit_oss/agent/__init__.py
Normal file
|
|
@ -0,0 +1,26 @@
|
||||||
|
"""RIA Toolkit agent package.
|
||||||
|
|
||||||
|
Provides two execution modes:
|
||||||
|
|
||||||
|
- **Legacy long-poll executor** (`NodeAgent` in :mod:`legacy_executor`) — an
|
||||||
|
HTTP long-polling agent that runs ONNX inference locally on the host.
|
||||||
|
- **Streamer** (:mod:`streamer`) — a thin WebSocket client that opens an SDR
|
||||||
|
and streams raw IQ to the RIA Hub server, which performs all inference.
|
||||||
|
|
||||||
|
Back-compat: ``from ria_toolkit_oss.agent import NodeAgent`` and the ``main``
|
||||||
|
entry point continue to work.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from .legacy_executor import NodeAgent
|
||||||
|
from .legacy_executor import main as _legacy_main
|
||||||
|
|
||||||
|
__all__ = ["NodeAgent", "main"]
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> None:
|
||||||
|
"""Unified CLI entry point. Dispatches to streamer/legacy subcommands."""
|
||||||
|
from .cli import main as _cli_main
|
||||||
|
|
||||||
|
_cli_main()
|
||||||
212
src/ria_toolkit_oss/agent/cli.py
Normal file
212
src/ria_toolkit_oss/agent/cli.py
Normal file
|
|
@ -0,0 +1,212 @@
|
||||||
|
"""Unified ``ria-agent`` CLI.
|
||||||
|
|
||||||
|
Subcommands:
|
||||||
|
|
||||||
|
- ``ria-agent run [legacy args]`` — legacy long-poll NodeAgent (unchanged).
|
||||||
|
- ``ria-agent stream`` — new WebSocket-based IQ streamer.
|
||||||
|
- ``ria-agent detect`` — print SDR drivers whose modules import cleanly.
|
||||||
|
- ``ria-agent register --hub URL --api-key KEY`` — register with the hub and
|
||||||
|
save credentials (and optional TX interlocks) to ``~/.ria/agent.json``.
|
||||||
|
|
||||||
|
Invoking ``ria-agent`` with no subcommand falls through to the legacy
|
||||||
|
long-poll behavior for back-compatibility with existing deployments.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import sys
|
||||||
|
|
||||||
|
from . import config as _config
|
||||||
|
from .hardware import available_devices
|
||||||
|
from .legacy_executor import main as _legacy_main
|
||||||
|
from .namegen import generate_agent_name
|
||||||
|
|
||||||
|
_LEGACY_ALIASES = {"--hub", "--key", "--name", "--device", "--insecure", "--log-level", "--config"}
|
||||||
|
|
||||||
|
|
||||||
|
def _cmd_detect(_args: argparse.Namespace) -> int:
|
||||||
|
devices = available_devices()
|
||||||
|
if not devices:
|
||||||
|
print("No SDR drivers available (install ria-toolkit-oss[all-sdr] or per-driver extras).")
|
||||||
|
return 0
|
||||||
|
for name in devices:
|
||||||
|
print(name)
|
||||||
|
return 0
|
||||||
|
|
||||||
|
|
||||||
|
def _cmd_register(args: argparse.Namespace) -> int:
|
||||||
|
import urllib.request
|
||||||
|
|
||||||
|
hub_url = args.hub.rstrip("/")
|
||||||
|
url = f"{hub_url}/screens/agents/register"
|
||||||
|
name = args.name or generate_agent_name()
|
||||||
|
body = json.dumps({"name": name}).encode()
|
||||||
|
req = urllib.request.Request(
|
||||||
|
url,
|
||||||
|
data=body,
|
||||||
|
headers={
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
"X-API-Key": args.api_key,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
with urllib.request.urlopen(req) as resp:
|
||||||
|
data = json.loads(resp.read())
|
||||||
|
except Exception as e:
|
||||||
|
print(f"error: registration failed: {e}", file=sys.stderr)
|
||||||
|
return 1
|
||||||
|
|
||||||
|
agent_id = data["agent_id"]
|
||||||
|
token = data["token"]
|
||||||
|
|
||||||
|
cfg = _config.load()
|
||||||
|
cfg.hub_url = hub_url
|
||||||
|
cfg.agent_id = agent_id
|
||||||
|
cfg.token = token
|
||||||
|
cfg.api_key = args.api_key
|
||||||
|
cfg.name = name
|
||||||
|
cfg.insecure = bool(args.insecure)
|
||||||
|
cfg.tx_enabled = bool(getattr(args, "allow_tx", False))
|
||||||
|
if (v := getattr(args, "tx_max_gain_db", None)) is not None:
|
||||||
|
cfg.tx_max_gain_db = float(v)
|
||||||
|
if (v := getattr(args, "tx_max_duration_s", None)) is not None:
|
||||||
|
cfg.tx_max_duration_s = float(v)
|
||||||
|
freq_ranges = getattr(args, "tx_freq_range", None) or []
|
||||||
|
if freq_ranges:
|
||||||
|
cfg.tx_allowed_freq_ranges = [[float(lo), float(hi)] for lo, hi in freq_ranges]
|
||||||
|
path = _config.save(cfg)
|
||||||
|
|
||||||
|
print(f"Registered agent: {agent_id}")
|
||||||
|
if cfg.tx_enabled:
|
||||||
|
caps: list[str] = []
|
||||||
|
if cfg.tx_max_gain_db is not None:
|
||||||
|
caps.append(f"gain<={cfg.tx_max_gain_db} dB")
|
||||||
|
if cfg.tx_max_duration_s is not None:
|
||||||
|
caps.append(f"duration<={cfg.tx_max_duration_s} s")
|
||||||
|
if cfg.tx_allowed_freq_ranges:
|
||||||
|
caps.append(f"freq in {cfg.tx_allowed_freq_ranges}")
|
||||||
|
tail = f" ({', '.join(caps)})" if caps else ""
|
||||||
|
print(f"TX enabled{tail}")
|
||||||
|
print(f"Credentials saved to {path}")
|
||||||
|
return 0
|
||||||
|
|
||||||
|
|
||||||
|
def _cmd_stream(args: argparse.Namespace) -> int:
|
||||||
|
from .streamer import run_streamer
|
||||||
|
|
||||||
|
cfg = _config.load()
|
||||||
|
url = args.url or _derive_ws_url(cfg.hub_url, cfg.agent_id)
|
||||||
|
token = args.token or cfg.token
|
||||||
|
if not url:
|
||||||
|
print("error: --url is required (or run `ria-agent register` first)", file=sys.stderr)
|
||||||
|
return 2
|
||||||
|
if getattr(args, "allow_tx", False):
|
||||||
|
cfg.tx_enabled = True
|
||||||
|
try:
|
||||||
|
asyncio.run(run_streamer(url, token, cfg=cfg))
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
pass
|
||||||
|
return 0
|
||||||
|
|
||||||
|
|
||||||
|
def _derive_ws_url(hub_url: str, agent_id: str) -> str:
|
||||||
|
if not hub_url:
|
||||||
|
return ""
|
||||||
|
base = hub_url.rstrip("/")
|
||||||
|
if base.startswith("https://"):
|
||||||
|
base = "wss://" + base[len("https://") :]
|
||||||
|
elif base.startswith("http://"):
|
||||||
|
base = "ws://" + base[len("http://") :]
|
||||||
|
suffix = f"/screens/agent/ws?agent_id={agent_id}" if agent_id else "/screens/agent/ws"
|
||||||
|
return base + suffix
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> None:
|
||||||
|
# Back-compat: if the first non-flag token matches a known legacy flag,
|
||||||
|
# or there is no subcommand at all, dispatch to the legacy CLI.
|
||||||
|
argv = sys.argv[1:]
|
||||||
|
if not argv or (argv[0].startswith("--") and argv[0] in _LEGACY_ALIASES):
|
||||||
|
_legacy_main()
|
||||||
|
return
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(prog="ria-agent")
|
||||||
|
sub = parser.add_subparsers(dest="command", required=True)
|
||||||
|
|
||||||
|
sub.add_parser("run", help="Legacy long-poll agent (NodeAgent)")
|
||||||
|
sub.add_parser("detect", help="List available SDR drivers")
|
||||||
|
|
||||||
|
p_reg = sub.add_parser("register", help="Register agent with RIA Hub and save credentials")
|
||||||
|
p_reg.add_argument("--hub", required=True, help="RIA Hub URL (e.g. http://whitehorse:3005)")
|
||||||
|
p_reg.add_argument("--api-key", dest="api_key", required=True, help="Hub API key")
|
||||||
|
p_reg.add_argument("--name", default=None, help="Human-friendly agent name")
|
||||||
|
p_reg.add_argument("--insecure", action="store_true", help="Skip TLS verification")
|
||||||
|
p_reg.add_argument(
|
||||||
|
"--allow-tx",
|
||||||
|
dest="allow_tx",
|
||||||
|
action="store_true",
|
||||||
|
help="Opt this agent in to TX (required for any transmission from the hub)",
|
||||||
|
)
|
||||||
|
p_reg.add_argument(
|
||||||
|
"--tx-max-gain-db",
|
||||||
|
dest="tx_max_gain_db",
|
||||||
|
type=float,
|
||||||
|
default=None,
|
||||||
|
help="Reject tx_start frames whose tx_gain exceeds this cap (dB)",
|
||||||
|
)
|
||||||
|
p_reg.add_argument(
|
||||||
|
"--tx-max-duration-s",
|
||||||
|
dest="tx_max_duration_s",
|
||||||
|
type=float,
|
||||||
|
default=None,
|
||||||
|
help="Auto-stop any TX session after this many seconds",
|
||||||
|
)
|
||||||
|
p_reg.add_argument(
|
||||||
|
"--tx-freq-range",
|
||||||
|
dest="tx_freq_range",
|
||||||
|
type=float,
|
||||||
|
nargs=2,
|
||||||
|
action="append",
|
||||||
|
metavar=("LO", "HI"),
|
||||||
|
default=None,
|
||||||
|
help="Allowed TX center-frequency range in Hz (repeat for multiple bands)",
|
||||||
|
)
|
||||||
|
|
||||||
|
p_stream = sub.add_parser("stream", help="Run the WebSocket IQ streamer")
|
||||||
|
p_stream.add_argument("--url", default=None, help="Override WebSocket URL")
|
||||||
|
p_stream.add_argument("--token", default=None, help="Override bearer token")
|
||||||
|
p_stream.add_argument("--log-level", default="INFO")
|
||||||
|
p_stream.add_argument(
|
||||||
|
"--allow-tx",
|
||||||
|
dest="allow_tx",
|
||||||
|
action="store_true",
|
||||||
|
help="Runtime override: enable TX for this process without writing config",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Unknown extras are forwarded to the legacy CLI when command == "run".
|
||||||
|
args, extras = parser.parse_known_args(argv)
|
||||||
|
|
||||||
|
logging.basicConfig(
|
||||||
|
level=getattr(logging, getattr(args, "log_level", "INFO"), logging.INFO),
|
||||||
|
format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
|
||||||
|
)
|
||||||
|
|
||||||
|
if args.command == "run":
|
||||||
|
sys.argv = [sys.argv[0], *extras]
|
||||||
|
_legacy_main()
|
||||||
|
return
|
||||||
|
if args.command == "detect":
|
||||||
|
sys.exit(_cmd_detect(args))
|
||||||
|
if args.command == "register":
|
||||||
|
sys.exit(_cmd_register(args))
|
||||||
|
if args.command == "stream":
|
||||||
|
sys.exit(_cmd_stream(args))
|
||||||
|
|
||||||
|
parser.error(f"unknown command: {args.command}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
89
src/ria_toolkit_oss/agent/config.py
Normal file
89
src/ria_toolkit_oss/agent/config.py
Normal file
|
|
@ -0,0 +1,89 @@
|
||||||
|
"""Agent configuration stored at ``~/.ria/agent.json``.
|
||||||
|
|
||||||
|
Schema::
|
||||||
|
|
||||||
|
{
|
||||||
|
"hub_url": "https://riahub.example.com",
|
||||||
|
"agent_id": "agent-abc123",
|
||||||
|
"token": "rha_xxxx",
|
||||||
|
"name": "lab-bench-1",
|
||||||
|
"insecure": false,
|
||||||
|
"tx_enabled": false,
|
||||||
|
"tx_max_gain_db": null,
|
||||||
|
"tx_max_duration_s": null,
|
||||||
|
"tx_allowed_freq_ranges": null
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
from dataclasses import asdict, dataclass, field
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve_default_path() -> Path:
|
||||||
|
return Path(os.environ.get("RIA_AGENT_CONFIG", str(Path.home() / ".ria" / "agent.json")))
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class AgentConfig:
|
||||||
|
hub_url: str = ""
|
||||||
|
agent_id: str = ""
|
||||||
|
token: str = ""
|
||||||
|
name: str = ""
|
||||||
|
insecure: bool = False
|
||||||
|
api_key: str = ""
|
||||||
|
tx_enabled: bool = False
|
||||||
|
tx_max_gain_db: float | None = None
|
||||||
|
tx_max_duration_s: float | None = None
|
||||||
|
tx_allowed_freq_ranges: list[list[float]] | None = None
|
||||||
|
extra: dict = field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
|
def default_path() -> Path:
|
||||||
|
return _resolve_default_path()
|
||||||
|
|
||||||
|
|
||||||
|
def _coerce_ranges(raw) -> list[list[float]] | None:
|
||||||
|
if raw is None:
|
||||||
|
return None
|
||||||
|
out: list[list[float]] = []
|
||||||
|
for pair in raw:
|
||||||
|
lo, hi = pair
|
||||||
|
out.append([float(lo), float(hi)])
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def load(path: Path | None = None) -> AgentConfig:
|
||||||
|
p = path or _resolve_default_path()
|
||||||
|
if not p.exists():
|
||||||
|
return AgentConfig()
|
||||||
|
data = json.loads(p.read_text())
|
||||||
|
known = {f for f in AgentConfig.__dataclass_fields__ if f != "extra"}
|
||||||
|
extra = {k: v for k, v in data.items() if k not in known}
|
||||||
|
return AgentConfig(
|
||||||
|
hub_url=data.get("hub_url", ""),
|
||||||
|
agent_id=data.get("agent_id", ""),
|
||||||
|
token=data.get("token", ""),
|
||||||
|
name=data.get("name", ""),
|
||||||
|
insecure=bool(data.get("insecure", False)),
|
||||||
|
api_key=data.get("api_key", ""),
|
||||||
|
tx_enabled=bool(data.get("tx_enabled", False)),
|
||||||
|
tx_max_gain_db=(float(v) if (v := data.get("tx_max_gain_db")) is not None else None),
|
||||||
|
tx_max_duration_s=(float(v) if (v := data.get("tx_max_duration_s")) is not None else None),
|
||||||
|
tx_allowed_freq_ranges=_coerce_ranges(data.get("tx_allowed_freq_ranges")),
|
||||||
|
extra=extra,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def save(cfg: AgentConfig, path: Path | None = None) -> Path:
|
||||||
|
p = path or _resolve_default_path()
|
||||||
|
p.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
data = asdict(cfg)
|
||||||
|
extra = data.pop("extra", {}) or {}
|
||||||
|
data.update(extra)
|
||||||
|
p.write_text(json.dumps(data, indent=2))
|
||||||
|
os.chmod(p, 0o600)
|
||||||
|
return p
|
||||||
54
src/ria_toolkit_oss/agent/hardware.py
Normal file
54
src/ria_toolkit_oss/agent/hardware.py
Normal file
|
|
@ -0,0 +1,54 @@
|
||||||
|
"""Hardware detection and heartbeat payload construction for the streamer."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from ria_toolkit_oss.sdr import detect_available
|
||||||
|
|
||||||
|
from .config import AgentConfig
|
||||||
|
|
||||||
|
|
||||||
|
def available_devices() -> list[str]:
|
||||||
|
"""Return a sorted list of device names whose driver modules import cleanly."""
|
||||||
|
return sorted(detect_available().keys())
|
||||||
|
|
||||||
|
|
||||||
|
def heartbeat_payload(
|
||||||
|
status: str = "idle",
|
||||||
|
app_id: str | None = None,
|
||||||
|
*,
|
||||||
|
cfg: AgentConfig | None = None,
|
||||||
|
sessions: dict | None = None,
|
||||||
|
) -> dict:
|
||||||
|
"""Build the JSON body of a periodic heartbeat frame.
|
||||||
|
|
||||||
|
*cfg* drives the ``capabilities`` list and the ``tx_enabled`` flag. If not
|
||||||
|
supplied, the heartbeat advertises RX-only with ``tx_enabled=False`` —
|
||||||
|
matching the pre-TX shape.
|
||||||
|
"""
|
||||||
|
c = cfg or AgentConfig()
|
||||||
|
capabilities = ["rx"]
|
||||||
|
if c.tx_enabled:
|
||||||
|
capabilities.append("tx")
|
||||||
|
|
||||||
|
payload: dict = {
|
||||||
|
"type": "heartbeat",
|
||||||
|
"hardware": available_devices(),
|
||||||
|
"status": status,
|
||||||
|
"capabilities": capabilities,
|
||||||
|
"tx_enabled": bool(c.tx_enabled),
|
||||||
|
}
|
||||||
|
# Surface configured interlock values so the hub can pre-filter UI controls
|
||||||
|
# before sending a tx_start that would be rejected. Only included when TX
|
||||||
|
# is opted in AND the operator set a cap.
|
||||||
|
if c.tx_enabled:
|
||||||
|
if c.tx_max_gain_db is not None:
|
||||||
|
payload["tx_max_gain_db"] = float(c.tx_max_gain_db)
|
||||||
|
if c.tx_max_duration_s is not None:
|
||||||
|
payload["tx_max_duration_s"] = float(c.tx_max_duration_s)
|
||||||
|
if c.tx_allowed_freq_ranges:
|
||||||
|
payload["tx_allowed_freq_ranges"] = [[float(lo), float(hi)] for lo, hi in c.tx_allowed_freq_ranges]
|
||||||
|
if app_id:
|
||||||
|
payload["app_id"] = app_id
|
||||||
|
if sessions:
|
||||||
|
payload["sessions"] = sessions
|
||||||
|
return payload
|
||||||
1000
src/ria_toolkit_oss/agent/legacy_executor.py
Normal file
1000
src/ria_toolkit_oss/agent/legacy_executor.py
Normal file
File diff suppressed because it is too large
Load Diff
147
src/ria_toolkit_oss/agent/namegen.py
Normal file
147
src/ria_toolkit_oss/agent/namegen.py
Normal file
|
|
@ -0,0 +1,147 @@
|
||||||
|
"""Generate random human-readable agent names.
|
||||||
|
|
||||||
|
Produces names in the form ``adjective-colour-animal``, e.g.
|
||||||
|
``swift-teal-falcon`` or ``brave-coral-otter``. All words are chosen
|
||||||
|
to be friendly and inoffensive.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import random
|
||||||
|
|
||||||
|
ADJECTIVES: list[str] = [
|
||||||
|
"brave",
|
||||||
|
"bright",
|
||||||
|
"calm",
|
||||||
|
"clever",
|
||||||
|
"cool",
|
||||||
|
"daring",
|
||||||
|
"eager",
|
||||||
|
"fair",
|
||||||
|
"fancy",
|
||||||
|
"fast",
|
||||||
|
"fierce",
|
||||||
|
"gentle",
|
||||||
|
"grand",
|
||||||
|
"happy",
|
||||||
|
"jolly",
|
||||||
|
"keen",
|
||||||
|
"kind",
|
||||||
|
"lively",
|
||||||
|
"lucky",
|
||||||
|
"mighty",
|
||||||
|
"noble",
|
||||||
|
"plucky",
|
||||||
|
"proud",
|
||||||
|
"quick",
|
||||||
|
"quiet",
|
||||||
|
"sharp",
|
||||||
|
"shiny",
|
||||||
|
"sleek",
|
||||||
|
"smart",
|
||||||
|
"steady",
|
||||||
|
"stellar",
|
||||||
|
"strong",
|
||||||
|
"sturdy",
|
||||||
|
"sunny",
|
||||||
|
"sure",
|
||||||
|
"swift",
|
||||||
|
"tall",
|
||||||
|
"vivid",
|
||||||
|
"warm",
|
||||||
|
"wise",
|
||||||
|
]
|
||||||
|
|
||||||
|
COLOURS: list[str] = [
|
||||||
|
"amber",
|
||||||
|
"aqua",
|
||||||
|
"azure",
|
||||||
|
"beige",
|
||||||
|
"blue",
|
||||||
|
"bronze",
|
||||||
|
"coral",
|
||||||
|
"copper",
|
||||||
|
"crimson",
|
||||||
|
"cyan",
|
||||||
|
"denim",
|
||||||
|
"gold",
|
||||||
|
"green",
|
||||||
|
"grey",
|
||||||
|
"indigo",
|
||||||
|
"ivory",
|
||||||
|
"jade",
|
||||||
|
"lemon",
|
||||||
|
"lilac",
|
||||||
|
"lime",
|
||||||
|
"maroon",
|
||||||
|
"mint",
|
||||||
|
"navy",
|
||||||
|
"olive",
|
||||||
|
"onyx",
|
||||||
|
"peach",
|
||||||
|
"pearl",
|
||||||
|
"plum",
|
||||||
|
"red",
|
||||||
|
"rose",
|
||||||
|
"ruby",
|
||||||
|
"rust",
|
||||||
|
"sage",
|
||||||
|
"sand",
|
||||||
|
"scarlet",
|
||||||
|
"silver",
|
||||||
|
"slate",
|
||||||
|
"steel",
|
||||||
|
"teal",
|
||||||
|
"violet",
|
||||||
|
]
|
||||||
|
|
||||||
|
ANIMALS: list[str] = [
|
||||||
|
"badger",
|
||||||
|
"bear",
|
||||||
|
"bison",
|
||||||
|
"crane",
|
||||||
|
"deer",
|
||||||
|
"dolphin",
|
||||||
|
"eagle",
|
||||||
|
"elk",
|
||||||
|
"falcon",
|
||||||
|
"finch",
|
||||||
|
"fox",
|
||||||
|
"gecko",
|
||||||
|
"hawk",
|
||||||
|
"heron",
|
||||||
|
"horse",
|
||||||
|
"ibis",
|
||||||
|
"jaguar",
|
||||||
|
"jay",
|
||||||
|
"kite",
|
||||||
|
"koala",
|
||||||
|
"lark",
|
||||||
|
"lion",
|
||||||
|
"lynx",
|
||||||
|
"marten",
|
||||||
|
"moose",
|
||||||
|
"newt",
|
||||||
|
"orca",
|
||||||
|
"osprey",
|
||||||
|
"otter",
|
||||||
|
"owl",
|
||||||
|
"panda",
|
||||||
|
"puma",
|
||||||
|
"raven",
|
||||||
|
"robin",
|
||||||
|
"salmon",
|
||||||
|
"seal",
|
||||||
|
"shark",
|
||||||
|
"stork",
|
||||||
|
"swift",
|
||||||
|
"wolf",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def generate_agent_name() -> str:
|
||||||
|
"""Return a random ``adjective-colour-animal`` name."""
|
||||||
|
adj = random.choice(ADJECTIVES)
|
||||||
|
col = random.choice(COLOURS)
|
||||||
|
ani = random.choice(ANIMALS)
|
||||||
|
return f"{adj}-{col}-{ani}"
|
||||||
747
src/ria_toolkit_oss/agent/streamer.py
Normal file
747
src/ria_toolkit_oss/agent/streamer.py
Normal file
|
|
@ -0,0 +1,747 @@
|
||||||
|
"""IQ-streaming agent.
|
||||||
|
|
||||||
|
Listens for control messages from the RIA Hub over a persistent WebSocket.
|
||||||
|
Supports:
|
||||||
|
|
||||||
|
- An **RX session** (hub sends ``start``/``stop``/``configure``; agent opens
|
||||||
|
the SDR, loops ``sdr.rx()`` and ships raw interleaved float32 IQ).
|
||||||
|
- A **TX session** (hub sends ``tx_start``/``tx_stop``/``tx_configure`` plus
|
||||||
|
binary IQ frames; agent feeds them into ``sdr._stream_tx``). Phase 3 wires
|
||||||
|
up the session plumbing and rejects TX when ``cfg.tx_enabled`` is False;
|
||||||
|
Phase 4 implements the full TX loop.
|
||||||
|
|
||||||
|
Both sessions can run concurrently on the same physical SDR (FDD) — a
|
||||||
|
ref-counted SDR registry shares one driver instance when RX and TX name the
|
||||||
|
same ``(device, identifier)``.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
import queue
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from .config import AgentConfig
|
||||||
|
from .hardware import heartbeat_payload
|
||||||
|
from .ws_client import WsClient
|
||||||
|
|
||||||
|
logger = logging.getLogger("ria_agent.streamer")
|
||||||
|
|
||||||
|
_DEFAULT_BUFFER_SIZE = 1024
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Session dataclasses
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class RxSession:
|
||||||
|
app_id: str
|
||||||
|
sdr: Any
|
||||||
|
device_key: tuple[str, str | None]
|
||||||
|
buffer_size: int
|
||||||
|
task: asyncio.Task | None = None
|
||||||
|
pending_config: dict = field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TxSession:
|
||||||
|
app_id: str
|
||||||
|
sdr: Any
|
||||||
|
device_key: tuple[str, str | None]
|
||||||
|
buffer_size: int
|
||||||
|
task: Any = None # concurrent.futures.Future from run_in_executor
|
||||||
|
pending_config: dict = field(default_factory=dict)
|
||||||
|
underrun_policy: str = "pause"
|
||||||
|
last_buffer: np.ndarray | None = None
|
||||||
|
stop_event: threading.Event = field(default_factory=threading.Event)
|
||||||
|
started_at: float = 0.0
|
||||||
|
max_duration_s: float | None = None
|
||||||
|
state: str = "armed"
|
||||||
|
# Thread-safe queue of inbound interleaved-float32 IQ frames. Bounded so
|
||||||
|
# hub-side over-production triggers WS backpressure rather than memory
|
||||||
|
# growth in the agent.
|
||||||
|
in_queue: "queue.Queue[bytes]" = field(default_factory=lambda: queue.Queue(maxsize=8))
|
||||||
|
# Set by the TX callback when it hits an underrun while policy=="pause";
|
||||||
|
# asyncio side flips the session state and emits tx_status.
|
||||||
|
underrun_flag: threading.Event = field(default_factory=threading.Event)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# SDR registry (ref-counted so one Pluto handle serves RX + TX simultaneously)
|
||||||
|
|
||||||
|
|
||||||
|
class _SdrRegistry:
|
||||||
|
def __init__(self, factory):
|
||||||
|
self._factory = factory
|
||||||
|
self._instances: dict[tuple[str, str | None], tuple[Any, int]] = {}
|
||||||
|
self._lock = threading.Lock()
|
||||||
|
|
||||||
|
def acquire(self, device: str, identifier: str | None) -> tuple[Any, tuple[str, str | None]]:
|
||||||
|
key = (device, identifier)
|
||||||
|
with self._lock:
|
||||||
|
if key in self._instances:
|
||||||
|
sdr, rc = self._instances[key]
|
||||||
|
self._instances[key] = (sdr, rc + 1)
|
||||||
|
return sdr, key
|
||||||
|
# Build outside the lock: driver init can be slow and we don't want to
|
||||||
|
# block concurrent releases on unrelated devices.
|
||||||
|
sdr = self._factory(device, identifier)
|
||||||
|
with self._lock:
|
||||||
|
if key in self._instances:
|
||||||
|
# Raced another acquirer; discard our duplicate and share theirs.
|
||||||
|
other_sdr, rc = self._instances[key]
|
||||||
|
try:
|
||||||
|
sdr.close()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
self._instances[key] = (other_sdr, rc + 1)
|
||||||
|
return other_sdr, key
|
||||||
|
self._instances[key] = (sdr, 1)
|
||||||
|
return sdr, key
|
||||||
|
|
||||||
|
def release(self, key: tuple[str, str | None]) -> bool:
|
||||||
|
"""Decrement refcount. Returns True if the caller owns the last reference
|
||||||
|
and should close the SDR."""
|
||||||
|
with self._lock:
|
||||||
|
sdr, rc = self._instances.get(key, (None, 0))
|
||||||
|
if sdr is None:
|
||||||
|
return False
|
||||||
|
if rc <= 1:
|
||||||
|
del self._instances[key]
|
||||||
|
return True
|
||||||
|
self._instances[key] = (sdr, rc - 1)
|
||||||
|
return False
|
||||||
|
|
||||||
|
def refcount(self, key: tuple[str, str | None]) -> int:
|
||||||
|
with self._lock:
|
||||||
|
return self._instances.get(key, (None, 0))[1]
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Streamer
|
||||||
|
|
||||||
|
|
||||||
|
class Streamer:
|
||||||
|
"""Main streamer loop.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
ws:
|
||||||
|
Connected :class:`WsClient`.
|
||||||
|
sdr_factory:
|
||||||
|
Callable ``(device, identifier) -> SDR``. Defaults to the helper in
|
||||||
|
:mod:`ria_toolkit_oss.sdr`. Injectable for tests.
|
||||||
|
cfg:
|
||||||
|
:class:`AgentConfig` for interlocks (``tx_enabled`` and caps) and
|
||||||
|
heartbeat capabilities. Defaults to an empty ``AgentConfig()`` which
|
||||||
|
leaves TX disabled.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
ws,
|
||||||
|
sdr_factory=None,
|
||||||
|
cfg: AgentConfig | None = None,
|
||||||
|
) -> None:
|
||||||
|
self.ws = ws
|
||||||
|
self._cfg = cfg or AgentConfig()
|
||||||
|
self._registry = _SdrRegistry(sdr_factory or _default_sdr_factory)
|
||||||
|
self._rx: RxSession | None = None
|
||||||
|
self._tx: TxSession | None = None
|
||||||
|
# Pending radio_config accepted via ``configure`` before ``start``.
|
||||||
|
self._standalone_pending_config: dict = {}
|
||||||
|
# Cached asyncio event loop, set the first time a handler runs. Used
|
||||||
|
# to schedule async callbacks from the TX executor thread.
|
||||||
|
self._loop: asyncio.AbstractEventLoop | None = None
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Back-compat read-only shims for callers that check ``._sdr`` etc.
|
||||||
|
# Writes to these attributes are not supported — use the session objects.
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _sdr(self):
|
||||||
|
return self._rx.sdr if self._rx is not None else None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _pending_config(self) -> dict:
|
||||||
|
return self._rx.pending_config if self._rx is not None else self._standalone_pending_config
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# WsClient wiring
|
||||||
|
|
||||||
|
def build_heartbeat(self) -> dict:
|
||||||
|
status = "streaming" if (self._rx is not None or self._tx is not None) else "idle"
|
||||||
|
app_id: str | None = None
|
||||||
|
if self._rx is not None:
|
||||||
|
app_id = self._rx.app_id
|
||||||
|
elif self._tx is not None:
|
||||||
|
app_id = self._tx.app_id
|
||||||
|
|
||||||
|
sessions: dict[str, dict] = {}
|
||||||
|
if self._rx is not None:
|
||||||
|
sessions["rx"] = {"app_id": self._rx.app_id, "state": "streaming"}
|
||||||
|
if self._tx is not None:
|
||||||
|
sessions["tx"] = {"app_id": self._tx.app_id, "state": self._tx.state}
|
||||||
|
|
||||||
|
return heartbeat_payload(
|
||||||
|
status=status,
|
||||||
|
app_id=app_id,
|
||||||
|
cfg=self._cfg,
|
||||||
|
sessions=sessions or None,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Advisory / keepalive message types we accept and ignore without warning.
|
||||||
|
_IGNORED_MESSAGE_TYPES = frozenset({"tx_data_available"})
|
||||||
|
|
||||||
|
async def on_message(self, msg: dict) -> None:
|
||||||
|
t = msg.get("type")
|
||||||
|
if t in self._IGNORED_MESSAGE_TYPES:
|
||||||
|
logger.debug("Ignoring advisory message: %r", t)
|
||||||
|
return
|
||||||
|
handler = {
|
||||||
|
"start": self._handle_rx_start,
|
||||||
|
"stop": self._handle_rx_stop,
|
||||||
|
"configure": self._handle_rx_configure,
|
||||||
|
"tx_start": self._handle_tx_start,
|
||||||
|
"tx_stop": self._handle_tx_stop,
|
||||||
|
"tx_configure": self._handle_tx_configure,
|
||||||
|
}.get(t)
|
||||||
|
if handler is None:
|
||||||
|
logger.warning("Unknown server message type: %r", t)
|
||||||
|
return
|
||||||
|
await handler(msg)
|
||||||
|
|
||||||
|
async def on_binary(self, data: bytes) -> None:
|
||||||
|
tx = self._tx
|
||||||
|
if tx is None:
|
||||||
|
logger.debug("Dropping %d-byte binary frame: no TX session", len(data))
|
||||||
|
return
|
||||||
|
# Backpressure: if the TX queue is full, await briefly so the hub's
|
||||||
|
# ``await ws.send`` throttles naturally via TCP. We don't block
|
||||||
|
# indefinitely — a 2s stall means something else is wrong.
|
||||||
|
loop = asyncio.get_running_loop()
|
||||||
|
try:
|
||||||
|
await loop.run_in_executor(None, lambda: tx.in_queue.put(data, timeout=2.0))
|
||||||
|
except queue.Full:
|
||||||
|
logger.warning("TX queue stalled; dropping frame")
|
||||||
|
|
||||||
|
# ==================================================================
|
||||||
|
# RX
|
||||||
|
|
||||||
|
async def _handle_rx_start(self, msg: dict) -> None:
|
||||||
|
if self._rx is not None:
|
||||||
|
logger.warning("start received while already streaming — ignoring")
|
||||||
|
return
|
||||||
|
|
||||||
|
app_id = msg.get("app_id") or ""
|
||||||
|
radio_config = dict(msg.get("radio_config") or {})
|
||||||
|
device = radio_config.pop("device", None)
|
||||||
|
identifier = radio_config.pop("identifier", None)
|
||||||
|
buffer_size = int(radio_config.pop("buffer_size", _DEFAULT_BUFFER_SIZE))
|
||||||
|
if not device:
|
||||||
|
await self._send_error(app_id, "start missing radio_config.device")
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
sdr, device_key = self._registry.acquire(device, identifier)
|
||||||
|
_apply_sdr_config(sdr, radio_config)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.exception("Failed to open SDR %r", device)
|
||||||
|
await self._send_error(app_id, f"SDR init failed: {exc}")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Inherit any pending config that was queued before start.
|
||||||
|
pending = dict(self._standalone_pending_config)
|
||||||
|
self._standalone_pending_config = {}
|
||||||
|
|
||||||
|
session = RxSession(
|
||||||
|
app_id=app_id,
|
||||||
|
sdr=sdr,
|
||||||
|
device_key=device_key,
|
||||||
|
buffer_size=buffer_size,
|
||||||
|
pending_config=pending,
|
||||||
|
)
|
||||||
|
self._rx = session
|
||||||
|
await self._send_status("streaming", app_id)
|
||||||
|
session.task = asyncio.create_task(self._capture_loop(session), name="ria-streamer-capture")
|
||||||
|
|
||||||
|
async def _handle_rx_stop(self, msg: dict) -> None:
|
||||||
|
session = self._rx
|
||||||
|
if session is None:
|
||||||
|
return
|
||||||
|
if session.task is not None:
|
||||||
|
session.task.cancel()
|
||||||
|
try:
|
||||||
|
await session.task
|
||||||
|
except (asyncio.CancelledError, Exception):
|
||||||
|
pass
|
||||||
|
self._close_session_sdr(session)
|
||||||
|
app_id = session.app_id
|
||||||
|
self._rx = None
|
||||||
|
await self._send_status("idle", app_id)
|
||||||
|
|
||||||
|
async def _handle_rx_configure(self, msg: dict) -> None:
|
||||||
|
cfg = dict(msg.get("radio_config") or {})
|
||||||
|
if self._rx is not None:
|
||||||
|
self._rx.pending_config.update(cfg)
|
||||||
|
else:
|
||||||
|
self._standalone_pending_config.update(cfg)
|
||||||
|
logger.debug("Queued configure: %s", cfg)
|
||||||
|
|
||||||
|
async def _capture_loop(self, session: RxSession) -> None:
|
||||||
|
loop = asyncio.get_running_loop()
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
if session.pending_config:
|
||||||
|
cfg = session.pending_config
|
||||||
|
session.pending_config = {}
|
||||||
|
try:
|
||||||
|
_apply_sdr_config(session.sdr, cfg)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("Applying configure failed: %s", exc)
|
||||||
|
|
||||||
|
try:
|
||||||
|
samples = await loop.run_in_executor(None, session.sdr.rx, session.buffer_size)
|
||||||
|
except Exception as exc:
|
||||||
|
from ria_toolkit_oss.sdr import SdrDisconnectedError
|
||||||
|
|
||||||
|
if isinstance(exc, SdrDisconnectedError):
|
||||||
|
logger.warning("SDR disconnected: %s", exc)
|
||||||
|
await self._send_error(session.app_id, f"SDR disconnected: {exc}")
|
||||||
|
else:
|
||||||
|
logger.exception("SDR rx error")
|
||||||
|
await self._send_error(session.app_id, f"SDR capture failed: {exc}")
|
||||||
|
break
|
||||||
|
|
||||||
|
payload = _samples_to_interleaved_float32(samples)
|
||||||
|
try:
|
||||||
|
await self.ws.send_bytes(payload)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("Send failed: %s — ending capture", exc)
|
||||||
|
break
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
raise
|
||||||
|
finally:
|
||||||
|
self._close_session_sdr(session)
|
||||||
|
# If the loop died on its own (e.g. SDR disconnect), clear the
|
||||||
|
# session handle so future ``start`` messages can proceed.
|
||||||
|
if self._rx is session:
|
||||||
|
self._rx = None
|
||||||
|
|
||||||
|
# ==================================================================
|
||||||
|
# TX
|
||||||
|
|
||||||
|
async def _handle_tx_start(self, msg: dict) -> None: # noqa: C901
|
||||||
|
app_id = msg.get("app_id") or ""
|
||||||
|
radio_config = dict(msg.get("radio_config") or {})
|
||||||
|
|
||||||
|
# --- interlocks (agent-enforced; never trust the hub alone) ---
|
||||||
|
if not self._cfg.tx_enabled:
|
||||||
|
await self._send_tx_status(app_id, "error", "tx disabled on this agent")
|
||||||
|
return
|
||||||
|
tx_gain = radio_config.get("tx_gain")
|
||||||
|
if (
|
||||||
|
self._cfg.tx_max_gain_db is not None
|
||||||
|
and tx_gain is not None
|
||||||
|
and float(tx_gain) > float(self._cfg.tx_max_gain_db)
|
||||||
|
):
|
||||||
|
await self._send_tx_status(
|
||||||
|
app_id,
|
||||||
|
"error",
|
||||||
|
f"tx_gain {tx_gain} exceeds cap {self._cfg.tx_max_gain_db}",
|
||||||
|
)
|
||||||
|
return
|
||||||
|
tx_freq = radio_config.get("tx_center_frequency")
|
||||||
|
if self._cfg.tx_allowed_freq_ranges and tx_freq is not None:
|
||||||
|
f = float(tx_freq)
|
||||||
|
if not any(float(lo) <= f <= float(hi) for lo, hi in self._cfg.tx_allowed_freq_ranges):
|
||||||
|
await self._send_tx_status(
|
||||||
|
app_id,
|
||||||
|
"error",
|
||||||
|
f"tx_center_frequency {tx_freq} outside allowed ranges",
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
if self._tx is not None:
|
||||||
|
await self._send_tx_status(app_id, "error", "tx already active on this agent")
|
||||||
|
return
|
||||||
|
|
||||||
|
# --- device ---
|
||||||
|
device = radio_config.pop("device", None)
|
||||||
|
identifier = radio_config.pop("identifier", None)
|
||||||
|
buffer_size = int(radio_config.pop("buffer_size", _DEFAULT_BUFFER_SIZE))
|
||||||
|
underrun_policy = str(radio_config.pop("underrun_policy", "pause"))
|
||||||
|
if underrun_policy not in ("pause", "zero", "repeat"):
|
||||||
|
await self._send_tx_status(app_id, "error", f"invalid underrun_policy {underrun_policy!r}")
|
||||||
|
return
|
||||||
|
if not device:
|
||||||
|
await self._send_tx_status(app_id, "error", "tx_start missing radio_config.device")
|
||||||
|
return
|
||||||
|
|
||||||
|
device_key: tuple[str, str | None] | None = None
|
||||||
|
sdr: Any = None
|
||||||
|
try:
|
||||||
|
sdr, device_key = self._registry.acquire(device, identifier)
|
||||||
|
_apply_sdr_config(sdr, radio_config)
|
||||||
|
# init_tx is mandatory for any driver that exposes it: drivers
|
||||||
|
# that gate _stream_tx on _tx_initialized (Pluto, HackRF, USRP,
|
||||||
|
# …) crash with a confusing "TX was not initialized" error 2 s
|
||||||
|
# later in the executor thread if we skip it. Treat the three
|
||||||
|
# required keys as a hard contract — a missing one is a hub-side
|
||||||
|
# manifest bug and we want it surfaced immediately, not papered
|
||||||
|
# over with stale radio state.
|
||||||
|
if hasattr(sdr, "init_tx"):
|
||||||
|
init_args = {k: radio_config.get(f"tx_{k}") for k in ("sample_rate", "center_frequency", "gain")}
|
||||||
|
missing = [f"tx_{k}" for k, v in init_args.items() if v is None]
|
||||||
|
if missing:
|
||||||
|
raise ValueError(f"tx_start missing required radio_config keys: {missing}")
|
||||||
|
sdr.init_tx(
|
||||||
|
sample_rate=init_args["sample_rate"],
|
||||||
|
center_frequency=init_args["center_frequency"],
|
||||||
|
gain=init_args["gain"],
|
||||||
|
channel=radio_config.get("tx_channel", 0),
|
||||||
|
gain_mode=radio_config.get("tx_gain_mode", "manual"),
|
||||||
|
)
|
||||||
|
except Exception as exc:
|
||||||
|
if device_key is not None:
|
||||||
|
if self._registry.release(device_key):
|
||||||
|
try:
|
||||||
|
sdr.close()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
logger.exception("Failed to init TX on %r", device)
|
||||||
|
await self._send_tx_status(app_id, "error", f"tx init failed: {exc}")
|
||||||
|
return
|
||||||
|
|
||||||
|
self._loop = asyncio.get_running_loop()
|
||||||
|
session = TxSession(
|
||||||
|
app_id=app_id,
|
||||||
|
sdr=sdr,
|
||||||
|
device_key=device_key,
|
||||||
|
buffer_size=buffer_size,
|
||||||
|
underrun_policy=underrun_policy,
|
||||||
|
started_at=time.monotonic(),
|
||||||
|
max_duration_s=self._cfg.tx_max_duration_s,
|
||||||
|
)
|
||||||
|
self._tx = session
|
||||||
|
await self._send_tx_status(app_id, "armed")
|
||||||
|
session.task = self._loop.run_in_executor(None, self._tx_executor_body, session)
|
||||||
|
# Spawn a small watchdog that transitions armed → transmitting when
|
||||||
|
# the first buffer has been consumed, and surfaces underrun / max-
|
||||||
|
# duration terminations back to the hub.
|
||||||
|
asyncio.create_task(self._tx_watchdog(session))
|
||||||
|
|
||||||
|
async def _handle_tx_stop(self, msg: dict) -> None:
|
||||||
|
session = self._tx
|
||||||
|
if session is None:
|
||||||
|
return
|
||||||
|
app_id = session.app_id
|
||||||
|
session.stop_event.set()
|
||||||
|
try:
|
||||||
|
session.sdr.pause_tx()
|
||||||
|
except Exception:
|
||||||
|
logger.debug("pause_tx raised during stop", exc_info=True)
|
||||||
|
# Wake the executor thread if it's blocked on ``queue.get``.
|
||||||
|
self._drain_tx_queue(session)
|
||||||
|
if session.task is not None:
|
||||||
|
try:
|
||||||
|
await asyncio.wait_for(asyncio.wrap_future(session.task), timeout=1.5)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
logger.warning("TX executor did not exit within 1.5s after stop")
|
||||||
|
except Exception:
|
||||||
|
logger.debug("TX executor raised on shutdown", exc_info=True)
|
||||||
|
self._close_session_sdr(session)
|
||||||
|
self._tx = None
|
||||||
|
await self._send_tx_status(app_id, "done")
|
||||||
|
|
||||||
|
async def _handle_tx_configure(self, msg: dict) -> None:
|
||||||
|
if self._tx is None:
|
||||||
|
return
|
||||||
|
self._tx.pending_config.update(msg.get("radio_config") or {})
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# TX executor & watchdog
|
||||||
|
|
||||||
|
def _tx_executor_body(self, session: TxSession) -> None:
|
||||||
|
try:
|
||||||
|
session.sdr._stream_tx(lambda n: self._tx_callback(session, n))
|
||||||
|
except Exception as exc:
|
||||||
|
logger.exception("TX stream crashed")
|
||||||
|
# Schedule both the error frame and session teardown on the loop
|
||||||
|
# so ``self._tx`` clears, subsequent binary frames are rejected,
|
||||||
|
# and the SDR handle is released.
|
||||||
|
self._schedule(self._tx_crash_teardown(session, str(exc)))
|
||||||
|
|
||||||
|
def _tx_callback(self, session: TxSession, num_samples) -> np.ndarray:
|
||||||
|
n = int(num_samples)
|
||||||
|
# Honor stop requests: return silence one last time and let the driver
|
||||||
|
# exit its loop on the next iteration (pause_tx flips _enable_tx).
|
||||||
|
if session.stop_event.is_set():
|
||||||
|
return _silence(n)
|
||||||
|
|
||||||
|
# Max-duration watchdog.
|
||||||
|
if session.max_duration_s is not None and (time.monotonic() - session.started_at) >= float(
|
||||||
|
session.max_duration_s
|
||||||
|
):
|
||||||
|
session.stop_event.set()
|
||||||
|
try:
|
||||||
|
session.sdr.pause_tx()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
self._schedule(self._send_tx_status(session.app_id, "done", "max duration reached"))
|
||||||
|
return _silence(n)
|
||||||
|
|
||||||
|
# Apply queued configure at buffer boundary.
|
||||||
|
if session.pending_config:
|
||||||
|
cfg = session.pending_config
|
||||||
|
session.pending_config = {}
|
||||||
|
try:
|
||||||
|
_apply_sdr_config(session.sdr, cfg)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.debug("tx_configure apply failed: %s", exc)
|
||||||
|
|
||||||
|
try:
|
||||||
|
raw = session.in_queue.get(timeout=0.1)
|
||||||
|
except queue.Empty:
|
||||||
|
return self._underrun_fill(session, n)
|
||||||
|
|
||||||
|
arr = np.frombuffer(raw, dtype=np.float32)
|
||||||
|
if arr.size < 2 or arr.size % 2 != 0:
|
||||||
|
logger.warning("Malformed TX frame: %d floats (must be non-zero even count)", arr.size)
|
||||||
|
return self._underrun_fill(session, n)
|
||||||
|
samples = arr[0::2].astype(np.complex64) + 1j * arr[1::2].astype(np.complex64)
|
||||||
|
if samples.size < n:
|
||||||
|
out = np.zeros(n, dtype=np.complex64)
|
||||||
|
out[: samples.size] = samples
|
||||||
|
session.last_buffer = out
|
||||||
|
return out
|
||||||
|
if samples.size > n:
|
||||||
|
samples = samples[:n]
|
||||||
|
session.last_buffer = samples
|
||||||
|
if session.state == "armed":
|
||||||
|
session.state = "transmitting"
|
||||||
|
self._schedule(self._send_tx_status(session.app_id, "transmitting"))
|
||||||
|
return samples
|
||||||
|
|
||||||
|
def _underrun_fill(self, session: TxSession, n: int) -> np.ndarray:
|
||||||
|
policy = session.underrun_policy
|
||||||
|
if policy == "zero":
|
||||||
|
return _silence(n)
|
||||||
|
if policy == "repeat" and session.last_buffer is not None:
|
||||||
|
buf = session.last_buffer
|
||||||
|
if buf.size == n:
|
||||||
|
return buf
|
||||||
|
if buf.size > n:
|
||||||
|
return buf[:n].copy()
|
||||||
|
out = np.zeros(n, dtype=np.complex64)
|
||||||
|
out[: buf.size] = buf
|
||||||
|
return out
|
||||||
|
# "pause" policy (default) or "repeat" before any buffer arrived.
|
||||||
|
if not session.underrun_flag.is_set():
|
||||||
|
session.underrun_flag.set()
|
||||||
|
session.stop_event.set()
|
||||||
|
try:
|
||||||
|
session.sdr.pause_tx()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
return _silence(n)
|
||||||
|
|
||||||
|
async def _tx_watchdog(self, session: TxSession) -> None:
|
||||||
|
# Poll the underrun flag so we can emit status + tear down cleanly
|
||||||
|
# when the callback flips the flag from the executor thread. Check
|
||||||
|
# underrun_flag before stop_event, since the "pause" path sets both.
|
||||||
|
while session is self._tx:
|
||||||
|
if session.underrun_flag.is_set():
|
||||||
|
await self._send_tx_status(session.app_id, "underrun")
|
||||||
|
await self._teardown_tx_after_underrun(session)
|
||||||
|
return
|
||||||
|
if session.stop_event.is_set():
|
||||||
|
return
|
||||||
|
await asyncio.sleep(0.05)
|
||||||
|
|
||||||
|
async def _tx_crash_teardown(self, session: TxSession, message: str) -> None:
|
||||||
|
# Called from the executor thread via _schedule when _stream_tx raises.
|
||||||
|
# Emit the error, mark stopped, drain the queue, release the SDR.
|
||||||
|
await self._send_tx_status(session.app_id, "error", f"tx stream crashed: {message}")
|
||||||
|
if self._tx is not session:
|
||||||
|
return
|
||||||
|
session.stop_event.set()
|
||||||
|
self._drain_tx_queue(session)
|
||||||
|
self._close_session_sdr(session)
|
||||||
|
if self._tx is session:
|
||||||
|
self._tx = None
|
||||||
|
|
||||||
|
async def _teardown_tx_after_underrun(self, session: TxSession) -> None:
|
||||||
|
if self._tx is not session:
|
||||||
|
return
|
||||||
|
self._drain_tx_queue(session)
|
||||||
|
if session.task is not None:
|
||||||
|
try:
|
||||||
|
await asyncio.wait_for(asyncio.wrap_future(session.task), timeout=1.0)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
logger.warning("TX executor did not exit within 1s after underrun")
|
||||||
|
except Exception:
|
||||||
|
logger.debug("TX executor raised during underrun teardown", exc_info=True)
|
||||||
|
self._close_session_sdr(session)
|
||||||
|
if self._tx is session:
|
||||||
|
self._tx = None
|
||||||
|
|
||||||
|
def _drain_tx_queue(self, session: TxSession) -> None:
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
session.in_queue.get_nowait()
|
||||||
|
except queue.Empty:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def _schedule(self, coro) -> None:
|
||||||
|
loop = self._loop
|
||||||
|
if loop is None:
|
||||||
|
return
|
||||||
|
try:
|
||||||
|
asyncio.run_coroutine_threadsafe(coro, loop)
|
||||||
|
except Exception:
|
||||||
|
logger.debug("_schedule failed", exc_info=True)
|
||||||
|
|
||||||
|
# ==================================================================
|
||||||
|
# Helpers
|
||||||
|
|
||||||
|
def _close_session_sdr(self, session) -> None:
|
||||||
|
if session.sdr is None:
|
||||||
|
return
|
||||||
|
should_close = self._registry.release(session.device_key)
|
||||||
|
if should_close:
|
||||||
|
try:
|
||||||
|
session.sdr.close()
|
||||||
|
except Exception:
|
||||||
|
logger.debug("SDR close raised", exc_info=True)
|
||||||
|
|
||||||
|
async def _send_status(self, status: str, app_id: str) -> None:
|
||||||
|
try:
|
||||||
|
await self.ws.send_json({"type": "status", "status": status, "app_id": app_id})
|
||||||
|
except Exception as exc:
|
||||||
|
logger.debug("Status send failed: %s", exc)
|
||||||
|
|
||||||
|
async def _send_error(self, app_id: str, message: str) -> None:
|
||||||
|
try:
|
||||||
|
await self.ws.send_json({"type": "error", "app_id": app_id, "message": message})
|
||||||
|
except Exception as exc:
|
||||||
|
logger.debug("Error-frame send failed: %s", exc)
|
||||||
|
|
||||||
|
async def _send_tx_status(self, app_id: str, state: str, message: str | None = None) -> None:
|
||||||
|
payload: dict = {"type": "tx_status", "app_id": app_id, "state": state}
|
||||||
|
if message is not None:
|
||||||
|
payload["message"] = message
|
||||||
|
try:
|
||||||
|
await self.ws.send_json(payload)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.debug("tx_status send failed: %s", exc)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Helpers
|
||||||
|
|
||||||
|
_CONFIG_ATTR_MAP = {
|
||||||
|
"sample_rate": ("sample_rate", "rx_sample_rate"),
|
||||||
|
"center_frequency": ("center_freq", "rx_center_frequency"),
|
||||||
|
"center_freq": ("center_freq", "rx_center_frequency"),
|
||||||
|
"gain": ("gain", "rx_gain"),
|
||||||
|
"bandwidth": ("bandwidth", "rx_bandwidth"),
|
||||||
|
"tx_sample_rate": ("tx_sample_rate",),
|
||||||
|
"tx_center_frequency": ("tx_center_frequency", "tx_lo"),
|
||||||
|
"tx_gain": ("tx_gain",),
|
||||||
|
"tx_bandwidth": ("tx_bandwidth",),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _is_stub_setter(method: Any) -> bool:
|
||||||
|
"""True when *method* is an unimplemented base-class stub.
|
||||||
|
|
||||||
|
The ``SDR`` abstract base defines ``set_rx_sample_rate`` / ``set_tx_gain``
|
||||||
|
etc. as zero-argument ``NotImplementedError`` stubs. A driver (Pluto) that
|
||||||
|
actually transmits overrides them with a real ``(value, ...)`` signature.
|
||||||
|
Comparing ``__qualname__`` against ``SDR.`` lets us skip the stubs cheaply.
|
||||||
|
"""
|
||||||
|
return getattr(method, "__qualname__", "").startswith("SDR.")
|
||||||
|
|
||||||
|
|
||||||
|
def _apply_sdr_config(sdr: Any, cfg: dict) -> None:
|
||||||
|
"""Apply a radio_config dict to an SDR.
|
||||||
|
|
||||||
|
Prefers ``sdr.set_<attr>(value)`` when the driver implements it — Pluto's
|
||||||
|
setters take ``_param_lock``, so routing through them keeps concurrent
|
||||||
|
RX + TX reconfigures from racing on shared native attributes. Falls back
|
||||||
|
to ``setattr`` for drivers (MockSDR, tests) that don't override the
|
||||||
|
base-class stubs.
|
||||||
|
"""
|
||||||
|
for key, value in cfg.items():
|
||||||
|
if value is None:
|
||||||
|
continue
|
||||||
|
attrs = _CONFIG_ATTR_MAP.get(key, (key,))
|
||||||
|
applied = False
|
||||||
|
for attr in attrs:
|
||||||
|
setter = getattr(sdr, f"set_{attr}", None)
|
||||||
|
if callable(setter) and not _is_stub_setter(setter):
|
||||||
|
try:
|
||||||
|
setter(value)
|
||||||
|
applied = True
|
||||||
|
break
|
||||||
|
except Exception as exc:
|
||||||
|
logger.debug("set_%s(%r) failed: %s", attr, value, exc)
|
||||||
|
# Fall through to setattr; some drivers may partially
|
||||||
|
# implement setters.
|
||||||
|
if applied:
|
||||||
|
continue
|
||||||
|
for attr in attrs:
|
||||||
|
if hasattr(sdr, attr):
|
||||||
|
try:
|
||||||
|
setattr(sdr, attr, value)
|
||||||
|
applied = True
|
||||||
|
break
|
||||||
|
except Exception as exc:
|
||||||
|
logger.debug("setattr %s=%r failed: %s", attr, value, exc)
|
||||||
|
if not applied:
|
||||||
|
logger.debug("radio_config key %r ignored (no matching attr)", key)
|
||||||
|
|
||||||
|
|
||||||
|
def _silence(num_samples: int) -> np.ndarray:
|
||||||
|
"""Return a ``num_samples``-length zero-filled complex64 buffer."""
|
||||||
|
return np.zeros(int(num_samples), dtype=np.complex64)
|
||||||
|
|
||||||
|
|
||||||
|
def _samples_to_interleaved_float32(samples: Any) -> bytes:
|
||||||
|
"""Convert complex IQ samples (any numeric dtype) to interleaved float32 bytes."""
|
||||||
|
arr = np.asarray(samples)
|
||||||
|
if np.iscomplexobj(arr):
|
||||||
|
interleaved = np.empty(arr.size * 2, dtype=np.float32)
|
||||||
|
interleaved[0::2] = arr.real.astype(np.float32, copy=False).ravel()
|
||||||
|
interleaved[1::2] = arr.imag.astype(np.float32, copy=False).ravel()
|
||||||
|
return interleaved.tobytes()
|
||||||
|
return arr.astype(np.float32, copy=False).tobytes()
|
||||||
|
|
||||||
|
|
||||||
|
def _default_sdr_factory(device: str, identifier: str | None):
|
||||||
|
from ria_toolkit_oss.sdr import get_sdr_device
|
||||||
|
|
||||||
|
return get_sdr_device(device, ident=identifier)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Top-level entry
|
||||||
|
|
||||||
|
|
||||||
|
async def run_streamer(ws_url: str, token: str, *, cfg: AgentConfig | None = None) -> None:
|
||||||
|
"""Connect to *ws_url* and run the streamer loop until cancelled."""
|
||||||
|
ws = WsClient(ws_url, token)
|
||||||
|
streamer = Streamer(ws, cfg=cfg)
|
||||||
|
await ws.run(
|
||||||
|
streamer.on_message,
|
||||||
|
streamer.build_heartbeat,
|
||||||
|
on_binary=streamer.on_binary,
|
||||||
|
)
|
||||||
128
src/ria_toolkit_oss/agent/ws_client.py
Normal file
128
src/ria_toolkit_oss/agent/ws_client.py
Normal file
|
|
@ -0,0 +1,128 @@
|
||||||
|
"""Persistent WebSocket client for the streamer agent.
|
||||||
|
|
||||||
|
Handles connection lifecycle: connect, heartbeat, auto-reconnect on drop.
|
||||||
|
The caller drives the I/O loop via ``run()`` with a message handler callback.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from typing import Awaitable, Callable
|
||||||
|
|
||||||
|
logger = logging.getLogger("ria_agent.ws")
|
||||||
|
|
||||||
|
MessageHandler = Callable[[dict], Awaitable[None]]
|
||||||
|
HeartbeatBuilder = Callable[[], dict]
|
||||||
|
BinaryHandler = Callable[[bytes], Awaitable[None]]
|
||||||
|
|
||||||
|
|
||||||
|
class WsClient:
|
||||||
|
"""Persistent WebSocket connection with heartbeat and auto-reconnect.
|
||||||
|
|
||||||
|
``url`` should be a full ``wss://host/path`` (or ``ws://``) URL. ``token``
|
||||||
|
is sent as a bearer in the ``Authorization`` header on connect.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
url: str,
|
||||||
|
token: str,
|
||||||
|
*,
|
||||||
|
heartbeat_interval: float = 30.0,
|
||||||
|
reconnect_pause: float = 5.0,
|
||||||
|
) -> None:
|
||||||
|
self.url = url
|
||||||
|
self.token = token
|
||||||
|
self.heartbeat_interval = heartbeat_interval
|
||||||
|
self.reconnect_pause = reconnect_pause
|
||||||
|
self._ws = None
|
||||||
|
self._stop = asyncio.Event()
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
async def _connect(self):
|
||||||
|
import websockets
|
||||||
|
|
||||||
|
headers = [("Authorization", f"Bearer {self.token}")] if self.token else None
|
||||||
|
# websockets >= 12 accepts additional_headers; fall back to extra_headers for older versions.
|
||||||
|
try:
|
||||||
|
return await websockets.connect(self.url, additional_headers=headers)
|
||||||
|
except TypeError:
|
||||||
|
return await websockets.connect(self.url, extra_headers=headers)
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
async def send_json(self, payload: dict) -> None:
|
||||||
|
if self._ws is None:
|
||||||
|
raise ConnectionError("WebSocket is not connected")
|
||||||
|
await self._ws.send(json.dumps(payload))
|
||||||
|
|
||||||
|
async def send_bytes(self, data: bytes) -> None:
|
||||||
|
if self._ws is None:
|
||||||
|
raise ConnectionError("WebSocket is not connected")
|
||||||
|
await self._ws.send(data)
|
||||||
|
|
||||||
|
def stop(self) -> None:
|
||||||
|
self._stop.set()
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
async def run(
|
||||||
|
self,
|
||||||
|
on_message: MessageHandler,
|
||||||
|
heartbeat: HeartbeatBuilder,
|
||||||
|
on_binary: BinaryHandler | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""Main loop: connect, heartbeat, dispatch messages, reconnect on drop."""
|
||||||
|
while not self._stop.is_set():
|
||||||
|
try:
|
||||||
|
self._ws = await self._connect()
|
||||||
|
logger.info("Connected to %s", self.url)
|
||||||
|
hb_task = asyncio.create_task(self._heartbeat_loop(heartbeat))
|
||||||
|
try:
|
||||||
|
async for raw in self._ws:
|
||||||
|
if isinstance(raw, bytes):
|
||||||
|
if on_binary is None:
|
||||||
|
logger.debug("Discarding unexpected %d-byte binary frame", len(raw))
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
await on_binary(raw)
|
||||||
|
except Exception:
|
||||||
|
logger.exception("on_binary handler raised; dropping frame")
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
msg = json.loads(raw)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
logger.warning("Malformed control frame: %r", raw[:200])
|
||||||
|
continue
|
||||||
|
await on_message(msg)
|
||||||
|
finally:
|
||||||
|
hb_task.cancel()
|
||||||
|
try:
|
||||||
|
await hb_task
|
||||||
|
except (asyncio.CancelledError, Exception):
|
||||||
|
pass
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
raise
|
||||||
|
except Exception as exc:
|
||||||
|
if self._stop.is_set():
|
||||||
|
break
|
||||||
|
logger.warning("WS error: %s — reconnecting in %.1fs", exc, self.reconnect_pause)
|
||||||
|
finally:
|
||||||
|
try:
|
||||||
|
if self._ws is not None:
|
||||||
|
await self._ws.close()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
self._ws = None
|
||||||
|
if self._stop.is_set():
|
||||||
|
break
|
||||||
|
await asyncio.sleep(self.reconnect_pause)
|
||||||
|
|
||||||
|
async def _heartbeat_loop(self, heartbeat: HeartbeatBuilder) -> None:
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
await self.send_json(heartbeat())
|
||||||
|
except Exception as exc:
|
||||||
|
logger.debug("Heartbeat send failed: %s", exc)
|
||||||
|
return
|
||||||
|
await asyncio.sleep(self.heartbeat_interval)
|
||||||
54
src/ria_toolkit_oss/annotations/__init__.py
Normal file
54
src/ria_toolkit_oss/annotations/__init__.py
Normal file
|
|
@ -0,0 +1,54 @@
|
||||||
|
"""
|
||||||
|
The annotations package contains tools and utilities for creating, managing, and processing annotations.
|
||||||
|
|
||||||
|
Provides automatic annotation generation using various signal detection algorithms:
|
||||||
|
- Energy-based detection (detect_signals_energy)
|
||||||
|
- CUSUM-based segmentation (annotate_with_cusum)
|
||||||
|
- Threshold-based qualification (threshold_qualifier)
|
||||||
|
- Signal isolation and extraction (isolate_signal)
|
||||||
|
- Occupied bandwidth analysis (calculate_occupied_bandwidth, calculate_nominal_bandwidth)
|
||||||
|
|
||||||
|
All detection functions return Recording objects with added annotations.
|
||||||
|
"""
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
# Energy-based detection
|
||||||
|
"detect_signals_energy",
|
||||||
|
"calculate_occupied_bandwidth",
|
||||||
|
"calculate_nominal_bandwidth",
|
||||||
|
"calculate_full_detected_bandwidth",
|
||||||
|
"annotate_with_obw",
|
||||||
|
# CUSUM detection
|
||||||
|
"annotate_with_cusum",
|
||||||
|
# Threshold detection
|
||||||
|
"threshold_qualifier",
|
||||||
|
# Parallel signal separation (Phase 2)
|
||||||
|
"find_spectral_components",
|
||||||
|
"split_annotation_by_components",
|
||||||
|
"split_recording_annotations",
|
||||||
|
# Signal isolation
|
||||||
|
"isolate_signal",
|
||||||
|
# Annotation transforms
|
||||||
|
"remove_contained_boxes",
|
||||||
|
"is_annotation_contained",
|
||||||
|
# Dataset creation
|
||||||
|
"qualify_slice_from_annotations",
|
||||||
|
]
|
||||||
|
|
||||||
|
from .annotation_transforms import is_annotation_contained, remove_contained_boxes
|
||||||
|
from .cusum_annotator import annotate_with_cusum
|
||||||
|
from .energy_detector import (
|
||||||
|
annotate_with_obw,
|
||||||
|
calculate_full_detected_bandwidth,
|
||||||
|
calculate_nominal_bandwidth,
|
||||||
|
calculate_occupied_bandwidth,
|
||||||
|
detect_signals_energy,
|
||||||
|
)
|
||||||
|
from .parallel_signal_separator import (
|
||||||
|
find_spectral_components,
|
||||||
|
split_annotation_by_components,
|
||||||
|
split_recording_annotations,
|
||||||
|
)
|
||||||
|
from .qualify_slice import qualify_slice_from_annotations
|
||||||
|
from .signal_isolation import isolate_signal
|
||||||
|
from .threshold_qualifier import threshold_qualifier
|
||||||
55
src/ria_toolkit_oss/annotations/annotation_transforms.py
Normal file
55
src/ria_toolkit_oss/annotations/annotation_transforms.py
Normal file
|
|
@ -0,0 +1,55 @@
|
||||||
|
from ria_toolkit_oss.data.annotation import Annotation
|
||||||
|
|
||||||
|
# TODO figure out how to transfer labels in the merge case
|
||||||
|
|
||||||
|
|
||||||
|
def remove_contained_boxes(annotations: list[Annotation]):
|
||||||
|
"""
|
||||||
|
Remove all annotations (bounding boxes) that are entirely contained within other boxes in the list.
|
||||||
|
|
||||||
|
:param annotations: A list of Annotation objects.
|
||||||
|
:type annotations: list[Annotation]
|
||||||
|
|
||||||
|
:returns: A new list of Annotation objects.
|
||||||
|
:rtype: list[Annotation]"""
|
||||||
|
|
||||||
|
output_boxes = []
|
||||||
|
|
||||||
|
for i in range(len(annotations)):
|
||||||
|
contained = False
|
||||||
|
for j in range(len(annotations)):
|
||||||
|
if i != j and is_annotation_contained(annotations[i], annotations[j]):
|
||||||
|
contained = True
|
||||||
|
break
|
||||||
|
|
||||||
|
if not contained:
|
||||||
|
output_boxes.append(annotations[i])
|
||||||
|
|
||||||
|
return output_boxes
|
||||||
|
|
||||||
|
|
||||||
|
def is_annotation_contained(inner: Annotation, outer: Annotation) -> bool:
|
||||||
|
"""
|
||||||
|
Check if an annotation box is entirely contained within another annotation bounding box.
|
||||||
|
|
||||||
|
:param inner: The inner box.
|
||||||
|
:type inner: Annotation.
|
||||||
|
:param outer: The outer box.
|
||||||
|
:type outer: Annotation.
|
||||||
|
|
||||||
|
:returns: True if inner is within outer, false otherwise.
|
||||||
|
:rtype: bool
|
||||||
|
"""
|
||||||
|
|
||||||
|
inner_sample_stop = inner.sample_start + inner.sample_count
|
||||||
|
outer_sample_stop = outer.sample_start + outer.sample_count
|
||||||
|
|
||||||
|
if inner.sample_start > outer.sample_start and inner_sample_stop < outer_sample_stop:
|
||||||
|
if inner.freq_lower_edge > outer.freq_lower_edge and inner.freq_upper_edge < outer.freq_upper_edge:
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def merge_annotations(annotations: list[Annotation], overlap_threshold) -> list[Annotation]:
|
||||||
|
raise NotImplementedError
|
||||||
203
src/ria_toolkit_oss/annotations/cusum_annotator.py
Normal file
203
src/ria_toolkit_oss/annotations/cusum_annotator.py
Normal file
|
|
@ -0,0 +1,203 @@
|
||||||
|
import json
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from ria_toolkit_oss.data import Annotation, Recording
|
||||||
|
|
||||||
|
|
||||||
|
def annotate_with_cusum(
|
||||||
|
recording: Recording,
|
||||||
|
label: Optional[str] = "segment",
|
||||||
|
window_size: Optional[int] = 1,
|
||||||
|
min_duration: Optional[float] = None,
|
||||||
|
tolerance: Optional[int] = None,
|
||||||
|
annotation_type: Optional[str] = "standalone",
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Add annotations that divide the recording into distinct time segments.
|
||||||
|
|
||||||
|
This algorithm computes the cumulative sum of the sample magnitudes and
|
||||||
|
determines break points in the signal.
|
||||||
|
|
||||||
|
This tool can be used to find points where a signal turns on or off, or
|
||||||
|
changes between a low and high amplitude.
|
||||||
|
|
||||||
|
:param recording: A ``Recording`` object to annotate.
|
||||||
|
:type recording: ``ria_toolkit_oss.data.Recording``
|
||||||
|
:param label: Label for the detected segments.
|
||||||
|
:type label: str
|
||||||
|
:param window_size: The length (in samples) of the moving average window.
|
||||||
|
:type window_size: int
|
||||||
|
:param min_duration: The minimum duration (in ms) of a segment.
|
||||||
|
The algorithm will not produce annotations shorter than this length.
|
||||||
|
:type min_duration: float
|
||||||
|
:param tolerance: The minimum length (in samples) of a segment.
|
||||||
|
:type tolerance: int
|
||||||
|
:param annotation_type: Annotation type (standalone, parallel, intersection).
|
||||||
|
:type annotation_type: str
|
||||||
|
"""
|
||||||
|
|
||||||
|
sample_rate = recording.metadata["sample_rate"]
|
||||||
|
center_frequency = recording.metadata.get("center_frequency", 0)
|
||||||
|
|
||||||
|
# Create an object of the time segmenter
|
||||||
|
time_segmenter = TimeSegmenter(sample_rate, min_duration, window_size, tolerance)
|
||||||
|
|
||||||
|
change_points = time_segmenter.apply(recording.data[0])
|
||||||
|
|
||||||
|
time_segments_indices = np.append(np.insert(change_points, 0, 0), len(recording.data[0]))
|
||||||
|
annotations = []
|
||||||
|
for i in range(len(time_segments_indices) - 1):
|
||||||
|
# Build comment JSON with type metadata
|
||||||
|
comment_data = {
|
||||||
|
"type": annotation_type,
|
||||||
|
"generator": "cusum_annotator",
|
||||||
|
"params": {
|
||||||
|
"window_size": window_size,
|
||||||
|
"min_duration": min_duration,
|
||||||
|
"tolerance": tolerance,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
f_min, f_max = detect_frequency(
|
||||||
|
signal=recording.data[0],
|
||||||
|
start=time_segments_indices[i],
|
||||||
|
stop=time_segments_indices[i + 1],
|
||||||
|
sample_rate=sample_rate,
|
||||||
|
)
|
||||||
|
|
||||||
|
annotations.append(
|
||||||
|
Annotation(
|
||||||
|
sample_start=time_segments_indices[i],
|
||||||
|
sample_count=time_segments_indices[i + 1] - time_segments_indices[i],
|
||||||
|
freq_lower_edge=center_frequency + f_min,
|
||||||
|
freq_upper_edge=center_frequency + f_max,
|
||||||
|
label=label,
|
||||||
|
comment=json.dumps(comment_data),
|
||||||
|
detail={"generator": "cusum_annotator"},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return Recording(data=recording.data, metadata=recording.metadata, annotations=recording.annotations + annotations)
|
||||||
|
|
||||||
|
|
||||||
|
def _compute_cusum(_signal, sample_rate: int, tolerance: int = None, min_duration: float = -1):
|
||||||
|
"""
|
||||||
|
This function efficiently computes the cumulative sum of a give list (_signal), with an optional tolerance.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
- _signal: array of iq samples.
|
||||||
|
- Tolerance: the least acceptable length of a block, Defaults to None.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
- cusum (array): Array of the cumulative sum of the given list
|
||||||
|
- sample_rate (int): __description_
|
||||||
|
- change_points (array): Array of the indices at which a change in the CUSUM direction happens.
|
||||||
|
- min_duration (float): The least acceptable time width of each segment (in ms). Defaults to -1.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# efficiently calculate the running sum of the signal
|
||||||
|
# cusum = list(itertools.accumulate((_signal - np.mean(_signal))))
|
||||||
|
x = _signal - np.mean(_signal)
|
||||||
|
cusum = np.cumsum(x)
|
||||||
|
|
||||||
|
# 'diff' computes the differences between the consecutive values,
|
||||||
|
# then 'sign' determines if it is +ve or -ve.
|
||||||
|
change_indicators = np.sign(np.diff(cusum))
|
||||||
|
change_points = np.where(np.diff(change_indicators))[0] + 1
|
||||||
|
|
||||||
|
# Limit the change_points
|
||||||
|
# Reject those whose number of samples < minimum accepted #n of samples in (min duration) ms.
|
||||||
|
if min_duration is not None and min_duration > 0:
|
||||||
|
min_samples_wide = int(min_duration * sample_rate / 1000)
|
||||||
|
segments_lengths = np.diff(change_points)
|
||||||
|
segments_lengths = np.insert(segments_lengths, 0, change_points[0])
|
||||||
|
change_points = change_points[np.where(segments_lengths > min_samples_wide)[0]]
|
||||||
|
return cusum, change_points
|
||||||
|
|
||||||
|
|
||||||
|
def detect_frequency(signal, start, stop, sample_rate):
|
||||||
|
signal_segment = signal[start:stop]
|
||||||
|
if len(signal_segment) > 0:
|
||||||
|
fft_data = np.abs(np.fft.fftshift(np.fft.fft(signal_segment)))
|
||||||
|
fft_freqs = np.fft.fftshift(np.fft.fftfreq(len(signal_segment), 1 / sample_rate))
|
||||||
|
|
||||||
|
# Use a spectral threshold to find the 'height' of the orange block
|
||||||
|
spectral_thresh = np.max(fft_data) * 0.15
|
||||||
|
sig_indices = np.where(fft_data > spectral_thresh)[0]
|
||||||
|
|
||||||
|
if len(sig_indices) > 4:
|
||||||
|
return fft_freqs[sig_indices[0]], fft_freqs[sig_indices[-1]]
|
||||||
|
else:
|
||||||
|
return -sample_rate / 4, sample_rate / 4
|
||||||
|
else:
|
||||||
|
return -sample_rate / 4, sample_rate / 4
|
||||||
|
|
||||||
|
|
||||||
|
class TimeSegmenter:
|
||||||
|
"""Time Segmenter class, it creates a segmenter object with certain\
|
||||||
|
characteristics to easily split an input signal to segments based on\
|
||||||
|
the cumulative sum of deviations (of the signal mean)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self, sample_rate: int, min_duration: float = 1, moving_average_window: int = 3, tolerance: int = None
|
||||||
|
):
|
||||||
|
"""_summary_
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sample_rate (int): _description_
|
||||||
|
min_duration (float, optional): _description_. Defaults to 1.
|
||||||
|
moving_average_window (int, optional): _description_. Defaults to 3.
|
||||||
|
tolerance (int, optional): _description_. Defaults to None.
|
||||||
|
"""
|
||||||
|
self.sample_rate = sample_rate
|
||||||
|
self.min_duration = min_duration
|
||||||
|
self.moving_average_window = moving_average_window
|
||||||
|
self._moving_avg_filter = self._init_filter()
|
||||||
|
self.tolerance = tolerance
|
||||||
|
|
||||||
|
def _init_filter(self):
|
||||||
|
"""_summary_
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
_type_: _description_
|
||||||
|
"""
|
||||||
|
return np.ones(self.moving_average_window) / self.moving_average_window
|
||||||
|
|
||||||
|
def _apply_filter(self, iqsignal: np.array):
|
||||||
|
"""_summary_
|
||||||
|
|
||||||
|
Args:
|
||||||
|
iqsignal (np.array): _description_
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
_type_: _description_
|
||||||
|
"""
|
||||||
|
return np.convolve(abs(iqsignal), self._moving_avg_filter, mode="same")
|
||||||
|
|
||||||
|
def _create_segments(self, iq_signal: np.array, change_points: np.array):
|
||||||
|
"""_summary_
|
||||||
|
|
||||||
|
Args:
|
||||||
|
iq_signal (np.array): _description_
|
||||||
|
change_points (np.array): _description_
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
_type_: _description_
|
||||||
|
"""
|
||||||
|
return np.split(iq_signal, change_points)
|
||||||
|
|
||||||
|
def apply(self, iq_signal: np.array):
|
||||||
|
"""_summary_
|
||||||
|
|
||||||
|
Args:
|
||||||
|
iq_signal (np.array): _description_
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
_type_: _description_
|
||||||
|
"""
|
||||||
|
smoothed_signal = self._apply_filter(iq_signal)
|
||||||
|
_, change_points = _compute_cusum(smoothed_signal, self.sample_rate, self.tolerance, self.min_duration)
|
||||||
|
# segments = self._create_segments(iq_signal, change_points)
|
||||||
|
return change_points
|
||||||
438
src/ria_toolkit_oss/annotations/energy_detector.py
Normal file
438
src/ria_toolkit_oss/annotations/energy_detector.py
Normal file
|
|
@ -0,0 +1,438 @@
|
||||||
|
"""
|
||||||
|
Energy-based signal detection and bandwidth analysis.
|
||||||
|
|
||||||
|
Provides automatic annotation generation using energy-based signal detection
|
||||||
|
and occupied bandwidth calculation following ITU-R SM.328 standard.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from scipy.signal import filtfilt
|
||||||
|
|
||||||
|
from ria_toolkit_oss.data import Annotation, Recording
|
||||||
|
|
||||||
|
|
||||||
|
def detect_signals_energy(
|
||||||
|
recording: Recording,
|
||||||
|
k: int = 10,
|
||||||
|
threshold_factor: float = 1.2,
|
||||||
|
window_size: int = 200,
|
||||||
|
min_distance: int = 5000,
|
||||||
|
label: str = "signal",
|
||||||
|
annotation_type: str = "standalone",
|
||||||
|
freq_method: str = "nbw",
|
||||||
|
nfft: int = None,
|
||||||
|
obw_power: float = 0.99,
|
||||||
|
) -> Recording:
|
||||||
|
"""
|
||||||
|
Detect signal bursts using energy-based method with adaptive noise floor estimation.
|
||||||
|
|
||||||
|
This algorithm smooths the signal with a moving average filter, estimates the noise
|
||||||
|
floor from k segments, applies a threshold to detect regions above noise, and merges
|
||||||
|
nearby detections. Detected time boundaries are then assigned frequency bounds based
|
||||||
|
on the selected frequency method.
|
||||||
|
|
||||||
|
Time Detection Algorithm:
|
||||||
|
1. Smooth signal using moving average (envelope detection)
|
||||||
|
2. Divide smoothed signal into k segments
|
||||||
|
3. Estimate noise floor as median of segment mean powers
|
||||||
|
4. Detect regions where power exceeds threshold_factor * noise_floor
|
||||||
|
5. Merge regions closer than min_distance samples
|
||||||
|
|
||||||
|
Frequency Bounding (freq_method):
|
||||||
|
- 'nbw': Nominal bandwidth (OBW + center frequency) - DEFAULT
|
||||||
|
- 'obw': Occupied bandwidth (99.99% power, includes siedelobes)
|
||||||
|
- 'full-detected': Lowest to highest spectral component
|
||||||
|
- 'full-bandwidth': Entire Nyquist span (center_freq ± sample_rate/2)
|
||||||
|
|
||||||
|
:param recording: Recording to analyze
|
||||||
|
:type recording: Recording
|
||||||
|
:param k: Number of segments for noise floor estimation (default: 10)
|
||||||
|
:type k: int
|
||||||
|
:param threshold_factor: Threshold multiplier above noise floor (typical: 1.2-2.0, default: 1.2)
|
||||||
|
:type threshold_factor: float
|
||||||
|
:param window_size: Moving average window size in samples (default: 200)
|
||||||
|
:type window_size: int
|
||||||
|
:param min_distance: Minimum distance between separate signals in samples (default: 5000)
|
||||||
|
:type min_distance: int
|
||||||
|
:param label: Label for detected annotations (default: "signal")
|
||||||
|
:type label: str
|
||||||
|
:param annotation_type: Annotation type (standalone, parallel, intersection, default: standalone)
|
||||||
|
:type annotation_type: str
|
||||||
|
:param freq_method: How to calculate frequency bounds (default: 'nbw')
|
||||||
|
:type freq_method: str
|
||||||
|
:param nfft: FFT size for frequency calculations (default: None)
|
||||||
|
:type nfft: int
|
||||||
|
:param obw_power: Power percentage for OBW (0.9999 = 99.99%, default: 0.99)
|
||||||
|
:type obw_power: float
|
||||||
|
|
||||||
|
:returns: New Recording with added annotations
|
||||||
|
:rtype: Recording
|
||||||
|
|
||||||
|
**Example**::
|
||||||
|
|
||||||
|
>>> from ria.io import load_recording
|
||||||
|
>>> from ria_toolkit_oss.annotations import detect_signals_energy
|
||||||
|
>>> recording = load_recording("capture.sigmf")
|
||||||
|
|
||||||
|
>>> # Detect with NBW frequency bounds (default, best for real signals)
|
||||||
|
>>> annotated = detect_signals_energy(recording, label="burst")
|
||||||
|
|
||||||
|
>>> # Detect with OBW (more conservative, includes siedelobes)
|
||||||
|
>>> annotated = detect_signals_energy(
|
||||||
|
... recording, label="burst", freq_method="obw"
|
||||||
|
... )
|
||||||
|
|
||||||
|
>>> # Detect with full detected range (captures all spectral components)
|
||||||
|
>>> annotated = detect_signals_energy(
|
||||||
|
... recording, label="burst", freq_method="full-detected"
|
||||||
|
... )
|
||||||
|
"""
|
||||||
|
# Extract signal data (use first channel only)
|
||||||
|
signal = recording.data[0]
|
||||||
|
|
||||||
|
# Calculate smoothed signal power
|
||||||
|
kernel = np.ones(window_size) / window_size
|
||||||
|
smoothed_power = filtfilt(kernel, [1], np.abs(signal) ** 2)
|
||||||
|
|
||||||
|
# Estimate noise floor using segment-based median (robust to signal presence)
|
||||||
|
segments = np.array_split(smoothed_power, k)
|
||||||
|
noise_floor = np.median([np.mean(s) for s in segments])
|
||||||
|
|
||||||
|
# Detect signal boundaries (regions above threshold)
|
||||||
|
enter = noise_floor * threshold_factor
|
||||||
|
exit = enter * 0.8
|
||||||
|
boundaries = []
|
||||||
|
start = None
|
||||||
|
active = False
|
||||||
|
|
||||||
|
for i, p in enumerate(smoothed_power):
|
||||||
|
if not active and p > enter:
|
||||||
|
start = i
|
||||||
|
active = True
|
||||||
|
elif active and p < exit:
|
||||||
|
boundaries.append((start, i - start))
|
||||||
|
active = False
|
||||||
|
|
||||||
|
if active:
|
||||||
|
boundaries.append((start, len(smoothed_power) - start))
|
||||||
|
|
||||||
|
# Merge boundaries that are closer than min_distance
|
||||||
|
merged_boundaries = []
|
||||||
|
if boundaries:
|
||||||
|
start, length = boundaries[0]
|
||||||
|
for next_start, next_length in boundaries[1:]:
|
||||||
|
if next_start - (start + length) < min_distance:
|
||||||
|
# Merge with current boundary
|
||||||
|
length = next_start + next_length - start
|
||||||
|
else:
|
||||||
|
# Save current and start new boundary
|
||||||
|
merged_boundaries.append((start, length))
|
||||||
|
start, length = next_start, next_length
|
||||||
|
# Add final boundary
|
||||||
|
merged_boundaries.append((start, length))
|
||||||
|
|
||||||
|
# Create annotations from detected boundaries
|
||||||
|
sample_rate = recording.metadata["sample_rate"]
|
||||||
|
center_frequency = recording.metadata.get("center_frequency", 0)
|
||||||
|
|
||||||
|
# Validate frequency method
|
||||||
|
valid_freq_methods = ["nbw", "obw", "full-detected", "full-bandwidth"]
|
||||||
|
if freq_method not in valid_freq_methods:
|
||||||
|
raise ValueError(f"Invalid freq_method '{freq_method}'. " f"Must be one of: {', '.join(valid_freq_methods)}")
|
||||||
|
|
||||||
|
annotations = []
|
||||||
|
for start_sample, sample_count in merged_boundaries:
|
||||||
|
# Calculate frequency bounds based on method
|
||||||
|
freq_lower, freq_upper = calculate_frequency_bounds(
|
||||||
|
freq_method, center_frequency, sample_rate, nfft, signal, start_sample, sample_count, obw_power
|
||||||
|
)
|
||||||
|
# Build comment JSON with type metadata
|
||||||
|
comment_data = {
|
||||||
|
"type": annotation_type,
|
||||||
|
"generator": "energy_detector",
|
||||||
|
"freq_method": freq_method,
|
||||||
|
"params": {
|
||||||
|
"threshold_factor": threshold_factor,
|
||||||
|
"window_size": window_size,
|
||||||
|
"noise_floor": float(noise_floor),
|
||||||
|
"threshold": float(enter),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
anno = Annotation(
|
||||||
|
sample_start=start_sample,
|
||||||
|
sample_count=sample_count,
|
||||||
|
freq_lower_edge=freq_lower,
|
||||||
|
freq_upper_edge=freq_upper,
|
||||||
|
label=label,
|
||||||
|
comment=json.dumps(comment_data),
|
||||||
|
detail={"generator": "energy_detector", "freq_method": freq_method},
|
||||||
|
)
|
||||||
|
annotations.append(anno)
|
||||||
|
|
||||||
|
return Recording(data=recording.data, metadata=recording.metadata, annotations=recording.annotations + annotations)
|
||||||
|
|
||||||
|
|
||||||
|
def calculate_occupied_bandwidth(
|
||||||
|
signal: np.ndarray,
|
||||||
|
sampling_rate: float,
|
||||||
|
nfft: int = None,
|
||||||
|
power_percentage: float = 0.99,
|
||||||
|
):
|
||||||
|
if nfft is None:
|
||||||
|
nfft = max(65536, 2 ** int(np.floor(np.log2(len(signal)))))
|
||||||
|
|
||||||
|
window = np.blackman(len(signal))
|
||||||
|
spec = np.fft.fftshift(np.fft.fft(signal * window, n=nfft))
|
||||||
|
|
||||||
|
psd = np.abs(spec) ** 2
|
||||||
|
psd = psd / psd.sum() # normalize
|
||||||
|
|
||||||
|
freqs = np.fft.fftshift(np.fft.fftfreq(nfft, 1 / sampling_rate))
|
||||||
|
|
||||||
|
cdf = np.cumsum(psd)
|
||||||
|
|
||||||
|
tail = (1 - power_percentage) / 2
|
||||||
|
|
||||||
|
lower_idx = np.searchsorted(cdf, tail)
|
||||||
|
upper_idx = np.searchsorted(cdf, 1 - tail)
|
||||||
|
|
||||||
|
return freqs[upper_idx] - freqs[lower_idx], freqs[lower_idx], freqs[upper_idx]
|
||||||
|
|
||||||
|
|
||||||
|
def calculate_nominal_bandwidth(
|
||||||
|
signal: np.ndarray,
|
||||||
|
sampling_rate: float,
|
||||||
|
nfft: int = None,
|
||||||
|
power_percentage: float = 0.99,
|
||||||
|
) -> Tuple[float, float]:
|
||||||
|
"""
|
||||||
|
Calculate nominal bandwidth and center frequency.
|
||||||
|
|
||||||
|
Nominal bandwidth (NBW) is the occupied bandwidth along with the center
|
||||||
|
frequency of the signal's spectral occupancy. Useful for characterizing
|
||||||
|
signals with unknown or drifting center frequencies.
|
||||||
|
|
||||||
|
:param signal: Complex IQ signal samples
|
||||||
|
:type signal: np.ndarray
|
||||||
|
:param sampling_rate: Sample rate in Hz
|
||||||
|
:type sampling_rate: float
|
||||||
|
:param nfft: FFT size
|
||||||
|
:type nfft: int
|
||||||
|
:param power_percentage: Fraction of power to contain
|
||||||
|
:type power_percentage: float
|
||||||
|
|
||||||
|
:returns: Tuple of (nominal_bandwidth_hz, center_frequency_hz)
|
||||||
|
:rtype: Tuple[float, float]
|
||||||
|
|
||||||
|
**Example**::
|
||||||
|
|
||||||
|
>>> from ria_toolkit_oss.annotations import calculate_nominal_bandwidth
|
||||||
|
>>> nbw, center = calculate_nominal_bandwidth(signal, sampling_rate=10e6)
|
||||||
|
>>> print(f"NBW: {nbw/1e6:.3f} MHz, Center: {center/1e6:.3f} MHz")
|
||||||
|
"""
|
||||||
|
bw, lower_freq, upper_freq = calculate_occupied_bandwidth(signal, sampling_rate, nfft, power_percentage)
|
||||||
|
|
||||||
|
# Center frequency is midpoint of occupied band
|
||||||
|
center_freq = (lower_freq + upper_freq) / 2
|
||||||
|
|
||||||
|
return lower_freq, upper_freq, center_freq
|
||||||
|
|
||||||
|
|
||||||
|
def calculate_full_detected_bandwidth(
|
||||||
|
signal: np.ndarray,
|
||||||
|
sampling_rate: float,
|
||||||
|
nfft: int = None,
|
||||||
|
start_offset: int = 1000,
|
||||||
|
) -> Tuple[float, float, float]:
|
||||||
|
"""
|
||||||
|
Calculate frequency range from lowest to highest spectral component.
|
||||||
|
|
||||||
|
Unlike OBW/NBW which define a power-based bandwidth, this calculates
|
||||||
|
the absolute frequency span from the lowest non-zero spectral component
|
||||||
|
to the highest non-zero component.
|
||||||
|
|
||||||
|
Useful for:
|
||||||
|
- Signals with spectral gaps
|
||||||
|
- Multiple parallel signals (captures all of them)
|
||||||
|
- Understanding total occupied spectrum vs. actual bandwidth
|
||||||
|
|
||||||
|
:param signal: Complex IQ signal samples
|
||||||
|
:type signal: np.ndarray
|
||||||
|
:param sampling_rate: Sample rate in Hz
|
||||||
|
:type sampling_rate: float
|
||||||
|
:param nfft: FFT size
|
||||||
|
:type nfft: int
|
||||||
|
:param start_offset: Skip samples at start
|
||||||
|
:type start_offset: int
|
||||||
|
|
||||||
|
:returns: Tuple of (bandwidth_hz, lower_freq_hz, upper_freq_hz)
|
||||||
|
:rtype: Tuple[float, float, float]
|
||||||
|
|
||||||
|
**Example**::
|
||||||
|
|
||||||
|
>>> # Signal with two components at different frequencies
|
||||||
|
>>> bw, f_low, f_high = calculate_full_detected_bandwidth(
|
||||||
|
... signal, sampling_rate=10e6, nfft=65536
|
||||||
|
... )
|
||||||
|
>>> print(f"Full span: {f_low/1e6:.3f} to {f_high/1e6:.3f} MHz")
|
||||||
|
"""
|
||||||
|
# Validate input
|
||||||
|
if len(signal) < nfft + start_offset:
|
||||||
|
raise ValueError(
|
||||||
|
f"Signal too short: need {nfft + start_offset} samples, "
|
||||||
|
f"got {len(signal)}. Reduce nfft or start_offset."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Extract segment
|
||||||
|
signal_segment = signal[start_offset : nfft + start_offset]
|
||||||
|
|
||||||
|
# Compute FFT and power spectral density
|
||||||
|
freq_spectrum = np.fft.fft(signal_segment, n=nfft)
|
||||||
|
psd = np.abs(freq_spectrum) ** 2
|
||||||
|
|
||||||
|
# Shift to center DC
|
||||||
|
psd_shifted = np.fft.fftshift(psd)
|
||||||
|
freq_bins = np.fft.fftshift(np.fft.fftfreq(nfft, 1 / sampling_rate))
|
||||||
|
|
||||||
|
# Find noise floor (mean of lowest 10% of bins) and all bins above noise floor
|
||||||
|
noise_floor = np.mean(np.sort(psd_shifted)[: int(len(psd_shifted) * 0.1)])
|
||||||
|
above_noise = np.where(psd_shifted > noise_floor * 1.5)[0]
|
||||||
|
|
||||||
|
if len(above_noise) == 0:
|
||||||
|
# No signal above noise, return zero bandwidth
|
||||||
|
return 0.0, 0.0, 0.0
|
||||||
|
|
||||||
|
# Get frequency range of signal components
|
||||||
|
lower_idx = above_noise[0]
|
||||||
|
upper_idx = above_noise[-1]
|
||||||
|
|
||||||
|
lower_freq = freq_bins[lower_idx]
|
||||||
|
upper_freq = freq_bins[upper_idx]
|
||||||
|
|
||||||
|
bandwidth = upper_freq - lower_freq
|
||||||
|
|
||||||
|
return bandwidth, lower_freq, upper_freq
|
||||||
|
|
||||||
|
|
||||||
|
def annotate_with_obw(
|
||||||
|
recording: Recording,
|
||||||
|
label: str = "signal",
|
||||||
|
annotation_type: str = "standalone",
|
||||||
|
nfft: int = None,
|
||||||
|
power_percentage: float = 0.99,
|
||||||
|
) -> Recording:
|
||||||
|
"""
|
||||||
|
Create a single annotation spanning the occupied bandwidth of the entire recording.
|
||||||
|
|
||||||
|
Analyzes the full recording to find its occupied bandwidth and creates an annotation
|
||||||
|
covering that frequency range for the entire time duration.
|
||||||
|
|
||||||
|
:param recording: Recording to analyze
|
||||||
|
:type recording: Recording
|
||||||
|
:param label: Annotation label
|
||||||
|
:type label: str
|
||||||
|
:param annotation_type: Annotation type
|
||||||
|
:type annotation_type: str
|
||||||
|
:param nfft: FFT size
|
||||||
|
:type nfft: int
|
||||||
|
:param power_percentage: Power percentage for OBW calculation
|
||||||
|
:type power_percentage: float
|
||||||
|
|
||||||
|
:returns: Recording with OBW annotation added
|
||||||
|
:rtype: Recording
|
||||||
|
|
||||||
|
**Example**::
|
||||||
|
|
||||||
|
>>> from ria_toolkit_oss.annotations import annotate_with_obw
|
||||||
|
>>> annotated = annotate_with_obw(recording, label="signal_obw")
|
||||||
|
"""
|
||||||
|
signal = recording.data[0]
|
||||||
|
sample_rate = recording.metadata["sample_rate"]
|
||||||
|
center_freq = recording.metadata.get("center_frequency", 0)
|
||||||
|
|
||||||
|
# Calculate OBW
|
||||||
|
obw, lower_offset, upper_offset = calculate_occupied_bandwidth(signal, sample_rate, nfft, power_percentage)
|
||||||
|
|
||||||
|
# Convert baseband offsets to absolute frequencies
|
||||||
|
freq_lower = center_freq + lower_offset
|
||||||
|
freq_upper = center_freq + upper_offset
|
||||||
|
|
||||||
|
# Create comment JSON
|
||||||
|
comment_data = {
|
||||||
|
"type": annotation_type,
|
||||||
|
"generator": "obw_annotator",
|
||||||
|
"obw_hz": float(obw),
|
||||||
|
"power_percentage": power_percentage,
|
||||||
|
"params": {"nfft": nfft},
|
||||||
|
}
|
||||||
|
|
||||||
|
# Create annotation spanning entire recording
|
||||||
|
anno = Annotation(
|
||||||
|
sample_start=0,
|
||||||
|
sample_count=len(signal),
|
||||||
|
freq_lower_edge=freq_lower,
|
||||||
|
freq_upper_edge=freq_upper,
|
||||||
|
label=label,
|
||||||
|
comment=json.dumps(comment_data),
|
||||||
|
detail={"generator": "obw_annotator", "obw_hz": float(obw)},
|
||||||
|
)
|
||||||
|
|
||||||
|
return Recording(data=recording.data, metadata=recording.metadata, annotations=recording.annotations + [anno])
|
||||||
|
|
||||||
|
|
||||||
|
def calculate_frequency_bounds(
|
||||||
|
freq_method, center_frequency, sample_rate, nfft, signal, start_sample, sample_count, obw_power
|
||||||
|
):
|
||||||
|
if freq_method == "full-bandwidth":
|
||||||
|
# Full Nyquist span
|
||||||
|
freq_lower = center_frequency - (sample_rate / 2)
|
||||||
|
freq_upper = center_frequency + (sample_rate / 2)
|
||||||
|
else:
|
||||||
|
# Extract segment for frequency analysis
|
||||||
|
segment_start = start_sample
|
||||||
|
segment_end = min(start_sample + sample_count, len(signal))
|
||||||
|
segment = signal[segment_start:segment_end]
|
||||||
|
|
||||||
|
if nfft is None or len(segment) >= nfft:
|
||||||
|
if freq_method == "nbw":
|
||||||
|
# Nominal bandwidth (OBW + center frequency)
|
||||||
|
try:
|
||||||
|
lower_freq, upper_freq, _ = calculate_nominal_bandwidth(segment, sample_rate, nfft, obw_power)
|
||||||
|
freq_lower = center_frequency + lower_freq
|
||||||
|
freq_upper = center_frequency + upper_freq
|
||||||
|
except (ValueError, IndexError):
|
||||||
|
# Fallback if calculation fails
|
||||||
|
freq_lower = center_frequency - (sample_rate / 2)
|
||||||
|
freq_upper = center_frequency + (sample_rate / 2)
|
||||||
|
|
||||||
|
elif freq_method == "obw":
|
||||||
|
# Occupied bandwidth
|
||||||
|
try:
|
||||||
|
_, f_lower, f_upper = calculate_occupied_bandwidth(segment, sample_rate, nfft, obw_power)
|
||||||
|
freq_lower = center_frequency + f_lower
|
||||||
|
freq_upper = center_frequency + f_upper
|
||||||
|
except (ValueError, IndexError):
|
||||||
|
# Fallback if calculation fails
|
||||||
|
freq_lower = center_frequency - (sample_rate / 2)
|
||||||
|
freq_upper = center_frequency + (sample_rate / 2)
|
||||||
|
|
||||||
|
elif freq_method == "full-detected":
|
||||||
|
# Full detected range (lowest to highest component)
|
||||||
|
try:
|
||||||
|
_, f_lower, f_upper = calculate_full_detected_bandwidth(segment, sample_rate, nfft)
|
||||||
|
freq_lower = center_frequency + f_lower
|
||||||
|
freq_upper = center_frequency + f_upper
|
||||||
|
except (ValueError, IndexError):
|
||||||
|
# Fallback if calculation fails
|
||||||
|
freq_lower = center_frequency - (sample_rate / 2)
|
||||||
|
freq_upper = center_frequency + (sample_rate / 2)
|
||||||
|
else:
|
||||||
|
# Segment too short for FFT, use full bandwidth
|
||||||
|
freq_lower = center_frequency - (sample_rate / 2)
|
||||||
|
freq_upper = center_frequency + (sample_rate / 2)
|
||||||
|
|
||||||
|
return freq_lower, freq_upper
|
||||||
435
src/ria_toolkit_oss/annotations/parallel_signal_separator.py
Normal file
435
src/ria_toolkit_oss/annotations/parallel_signal_separator.py
Normal file
|
|
@ -0,0 +1,435 @@
|
||||||
|
"""
|
||||||
|
Parallel signal separation for multi-component frequency-offset signals.
|
||||||
|
|
||||||
|
Provides methods to detect and separate overlapping frequency-domain signals
|
||||||
|
that occupy the same time window but different frequency bands.
|
||||||
|
|
||||||
|
This module implements **spectral peak detection** to identify distinct frequency
|
||||||
|
components and split single time-domain annotations into frequency-specific
|
||||||
|
sub-annotations.
|
||||||
|
|
||||||
|
**Key Design Decisions** (per Codex review):
|
||||||
|
|
||||||
|
1. **Complex IQ Support**: Uses `scipy.signal.welch` with `return_onesided=False`
|
||||||
|
for proper complex signal handling. Window length automatically adapts to
|
||||||
|
signal length via `nperseg=min(nfft, len(signal))` to handle bursts <nfft.
|
||||||
|
|
||||||
|
2. **Frequency Representation**: Components are detected in **relative** frequency
|
||||||
|
(baseband, centered at 0 Hz). Caller must add RF center_frequency_hz when
|
||||||
|
writing to SigMF annotations. This separation of concerns avoids the frequency
|
||||||
|
context bug where absolute Hz would be meaningless for baseband processing.
|
||||||
|
|
||||||
|
3. **Bandwidth Estimation**: Dual strategy avoids -3dB limitations:
|
||||||
|
- Primary: -3dB rolloff for typical narrowband signals
|
||||||
|
- Fallback: Cumulative power (99% like OBW) for wide/OFDM signals
|
||||||
|
- Auto-fallback when -3dB bandwidth is anomalous
|
||||||
|
|
||||||
|
4. **Noise Floor**: Auto-estimated via `np.percentile(psd_db, 10)` from data
|
||||||
|
to adapt across hardware (Pluto vs. ThinkRF). User can override if needed.
|
||||||
|
|
||||||
|
5. **Filter Sizing (Optional FIR extraction)**: When extracting components,
|
||||||
|
uses Kaiser window FIR with proper stopband specification. Auto-sizes
|
||||||
|
numtaps based on desired transition bandwidth. Includes downsampling
|
||||||
|
guidance for long captures.
|
||||||
|
|
||||||
|
6. **CLI Surface**: Single `separate` subcommand for all separation operations.
|
||||||
|
Can be chained after any detector or used standalone on existing annotations.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
Two WiFi channels captured simultaneously:
|
||||||
|
|
||||||
|
>>> from ria_toolkit_oss.annotations import find_spectral_components
|
||||||
|
>>> # Detect the two distinct channels (returns relative frequencies)
|
||||||
|
>>> components = find_spectral_components(signal, sampling_rate=20e6)
|
||||||
|
>>> print(f"Found {len(components)} components")
|
||||||
|
Found 2 components
|
||||||
|
|
||||||
|
The module is designed to work with detected time-domain annotations,
|
||||||
|
allowing splitting of overlapping signals into separate training samples.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from scipy import ndimage
|
||||||
|
from scipy import signal as scipy_signal
|
||||||
|
|
||||||
|
from ria_toolkit_oss.data import Annotation, Recording
|
||||||
|
|
||||||
|
|
||||||
|
def find_spectral_components(
|
||||||
|
signal_data: np.ndarray,
|
||||||
|
sampling_rate: float,
|
||||||
|
nfft: int = 65536,
|
||||||
|
noise_threshold_db: Optional[float] = None,
|
||||||
|
min_component_bw: float = 50e3,
|
||||||
|
time_percentile: float = 70.0,
|
||||||
|
) -> List[Tuple[float, float, float]]:
|
||||||
|
"""
|
||||||
|
Find distinct frequency components using spectral peak detection.
|
||||||
|
|
||||||
|
Identifies separate frequency components in a signal by analyzing the power
|
||||||
|
spectral density and finding peaks corresponding to distinct signals. This is
|
||||||
|
useful for separating parallel signals that occupy different frequency bands.
|
||||||
|
|
||||||
|
**Frequency Representation**: Returns frequencies in **baseband/relative** Hz
|
||||||
|
(centered at 0). To get absolute RF frequencies, add center_frequency_hz from
|
||||||
|
recording metadata to all returned values.
|
||||||
|
|
||||||
|
Algorithm:
|
||||||
|
1. Compute power spectral density using Welch (properly handles complex IQ)
|
||||||
|
2. Auto-estimate noise floor from data if not specified
|
||||||
|
3. Smooth PSD to reduce spurious peaks
|
||||||
|
4. Find local maxima above noise floor
|
||||||
|
5. Estimate bandwidth per peak using -3dB (fallback: cumulative power)
|
||||||
|
6. Filter components below minimum bandwidth threshold
|
||||||
|
|
||||||
|
:param signal_data: Complex IQ signal samples (np.complex64/128)
|
||||||
|
:type signal_data: np.ndarray
|
||||||
|
:param sampling_rate: Sample rate in Hz
|
||||||
|
:type sampling_rate: float
|
||||||
|
:param nfft: FFT size / window length for Welch. Automatically capped at
|
||||||
|
signal length to handle bursts (default: 65536)
|
||||||
|
:type nfft: int
|
||||||
|
:param noise_threshold_db: Minimum SNR threshold in dB. If None (default),
|
||||||
|
auto-estimates as np.percentile(psd_db, 10).
|
||||||
|
Adapt this across hardware (Pluto: ~-100, ThinkRF: ~-60).
|
||||||
|
:type noise_threshold_db: Optional[float]
|
||||||
|
:param min_component_bw: Minimum component bandwidth in Hz (default: 50 kHz)
|
||||||
|
:type min_component_bw: float
|
||||||
|
:param power_threshold: Cumulative power threshold for fallback bandwidth
|
||||||
|
estimation (default: 0.99 = 99% power, like OBW)
|
||||||
|
:type power_threshold: float
|
||||||
|
|
||||||
|
:returns: List of (center_freq_hz, lower_freq_hz, upper_freq_hz) tuples.
|
||||||
|
**All frequencies are relative (baseband, 0-centered).**
|
||||||
|
Add recording metadata['center_frequency'] to get absolute RF frequencies.
|
||||||
|
:rtype: List[Tuple[float, float, float]]
|
||||||
|
|
||||||
|
:raises ValueError: If signal has fewer than 256 samples
|
||||||
|
|
||||||
|
**Example**::
|
||||||
|
|
||||||
|
>>> from ria.io import load_recording
|
||||||
|
>>> from ria_toolkit_oss.annotations import find_spectral_components
|
||||||
|
>>> recording = load_recording("capture.sigmf")
|
||||||
|
>>> segment = recording.data[0][start:end]
|
||||||
|
>>> # Components in relative (baseband) frequency
|
||||||
|
>>> components = find_spectral_components(segment, sampling_rate=20e6)
|
||||||
|
>>> for center_rel, lower_rel, upper_rel in components:
|
||||||
|
... # Convert to absolute RF frequency
|
||||||
|
... center_abs = recording.metadata['center_frequency'] + center_rel
|
||||||
|
... print(f"Component @ {center_abs/1e9:.3f} GHz")
|
||||||
|
"""
|
||||||
|
# Validate input
|
||||||
|
min_samples = 256
|
||||||
|
if len(signal_data) < min_samples:
|
||||||
|
raise ValueError(f"Signal too short: need at least {min_samples} samples, " f"got {len(signal_data)}.")
|
||||||
|
|
||||||
|
# Compute PSD using Welch method for complex IQ signals
|
||||||
|
# CRITICAL: return_onesided=False for proper complex signal handling
|
||||||
|
nperseg = min(nfft, len(signal_data))
|
||||||
|
noverlap = nperseg // 2
|
||||||
|
|
||||||
|
# --- STFT ---
|
||||||
|
freqs, times, Zxx = scipy_signal.stft(
|
||||||
|
signal_data,
|
||||||
|
fs=sampling_rate,
|
||||||
|
window="blackman",
|
||||||
|
nperseg=nperseg,
|
||||||
|
noverlap=noverlap,
|
||||||
|
return_onesided=False,
|
||||||
|
boundary=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Shift zero freq to center
|
||||||
|
Zxx = np.fft.fftshift(Zxx, axes=0)
|
||||||
|
freqs = np.fft.fftshift(freqs)
|
||||||
|
|
||||||
|
# Power spectrogram
|
||||||
|
power = np.abs(Zxx) ** 2
|
||||||
|
power_db = 10 * np.log10(power + 1e-12)
|
||||||
|
|
||||||
|
# --- Aggregate across time robustly ---
|
||||||
|
# Using percentile instead of mean prevents short signals from being diluted
|
||||||
|
freq_profile_db = np.percentile(power_db, time_percentile, axis=1)
|
||||||
|
|
||||||
|
# --- Noise floor estimation ---
|
||||||
|
if noise_threshold_db is None:
|
||||||
|
noise_threshold_db = np.percentile(freq_profile_db, 20)
|
||||||
|
|
||||||
|
threshold = noise_threshold_db + 3 # 3 dB above noise floor
|
||||||
|
|
||||||
|
# --- Smooth lightly (avoid merging nearby signals) ---
|
||||||
|
freq_profile_db = ndimage.gaussian_filter1d(freq_profile_db, sigma=1.5)
|
||||||
|
|
||||||
|
# --- Binary mask of significant frequencies ---
|
||||||
|
mask = freq_profile_db > threshold
|
||||||
|
|
||||||
|
# --- Find contiguous frequency regions ---
|
||||||
|
labeled, num_features = ndimage.label(mask)
|
||||||
|
|
||||||
|
components = []
|
||||||
|
|
||||||
|
for region_label in range(1, num_features + 1):
|
||||||
|
region_indices = np.where(labeled == region_label)[0]
|
||||||
|
|
||||||
|
if len(region_indices) == 0:
|
||||||
|
continue
|
||||||
|
|
||||||
|
lower_idx = region_indices[0]
|
||||||
|
upper_idx = region_indices[-1]
|
||||||
|
|
||||||
|
lower_freq = freqs[lower_idx]
|
||||||
|
upper_freq = freqs[upper_idx]
|
||||||
|
bw = upper_freq - lower_freq
|
||||||
|
|
||||||
|
if bw < min_component_bw:
|
||||||
|
continue
|
||||||
|
|
||||||
|
center_freq = (lower_freq + upper_freq) / 2
|
||||||
|
components.append((center_freq, lower_freq, upper_freq))
|
||||||
|
|
||||||
|
return components
|
||||||
|
|
||||||
|
|
||||||
|
def split_annotation_by_components(
|
||||||
|
annotation: Annotation,
|
||||||
|
signal: np.ndarray,
|
||||||
|
sampling_rate: float,
|
||||||
|
center_frequency_hz: float = 0.0,
|
||||||
|
nfft: int = 65536,
|
||||||
|
noise_threshold_db: Optional[float] = None,
|
||||||
|
min_component_bw: float = 50e3,
|
||||||
|
) -> List[Annotation]:
|
||||||
|
"""
|
||||||
|
Split an annotation into multiple annotations by detected frequency components.
|
||||||
|
|
||||||
|
Takes an existing annotation spanning multiple frequency components and
|
||||||
|
analyzes the frequency content to create separate sub-annotations for
|
||||||
|
each distinct frequency component.
|
||||||
|
|
||||||
|
**Use case**: Energy detection found a time window with 2-3 parallel WiFi
|
||||||
|
channels. This function splits it into separate annotations per channel.
|
||||||
|
|
||||||
|
**Frequency Handling**: `find_spectral_components` returns relative (baseband)
|
||||||
|
frequencies. This function adds `center_frequency_hz` to convert to absolute
|
||||||
|
RF frequencies for SigMF annotation bounds. This ensures correct frequency
|
||||||
|
context across baseband and RF domains.
|
||||||
|
|
||||||
|
:param annotation: Original annotation to split
|
||||||
|
:type annotation: Annotation
|
||||||
|
:param signal: Full signal array (complex IQ)
|
||||||
|
:type signal: np.ndarray
|
||||||
|
:param sampling_rate: Sample rate in Hz
|
||||||
|
:type sampling_rate: float
|
||||||
|
:param center_frequency_hz: RF center frequency to add to relative frequencies
|
||||||
|
from peak detection (default: 0.0 = baseband)
|
||||||
|
:type center_frequency_hz: float
|
||||||
|
:param nfft: FFT size for analysis (default: 65536, auto-capped at signal length)
|
||||||
|
:type nfft: int
|
||||||
|
:param noise_threshold_db: Noise floor threshold in dB. If None (default),
|
||||||
|
auto-estimates from data.
|
||||||
|
:type noise_threshold_db: Optional[float]
|
||||||
|
:param min_component_bw: Minimum component bandwidth in Hz (default: 50 kHz)
|
||||||
|
:type min_component_bw: float
|
||||||
|
|
||||||
|
:returns: List of new annotations (one per detected component).
|
||||||
|
Returns empty list if no components found or segment too short.
|
||||||
|
:rtype: List[Annotation]
|
||||||
|
|
||||||
|
**Example**::
|
||||||
|
|
||||||
|
>>> from ria.io import load_recording
|
||||||
|
>>> from ria_toolkit_oss.annotations import split_annotation_by_components
|
||||||
|
>>> recording = load_recording("capture.sigmf")
|
||||||
|
>>> # Original annotation spans multiple channels
|
||||||
|
>>> original = recording.annotations[0]
|
||||||
|
>>> # Split using RF center frequency from metadata
|
||||||
|
>>> components = split_annotation_by_components(
|
||||||
|
... original,
|
||||||
|
... recording.data[0],
|
||||||
|
... recording.metadata['sample_rate'],
|
||||||
|
... center_frequency_hz=recording.metadata.get('center_frequency', 0.0)
|
||||||
|
... )
|
||||||
|
>>> print(f"Split into {len(components)} components")
|
||||||
|
Split into 2 components
|
||||||
|
|
||||||
|
**Algorithm**:
|
||||||
|
1. Extract segment corresponding to annotation time bounds
|
||||||
|
2. Find frequency components in that segment (returns relative frequencies)
|
||||||
|
3. Add center_frequency_hz to get absolute RF frequencies
|
||||||
|
4. Create new annotation for each component
|
||||||
|
5. Preserve original metadata (label, type, etc.)
|
||||||
|
6. Add component info to comment JSON
|
||||||
|
|
||||||
|
**Notes**:
|
||||||
|
- Original annotation is not modified
|
||||||
|
- Returns empty list if segment too short (<256 samples)
|
||||||
|
- Segments <nfft get auto-downsampled to nfft (see find_spectral_components)
|
||||||
|
- Each component inherits label from original
|
||||||
|
- Component frequencies in comment JSON are absolute (RF) frequencies
|
||||||
|
"""
|
||||||
|
# Extract segment corresponding to annotation time bounds
|
||||||
|
start_sample = annotation.sample_start
|
||||||
|
end_sample = min(start_sample + annotation.sample_count, len(signal))
|
||||||
|
segment = signal[start_sample:end_sample]
|
||||||
|
|
||||||
|
# Validate segment length is enough for spectral analysis
|
||||||
|
if len(segment) < 256:
|
||||||
|
return []
|
||||||
|
|
||||||
|
# Find components in this segment (returns relative/baseband frequencies)
|
||||||
|
try:
|
||||||
|
components = find_spectral_components(segment, sampling_rate, nfft, noise_threshold_db, min_component_bw)
|
||||||
|
except ValueError:
|
||||||
|
# Spectral analysis failed (e.g., not complex IQ)
|
||||||
|
return []
|
||||||
|
|
||||||
|
if not components:
|
||||||
|
# No components found
|
||||||
|
return []
|
||||||
|
|
||||||
|
# Create annotations for each component
|
||||||
|
new_annotations = []
|
||||||
|
for center_freq_rel, lower_freq_rel, upper_freq_rel in components:
|
||||||
|
# Convert relative (baseband) frequencies to absolute (RF) frequencies
|
||||||
|
center_freq_abs = center_frequency_hz + center_freq_rel
|
||||||
|
lower_freq_abs = center_frequency_hz + lower_freq_rel
|
||||||
|
upper_freq_abs = center_frequency_hz + upper_freq_rel
|
||||||
|
|
||||||
|
# Parse original annotation metadata
|
||||||
|
try:
|
||||||
|
comment_data = json.loads(annotation.comment)
|
||||||
|
except (json.JSONDecodeError, TypeError):
|
||||||
|
comment_data = {"type": "standalone"}
|
||||||
|
|
||||||
|
# Add component information (with absolute RF frequencies)
|
||||||
|
comment_data["split_from_annotation"] = True
|
||||||
|
comment_data["original_freq_bounds"] = {
|
||||||
|
"lower": float(annotation.freq_lower_edge),
|
||||||
|
"upper": float(annotation.freq_upper_edge),
|
||||||
|
}
|
||||||
|
comment_data["component_freq_bounds_rf"] = {
|
||||||
|
"center": float(center_freq_abs),
|
||||||
|
"lower": float(lower_freq_abs),
|
||||||
|
"upper": float(upper_freq_abs),
|
||||||
|
}
|
||||||
|
|
||||||
|
# Create new annotation with absolute RF frequency bounds
|
||||||
|
new_anno = Annotation(
|
||||||
|
sample_start=annotation.sample_start,
|
||||||
|
sample_count=annotation.sample_count,
|
||||||
|
freq_lower_edge=lower_freq_abs,
|
||||||
|
freq_upper_edge=upper_freq_abs,
|
||||||
|
label=annotation.label,
|
||||||
|
comment=json.dumps(comment_data),
|
||||||
|
detail={
|
||||||
|
"generator": "parallel_signal_separator",
|
||||||
|
"center_freq_hz": float(center_freq_abs),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
new_annotations.append(new_anno)
|
||||||
|
|
||||||
|
return new_annotations
|
||||||
|
|
||||||
|
|
||||||
|
def split_recording_annotations(
|
||||||
|
recording: Recording,
|
||||||
|
indices: Optional[List[int]] = None,
|
||||||
|
nfft: int = 65536,
|
||||||
|
noise_threshold_db: Optional[float] = None,
|
||||||
|
min_component_bw: float = 50e3,
|
||||||
|
) -> Recording:
|
||||||
|
"""
|
||||||
|
Split multiple annotations in a recording by frequency components.
|
||||||
|
|
||||||
|
Processes specified annotations (or all if indices=None), replacing each
|
||||||
|
with its frequency-separated components. Uses RF center_frequency from
|
||||||
|
recording metadata for proper absolute frequency conversion.
|
||||||
|
|
||||||
|
:param recording: Recording to process
|
||||||
|
:type recording: Recording
|
||||||
|
:param indices: Annotation indices to split (None = all, default: None).
|
||||||
|
Use indices=[] to skip splitting (returns unchanged recording).
|
||||||
|
:type indices: Optional[List[int]]
|
||||||
|
:param nfft: FFT size for spectral analysis (default: 65536,
|
||||||
|
auto-capped at signal segment length)
|
||||||
|
:type nfft: int
|
||||||
|
:param noise_threshold_db: Noise floor threshold in dB. If None (default),
|
||||||
|
auto-estimates from each segment.
|
||||||
|
:type noise_threshold_db: Optional[float]
|
||||||
|
:param min_component_bw: Minimum component bandwidth in Hz (default: 50 kHz).
|
||||||
|
Components narrower than this are filtered out.
|
||||||
|
:type min_component_bw: float
|
||||||
|
|
||||||
|
:returns: New Recording with split annotations
|
||||||
|
:rtype: Recording
|
||||||
|
|
||||||
|
**Example**::
|
||||||
|
|
||||||
|
>>> from ria.io import load_recording
|
||||||
|
>>> from ria_toolkit_oss.annotations import split_recording_annotations
|
||||||
|
>>> recording = load_recording("capture.sigmf")
|
||||||
|
>>> # Split all annotations
|
||||||
|
>>> split_rec = split_recording_annotations(recording)
|
||||||
|
>>> print(f"Original: {len(recording.annotations)} annotations")
|
||||||
|
>>> print(f"Split: {len(split_rec.annotations)} annotations")
|
||||||
|
Original: 5 annotations
|
||||||
|
Split: 9 annotations
|
||||||
|
|
||||||
|
**Algorithm**:
|
||||||
|
1. For each annotation in indices (or all if None):
|
||||||
|
2. Call split_annotation_by_components with RF center_frequency
|
||||||
|
3. If components found, replace annotation with components
|
||||||
|
4. If no components found, keep original annotation
|
||||||
|
5. Annotations not in indices are kept unchanged
|
||||||
|
|
||||||
|
**Notes**:
|
||||||
|
- Original recording is not modified
|
||||||
|
- Returns empty Recording.annotations if recording has no annotations
|
||||||
|
- RF center_frequency from metadata ensures correct absolute frequencies
|
||||||
|
- If an annotation can't be split (too short, wrong format), original kept
|
||||||
|
"""
|
||||||
|
if indices is None:
|
||||||
|
# Split all annotations
|
||||||
|
indices = list(range(len(recording.annotations)))
|
||||||
|
|
||||||
|
if not recording.annotations:
|
||||||
|
# No annotations to split
|
||||||
|
return recording
|
||||||
|
|
||||||
|
signal = recording.data[0]
|
||||||
|
sample_rate = recording.metadata["sample_rate"]
|
||||||
|
center_frequency = recording.metadata.get("center_frequency", 0.0)
|
||||||
|
|
||||||
|
# Build new annotation list
|
||||||
|
new_annotations = []
|
||||||
|
for i, anno in enumerate(recording.annotations):
|
||||||
|
if i in indices:
|
||||||
|
# Attempt to split this annotation
|
||||||
|
try:
|
||||||
|
components = split_annotation_by_components(
|
||||||
|
anno,
|
||||||
|
signal,
|
||||||
|
sample_rate,
|
||||||
|
center_frequency_hz=center_frequency,
|
||||||
|
nfft=nfft,
|
||||||
|
noise_threshold_db=noise_threshold_db,
|
||||||
|
min_component_bw=min_component_bw,
|
||||||
|
)
|
||||||
|
if components:
|
||||||
|
# Split successful, use components
|
||||||
|
new_annotations.extend(components)
|
||||||
|
else:
|
||||||
|
# No components found, keep original
|
||||||
|
new_annotations.append(anno)
|
||||||
|
except Exception:
|
||||||
|
# Split failed for any reason, keep original
|
||||||
|
new_annotations.append(anno)
|
||||||
|
else:
|
||||||
|
# Not in split list, keep as-is
|
||||||
|
new_annotations.append(anno)
|
||||||
|
|
||||||
|
return Recording(data=recording.data, metadata=recording.metadata, annotations=new_annotations)
|
||||||
35
src/ria_toolkit_oss/annotations/qualify_slice.py
Normal file
35
src/ria_toolkit_oss/annotations/qualify_slice.py
Normal file
|
|
@ -0,0 +1,35 @@
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from ria_toolkit_oss.data import Recording
|
||||||
|
|
||||||
|
|
||||||
|
def qualify_slice_from_annotations(recording: Recording, slice_length: int):
|
||||||
|
"""
|
||||||
|
Slice a recording into many smaller recordings,
|
||||||
|
discarding any slices which do not have annotations that apply to those samples.
|
||||||
|
Used together with an annotation based qualifier.
|
||||||
|
|
||||||
|
:param recording: The recording to slice.
|
||||||
|
:type recording: Recording
|
||||||
|
:param slice_length: The length in samples of a slice.
|
||||||
|
:type slice_length: int"""
|
||||||
|
|
||||||
|
if len(recording.annotations) == 0:
|
||||||
|
print("Warning, no annotations.")
|
||||||
|
|
||||||
|
annotation_mask = np.zeros(len(recording.data[0]))
|
||||||
|
|
||||||
|
for annotation in recording.annotations:
|
||||||
|
annotation_mask[annotation.sample_start : annotation.sample_start + annotation.sample_count] = 1
|
||||||
|
|
||||||
|
output_recordings = []
|
||||||
|
|
||||||
|
for i in range((len(recording.data[0]) // slice_length) - 1):
|
||||||
|
start_index = slice_length * i
|
||||||
|
end_index = slice_length * (i + 1)
|
||||||
|
|
||||||
|
if 1 in annotation_mask[start_index:end_index]:
|
||||||
|
sl = recording.data[:, start_index:end_index]
|
||||||
|
output_recordings.append(Recording(data=sl, metadata=recording.metadata))
|
||||||
|
|
||||||
|
return output_recordings
|
||||||
97
src/ria_toolkit_oss/annotations/signal_isolation.py
Normal file
97
src/ria_toolkit_oss/annotations/signal_isolation.py
Normal file
|
|
@ -0,0 +1,97 @@
|
||||||
|
import numpy as np
|
||||||
|
from scipy.signal import butter, lfilter
|
||||||
|
|
||||||
|
from ria_toolkit_oss.data.annotation import Annotation
|
||||||
|
from ria_toolkit_oss.data.recording import Recording
|
||||||
|
|
||||||
|
|
||||||
|
def isolate_signal(recording: Recording, annotation: Annotation) -> Recording:
|
||||||
|
"""
|
||||||
|
Slice, filter and frequency shift the input recording according to the bounding box defined by the annotation.
|
||||||
|
|
||||||
|
:param recording: The input Recording to be sliced.
|
||||||
|
:type recording: Recording
|
||||||
|
:param annotation: The Annotation object defining the area of the recording to isolate.
|
||||||
|
:type annotation: Annotation
|
||||||
|
:param decimate: Decimate the input signal after filtering to reduce the sample rate.
|
||||||
|
:type decimate: bool
|
||||||
|
|
||||||
|
:returns: The subsection of the original recording defined by the annotation.
|
||||||
|
:rtype: Recording"""
|
||||||
|
|
||||||
|
sample_start = max(0, annotation.sample_start)
|
||||||
|
sample_stop = min(len(recording), annotation.sample_start + annotation.sample_count)
|
||||||
|
|
||||||
|
anno_base_center_freq = (annotation.freq_lower_edge + annotation.freq_upper_edge) / 2 - recording.metadata.get(
|
||||||
|
"center_frequency", 0
|
||||||
|
)
|
||||||
|
|
||||||
|
anno_bw = annotation.freq_upper_edge - annotation.freq_lower_edge
|
||||||
|
|
||||||
|
signal_slice = recording.data[0, sample_start:sample_stop]
|
||||||
|
|
||||||
|
# normalize
|
||||||
|
signal_slice = signal_slice / np.max(np.abs(signal_slice))
|
||||||
|
|
||||||
|
isolation_bw = anno_bw
|
||||||
|
|
||||||
|
# frequency shift the center of the box about zero
|
||||||
|
shifted_signal_slice = frequency_shift_iq_samples(
|
||||||
|
iq_samples=signal_slice,
|
||||||
|
sample_rate=recording.metadata["sample_rate"],
|
||||||
|
shift_frequency=-1 * anno_base_center_freq,
|
||||||
|
)
|
||||||
|
|
||||||
|
# filter
|
||||||
|
if isolation_bw < recording.metadata["sample_rate"] - 1:
|
||||||
|
filtered_signal = apply_complex_lowpass_filter(
|
||||||
|
signal=shifted_signal_slice, cutoff_frequency=isolation_bw, sample_rate=recording.metadata["sample_rate"]
|
||||||
|
)
|
||||||
|
|
||||||
|
else:
|
||||||
|
filtered_signal = shifted_signal_slice
|
||||||
|
|
||||||
|
output = Recording(data=[filtered_signal], metadata=recording.metadata)
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
def frequency_shift_iq_samples(iq_samples, sample_rate, shift_frequency):
|
||||||
|
# Number of samples
|
||||||
|
num_samples = len(iq_samples)
|
||||||
|
|
||||||
|
# Create a time vector from 0 to the total duration in seconds
|
||||||
|
time_vector = np.arange(num_samples) / sample_rate
|
||||||
|
|
||||||
|
# Generate the complex exponential for the frequency shift
|
||||||
|
complex_exponential = np.exp(1j * 2 * np.pi * shift_frequency * time_vector)
|
||||||
|
|
||||||
|
# Apply the frequency shift to the IQ samples
|
||||||
|
shifted_samples = iq_samples * complex_exponential
|
||||||
|
|
||||||
|
return shifted_samples
|
||||||
|
|
||||||
|
|
||||||
|
# Function to apply a lowpass Butterworth filter to a complex signal
|
||||||
|
def apply_complex_lowpass_filter(signal, cutoff_frequency, sample_rate, order=5):
|
||||||
|
# Design the lowpass filter
|
||||||
|
b, a = design_complex_lowpass_filter(cutoff_frequency, sample_rate, order)
|
||||||
|
|
||||||
|
# Apply the lowpass filter
|
||||||
|
filtered_signal = lfilter(b, a, signal)
|
||||||
|
return filtered_signal
|
||||||
|
|
||||||
|
|
||||||
|
def design_complex_lowpass_filter(cutoff_frequency, sample_rate, order=5):
|
||||||
|
# Nyquist frequency for complex signals is the sample rate
|
||||||
|
nyquist = sample_rate
|
||||||
|
|
||||||
|
# Ensure the cutoff frequency is positive and within the Nyquist limit
|
||||||
|
if cutoff_frequency <= 0 or cutoff_frequency > nyquist:
|
||||||
|
raise ValueError("Cutoff frequency must be between 0 and the Nyquist frequency.")
|
||||||
|
|
||||||
|
# Normalize the cutoff frequency to the Nyquist frequency
|
||||||
|
cutoff_normalized = cutoff_frequency / nyquist
|
||||||
|
|
||||||
|
# Create a Butterworth lowpass filter
|
||||||
|
b, a = butter(order, cutoff_normalized, btype="low")
|
||||||
|
return b, a
|
||||||
359
src/ria_toolkit_oss/annotations/threshold_qualifier.py
Normal file
359
src/ria_toolkit_oss/annotations/threshold_qualifier.py
Normal file
|
|
@ -0,0 +1,359 @@
|
||||||
|
"""
|
||||||
|
Temporal signal detection and boundary refinement via Hysteresis Thresholding.
|
||||||
|
|
||||||
|
Provides methods to detect signal bursts in the time domain by triggering on
|
||||||
|
smoothed power peaks and expanding boundaries to capture the full energy envelope.
|
||||||
|
|
||||||
|
This module implements a **dual-threshold trigger** to solve the 'chatter'
|
||||||
|
problem in noisy environments, ensuring that signal annotations encapsulate
|
||||||
|
the entire rise and fall of a burst rather than just the peak.
|
||||||
|
|
||||||
|
**Key Design Decisions**:
|
||||||
|
|
||||||
|
1. **Hysteresis Logic (Dual-Threshold)**:
|
||||||
|
- **Trigger**: High threshold (`threshold * max_power`) ensures high confidence
|
||||||
|
in signal presence.
|
||||||
|
- **Boundary**: Low threshold (`0.5 * trigger`) allows the annotation to
|
||||||
|
"crawl" outward, capturing the lower-energy start and end of the burst
|
||||||
|
often missed by simple single-threshold detectors.
|
||||||
|
|
||||||
|
2. **Temporal Smoothing**: Uses a moving average window (`window_size`) prior
|
||||||
|
- to thresholding. This prevents high-frequency noise spikes from causing
|
||||||
|
fragmented annotations and provides a more stable estimate of the
|
||||||
|
signal's power envelope.
|
||||||
|
|
||||||
|
3. **Spectral Profiling**: Once a temporal segment is isolated, the module
|
||||||
|
- performs an automated FFT analysis. It identifies the **90% spectral
|
||||||
|
occupancy** to define the frequency boundaries (`f_min`, `f_max`),
|
||||||
|
allowing the detector to work on narrowband and wideband signals without
|
||||||
|
manual frequency tuning.
|
||||||
|
|
||||||
|
4. **Baseband/RF Mapping**: Automatically handles the conversion from
|
||||||
|
- relative FFT bin frequencies to absolute RF frequencies by referencing
|
||||||
|
`recording.metadata["center_frequency"]`.
|
||||||
|
|
||||||
|
5. **False Positive Mitigation**: Implements a hard minimum duration check
|
||||||
|
- (10ms) to ignore transient hardware spikes or noise floor fluctuations
|
||||||
|
that do not constitute a valid signal burst.
|
||||||
|
|
||||||
|
The module is designed to be the primary "first-pass" detector for pulsed
|
||||||
|
waveforms (like ADS-B, Lora, or bursty FSK) before passing them to
|
||||||
|
classification or demodulation stages.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from ria_toolkit_oss.data import Annotation, Recording
|
||||||
|
|
||||||
|
|
||||||
|
def _find_ranges(indices, max_gap):
|
||||||
|
"""
|
||||||
|
Groups individual indices into continuous temporal ranges.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
indices: Array of indices where the signal exceeded a threshold.
|
||||||
|
max_gap: Maximum gap allowed between indices to consider them part
|
||||||
|
of the same range.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A list of (start, stop) tuples representing detected signal segments.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if len(indices) == 0:
|
||||||
|
return []
|
||||||
|
|
||||||
|
start = indices[0]
|
||||||
|
prev = indices[0]
|
||||||
|
ranges = []
|
||||||
|
|
||||||
|
for i in range(1, len(indices)):
|
||||||
|
if indices[i] - prev > max_gap:
|
||||||
|
ranges.append((start, prev))
|
||||||
|
start = indices[i]
|
||||||
|
prev = indices[i]
|
||||||
|
|
||||||
|
ranges.append((start, prev))
|
||||||
|
|
||||||
|
return ranges
|
||||||
|
|
||||||
|
|
||||||
|
def _expand_and_filter_ranges(
|
||||||
|
smoothed_power: np.ndarray,
|
||||||
|
initial_ranges: list[tuple[int, int]],
|
||||||
|
boundary_val: float,
|
||||||
|
min_duration_samples: int,
|
||||||
|
) -> list[tuple[int, int]]:
|
||||||
|
"""Apply hysteresis expansion and minimum-duration filtering."""
|
||||||
|
out: list[tuple[int, int]] = []
|
||||||
|
n = len(smoothed_power)
|
||||||
|
for start, stop in initial_ranges:
|
||||||
|
if (stop - start) < min_duration_samples:
|
||||||
|
continue
|
||||||
|
|
||||||
|
true_start = start
|
||||||
|
while true_start > 0 and smoothed_power[true_start] > boundary_val:
|
||||||
|
true_start -= 1
|
||||||
|
|
||||||
|
true_stop = stop
|
||||||
|
while true_stop < n - 1 and smoothed_power[true_stop] > boundary_val:
|
||||||
|
true_stop += 1
|
||||||
|
|
||||||
|
if (true_stop - true_start) >= min_duration_samples:
|
||||||
|
out.append((true_start, true_stop))
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def _merge_ranges(ranges: list[tuple[int, int]], max_gap: int) -> list[tuple[int, int]]:
|
||||||
|
"""Merge overlapping or near-adjacent ranges."""
|
||||||
|
if not ranges:
|
||||||
|
return []
|
||||||
|
ranges = sorted(ranges, key=lambda r: r[0])
|
||||||
|
merged = [ranges[0]]
|
||||||
|
for s, e in ranges[1:]:
|
||||||
|
last_s, last_e = merged[-1]
|
||||||
|
if s <= last_e + max_gap:
|
||||||
|
merged[-1] = (last_s, max(last_e, e))
|
||||||
|
else:
|
||||||
|
merged.append((s, e))
|
||||||
|
return merged
|
||||||
|
|
||||||
|
|
||||||
|
def _estimate_noise_floor(power: np.ndarray, quantile: float = 20.0) -> float:
|
||||||
|
"""Estimate baseline from the quieter portion of the envelope."""
|
||||||
|
return float(np.percentile(power, quantile))
|
||||||
|
|
||||||
|
|
||||||
|
def _estimate_group_gap(sample_rate: float) -> int:
|
||||||
|
"""Use a fixed temporal grouping gap instead of reusing the smoothing window."""
|
||||||
|
return max(1, int(0.001 * sample_rate))
|
||||||
|
|
||||||
|
|
||||||
|
def _estimate_spectral_bounds(signal_segment: np.ndarray, sample_rate: float) -> tuple[float, float]:
|
||||||
|
"""Estimate occupied bandwidth from a smoothed magnitude spectrum."""
|
||||||
|
if len(signal_segment) == 0:
|
||||||
|
return -sample_rate / 4, sample_rate / 4
|
||||||
|
|
||||||
|
window = np.hanning(len(signal_segment))
|
||||||
|
windowed = signal_segment * window
|
||||||
|
|
||||||
|
fft_data = np.abs(np.fft.fftshift(np.fft.fft(windowed)))
|
||||||
|
fft_freqs = np.fft.fftshift(np.fft.fftfreq(len(signal_segment), 1 / sample_rate))
|
||||||
|
|
||||||
|
# Smooth the spectrum so noise-like wideband bursts form a contiguous mask
|
||||||
|
# instead of thousands of tiny isolated runs.
|
||||||
|
spectral_smooth_bins = max(5, min(257, (len(signal_segment) // 512) | 1))
|
||||||
|
spectral_kernel = np.ones(spectral_smooth_bins, dtype=np.float64) / spectral_smooth_bins
|
||||||
|
smoothed_fft = np.convolve(fft_data, spectral_kernel, mode="same")
|
||||||
|
|
||||||
|
spectral_floor = float(np.percentile(smoothed_fft, 20))
|
||||||
|
spectral_peak = float(np.max(smoothed_fft))
|
||||||
|
spectral_ratio = spectral_peak / max(spectral_floor, 1e-12)
|
||||||
|
|
||||||
|
if spectral_ratio < 1.2:
|
||||||
|
return -sample_rate / 4, sample_rate / 4
|
||||||
|
|
||||||
|
spectral_thresh = spectral_floor + 0.1 * (spectral_peak - spectral_floor)
|
||||||
|
sig_indices = np.where(smoothed_fft > spectral_thresh)[0]
|
||||||
|
|
||||||
|
if len(sig_indices) == 0:
|
||||||
|
peak_idx = int(np.argmax(smoothed_fft))
|
||||||
|
bin_hz = sample_rate / len(signal_segment)
|
||||||
|
half_bins = max(1, int(np.ceil(10_000.0 / bin_hz)))
|
||||||
|
lo_idx = max(0, peak_idx - half_bins)
|
||||||
|
hi_idx = min(len(smoothed_fft) - 1, peak_idx + half_bins)
|
||||||
|
else:
|
||||||
|
runs = _find_ranges(sig_indices, max_gap=max(1, spectral_smooth_bins // 2))
|
||||||
|
peak_idx = int(np.argmax(smoothed_fft))
|
||||||
|
lo_idx, hi_idx = min(
|
||||||
|
runs,
|
||||||
|
key=lambda run: 0 if run[0] <= peak_idx <= run[1] else min(abs(run[0] - peak_idx), abs(run[1] - peak_idx)),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Prevent extremely narrow tone boxes from collapsing to just a few bins.
|
||||||
|
min_total_bw_hz = 20_000.0
|
||||||
|
min_half_bins = max(1, int(np.ceil((min_total_bw_hz / 2) / (sample_rate / len(signal_segment)))))
|
||||||
|
center_idx = int(round((lo_idx + hi_idx) / 2))
|
||||||
|
lo_idx = max(0, min(lo_idx, center_idx - min_half_bins))
|
||||||
|
hi_idx = min(len(smoothed_fft) - 1, max(hi_idx, center_idx + min_half_bins))
|
||||||
|
|
||||||
|
return float(fft_freqs[lo_idx]), float(fft_freqs[hi_idx])
|
||||||
|
|
||||||
|
|
||||||
|
def threshold_qualifier(
|
||||||
|
recording: Recording,
|
||||||
|
threshold: float,
|
||||||
|
window_size: Optional[int] = None,
|
||||||
|
label: Optional[str] = None,
|
||||||
|
annotation_type: Optional[str] = "standalone",
|
||||||
|
channel: int = 0,
|
||||||
|
) -> Recording:
|
||||||
|
"""
|
||||||
|
Annotate a recording with bounding boxes for regions above a threshold.
|
||||||
|
Threshold is defined as a fraction of the maximum sample magnitude.
|
||||||
|
This algorithm searches for samples above the threshold and combines them into ranges if they
|
||||||
|
are within window_size of each other.
|
||||||
|
Detects and annotates signals using energy thresholding and spectral analysis.
|
||||||
|
|
||||||
|
The algorithm follows these steps:
|
||||||
|
1. Smooths power data using a moving average.
|
||||||
|
2. Identifies 'peak' regions exceeding a high trigger threshold.
|
||||||
|
3. Uses hysteresis to expand boundaries until power drops below a lower threshold.
|
||||||
|
4. Performs an FFT on each segment to determine frequency occupancy.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
recording: The Recording object containing IQ or real signal data.
|
||||||
|
threshold: Sensitivity multiplier (0.0 to 1.0) applied to max power.
|
||||||
|
window_size: Size of the smoothing filter in samples. Defaults to 1ms worth of samples.
|
||||||
|
label: Custom string label for annotations.
|
||||||
|
annotation_type: Metadata string for the 'type' field in the annotation.
|
||||||
|
channel: Index of the channel to annotate. Defaults to 0.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A new Recording object populated with detected Annotations.
|
||||||
|
"""
|
||||||
|
# Extract signal and metadata
|
||||||
|
sample_data = recording.data[channel]
|
||||||
|
sample_rate = recording.metadata["sample_rate"]
|
||||||
|
center_frequency = recording.metadata.get("center_frequency", 0)
|
||||||
|
|
||||||
|
if window_size is None:
|
||||||
|
window_size = max(64, int(sample_rate * 0.001))
|
||||||
|
|
||||||
|
# --- 1. SIGNAL CONDITIONING ---
|
||||||
|
# Convert to power (Magnitude squared)
|
||||||
|
power_data = np.abs(sample_data) ** 2
|
||||||
|
smoothing_window = np.ones(window_size) / window_size
|
||||||
|
smoothed_power = np.convolve(power_data, smoothing_window, mode="same")
|
||||||
|
group_gap_samples = _estimate_group_gap(sample_rate)
|
||||||
|
|
||||||
|
# Define thresholds using peak relative to baseline.
|
||||||
|
max_power = np.max(smoothed_power)
|
||||||
|
noise_floor = _estimate_noise_floor(smoothed_power)
|
||||||
|
dynamic_range_ratio = max_power / max(noise_floor, 1e-12)
|
||||||
|
|
||||||
|
# Soft early exit: keep a guard for low-contrast noise, but compute it from
|
||||||
|
# the quieter tail of the envelope so burst-heavy captures are not rejected.
|
||||||
|
if dynamic_range_ratio < 1.5:
|
||||||
|
return Recording(data=recording.data, metadata=recording.metadata, annotations=recording.annotations)
|
||||||
|
|
||||||
|
trigger_val = noise_floor + threshold * (max_power - noise_floor)
|
||||||
|
boundary_val = noise_floor + 0.5 * threshold * (max_power - noise_floor)
|
||||||
|
|
||||||
|
# --- 2. INITIAL DETECTION ---
|
||||||
|
# Enforce an explicit minimum duration in seconds; this is stable across
|
||||||
|
# varying capture lengths and avoids over-fitting to recording length.
|
||||||
|
min_duration_samples = max(1, int(0.005 * sample_rate))
|
||||||
|
annotations = []
|
||||||
|
|
||||||
|
# Pass 1: Detect stronger bursts.
|
||||||
|
indices = np.where(smoothed_power > trigger_val)[0]
|
||||||
|
pass1_initial = _find_ranges(indices=indices, max_gap=group_gap_samples)
|
||||||
|
pass1_ranges = _expand_and_filter_ranges(
|
||||||
|
smoothed_power=smoothed_power,
|
||||||
|
initial_ranges=pass1_initial,
|
||||||
|
boundary_val=boundary_val,
|
||||||
|
min_duration_samples=min_duration_samples,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Pass 2: Recover weaker bursts on residual power not already covered.
|
||||||
|
# This improves recall in mixed-amplitude captures.
|
||||||
|
# Expand each Pass-1 range by the smoothing window on both sides so the
|
||||||
|
# smoothing skirts of a strong burst are not re-detected as a weak burst
|
||||||
|
# immediately adjacent to it (mirrors the guard used in Pass 3).
|
||||||
|
mask = np.ones_like(smoothed_power, dtype=np.float32)
|
||||||
|
pass2_mask_expand = window_size
|
||||||
|
for s, e in pass1_ranges:
|
||||||
|
mask[max(0, s - pass2_mask_expand) : min(len(mask), e + pass2_mask_expand)] = 0.0
|
||||||
|
residual_power = smoothed_power * mask
|
||||||
|
|
||||||
|
residual_max = float(np.max(residual_power))
|
||||||
|
residual_ratio = residual_max / max(noise_floor, 1e-12)
|
||||||
|
|
||||||
|
pass2_ranges: list[tuple[int, int]] = []
|
||||||
|
if residual_ratio >= 2.0:
|
||||||
|
weak_threshold = max(0.3, threshold * 0.7)
|
||||||
|
weak_trigger = noise_floor + weak_threshold * (residual_max - noise_floor)
|
||||||
|
weak_boundary = noise_floor + 0.5 * weak_threshold * (residual_max - noise_floor)
|
||||||
|
weak_indices = np.where(residual_power > weak_trigger)[0]
|
||||||
|
pass2_initial = _find_ranges(indices=weak_indices, max_gap=group_gap_samples)
|
||||||
|
pass2_ranges = _expand_and_filter_ranges(
|
||||||
|
smoothed_power=residual_power,
|
||||||
|
initial_ranges=pass2_initial,
|
||||||
|
boundary_val=weak_boundary,
|
||||||
|
min_duration_samples=min_duration_samples,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Pass 3: Detect sustained faint bursts via macro-window averaging.
|
||||||
|
# Targets bursts whose peak power is near the trigger level but whose
|
||||||
|
# *average* power is consistently elevated above the noise floor — these
|
||||||
|
# are missed by peak-based detection because only a few short spikes exceed
|
||||||
|
# the trigger, all too brief to pass the minimum-duration filter.
|
||||||
|
#
|
||||||
|
# The mask is applied to power_data *before* convolving so that bright
|
||||||
|
# burst energy does not bleed through the long window into adjacent regions,
|
||||||
|
# which would inflate macro_residual_max and push the trigger above the
|
||||||
|
# faint burst's average power.
|
||||||
|
macro_window_size = max(window_size * 16, int(sample_rate * 0.02))
|
||||||
|
macro_kernel = np.ones(macro_window_size, dtype=np.float64) / macro_window_size
|
||||||
|
# Expand each annotated range by half the macro window on both sides so that
|
||||||
|
# the long convolution cannot "see" the leading/trailing edges of already-
|
||||||
|
# annotated bursts, which would produce spurious short fragments in Pass 3.
|
||||||
|
macro_expand = macro_window_size * 2
|
||||||
|
masked_power_for_macro = power_data.copy()
|
||||||
|
n = len(masked_power_for_macro)
|
||||||
|
for s, e in pass1_ranges + pass2_ranges:
|
||||||
|
masked_power_for_macro[max(0, s - macro_expand) : min(n, e + macro_expand)] = 0.0
|
||||||
|
macro_residual = np.convolve(masked_power_for_macro, macro_kernel, mode="same")
|
||||||
|
|
||||||
|
macro_residual_max = float(np.max(macro_residual))
|
||||||
|
|
||||||
|
pass3_ranges: list[tuple[int, int]] = []
|
||||||
|
if macro_residual_max / max(noise_floor, 1e-12) >= 1.3:
|
||||||
|
macro_trigger = noise_floor + threshold * (macro_residual_max - noise_floor)
|
||||||
|
macro_boundary = noise_floor + 0.5 * threshold * (macro_residual_max - noise_floor)
|
||||||
|
macro_indices = np.where(macro_residual > macro_trigger)[0]
|
||||||
|
macro_initial = _find_ranges(indices=macro_indices, max_gap=group_gap_samples)
|
||||||
|
pass3_ranges = _expand_and_filter_ranges(
|
||||||
|
smoothed_power=macro_residual,
|
||||||
|
initial_ranges=macro_initial,
|
||||||
|
boundary_val=macro_boundary,
|
||||||
|
min_duration_samples=min_duration_samples,
|
||||||
|
)
|
||||||
|
|
||||||
|
all_ranges = _merge_ranges(pass1_ranges + pass2_ranges + pass3_ranges, max_gap=group_gap_samples)
|
||||||
|
|
||||||
|
for true_start, true_stop in all_ranges:
|
||||||
|
|
||||||
|
# --- 4. SPECTRAL ANALYSIS (Frequency Detection) ---
|
||||||
|
signal_segment = sample_data[true_start:true_stop]
|
||||||
|
f_min, f_max = _estimate_spectral_bounds(signal_segment, sample_rate)
|
||||||
|
|
||||||
|
# --- 5. ANNOTATION GENERATION ---
|
||||||
|
ann_label = label if label is not None else f"{int(threshold*100)}%"
|
||||||
|
|
||||||
|
# Pack metadata for the UI/Downstream processing
|
||||||
|
comment_data = {
|
||||||
|
"type": annotation_type,
|
||||||
|
"generator": "threshold_qualifier",
|
||||||
|
"params": {
|
||||||
|
"threshold": threshold,
|
||||||
|
"window_size": window_size,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
anno = Annotation(
|
||||||
|
sample_start=true_start,
|
||||||
|
sample_count=true_stop - true_start,
|
||||||
|
freq_lower_edge=center_frequency + f_min,
|
||||||
|
freq_upper_edge=center_frequency + f_max,
|
||||||
|
label=ann_label,
|
||||||
|
comment=json.dumps(comment_data),
|
||||||
|
detail={"generator": "hysteresis_qualifier"},
|
||||||
|
)
|
||||||
|
annotations.append(anno)
|
||||||
|
|
||||||
|
# Return a new Recording object including the new annotations
|
||||||
|
return Recording(data=recording.data, metadata=recording.metadata, annotations=recording.annotations + annotations)
|
||||||
1
src/ria_toolkit_oss/app/__init__.py
Normal file
1
src/ria_toolkit_oss/app/__init__.py
Normal file
|
|
@ -0,0 +1 @@
|
||||||
|
"""App runner: pull and run containerized RIA applications."""
|
||||||
278
src/ria_toolkit_oss/app/cli.py
Normal file
278
src/ria_toolkit_oss/app/cli.py
Normal file
|
|
@ -0,0 +1,278 @@
|
||||||
|
"""Unified ``ria-app`` CLI.
|
||||||
|
|
||||||
|
Subcommands:
|
||||||
|
|
||||||
|
- ``ria-app pull <app>[:tag]`` — pull a RIA app image from the configured registry.
|
||||||
|
- ``ria-app run <app>[:tag]`` — pull (if needed) and run, auto-configuring
|
||||||
|
GPU/USB/network flags from image labels set by CI.
|
||||||
|
- ``ria-app list`` — list locally cached RIA app images.
|
||||||
|
- ``ria-app stop <app>`` — stop a running app container.
|
||||||
|
- ``ria-app logs <app>`` — tail logs of a running app container.
|
||||||
|
- ``ria-app configure`` — set default registry/namespace.
|
||||||
|
|
||||||
|
Image references resolve as::
|
||||||
|
|
||||||
|
my-classifier -> {registry}/{namespace}/my-classifier:latest
|
||||||
|
group/my-classifier -> {registry}/group/my-classifier:latest
|
||||||
|
host/group/app:tag -> host/group/app:tag (fully-qualified passthrough)
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import shutil
|
||||||
|
import subprocess
|
||||||
|
import sys
|
||||||
|
|
||||||
|
from . import config as _config
|
||||||
|
|
||||||
|
_LABEL_PROFILE = "ria.profile"
|
||||||
|
_LABEL_HARDWARE = "ria.hardware"
|
||||||
|
_LABEL_APP = "ria.app"
|
||||||
|
|
||||||
|
|
||||||
|
def _engine(cfg: _config.AppConfig, sudo_override: bool = False) -> list[str]:
|
||||||
|
for exe in ("docker", "podman"):
|
||||||
|
if shutil.which(exe):
|
||||||
|
use_sudo = sudo_override or cfg.sudo
|
||||||
|
return ["sudo", exe] if use_sudo else [exe]
|
||||||
|
print("error: neither 'docker' nor 'podman' found on PATH", file=sys.stderr)
|
||||||
|
sys.exit(2)
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve_ref(app: str, cfg: _config.AppConfig) -> str:
|
||||||
|
ref = app if ":" in app.split("/")[-1] else f"{app}:latest"
|
||||||
|
slashes = ref.count("/")
|
||||||
|
if slashes >= 2:
|
||||||
|
return ref
|
||||||
|
if slashes == 1:
|
||||||
|
return f"{cfg.registry}/{ref}" if cfg.registry else ref
|
||||||
|
if not cfg.registry or not cfg.namespace:
|
||||||
|
print(
|
||||||
|
"error: app is not fully qualified and no default registry/namespace configured. "
|
||||||
|
"Run `ria-app configure` or pass a full image reference (registry/namespace/app:tag).",
|
||||||
|
file=sys.stderr,
|
||||||
|
)
|
||||||
|
sys.exit(2)
|
||||||
|
return f"{cfg.registry}/{cfg.namespace}/{ref}"
|
||||||
|
|
||||||
|
|
||||||
|
def _container_name(ref: str) -> str:
|
||||||
|
name = ref.rsplit("/", 1)[-1].split(":", 1)[0]
|
||||||
|
return f"ria-app-{name}"
|
||||||
|
|
||||||
|
|
||||||
|
def _inspect_labels(engine: list[str], ref: str) -> dict:
|
||||||
|
try:
|
||||||
|
out = subprocess.check_output(
|
||||||
|
[*engine, "image", "inspect", "--format", "{{json .Config.Labels}}", ref],
|
||||||
|
stderr=subprocess.DEVNULL,
|
||||||
|
)
|
||||||
|
except subprocess.CalledProcessError:
|
||||||
|
return {}
|
||||||
|
try:
|
||||||
|
return json.loads(out.decode().strip()) or {}
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
|
||||||
|
def _gpu_available() -> bool:
|
||||||
|
if os.path.exists("/dev/nvidia0"):
|
||||||
|
return True
|
||||||
|
return shutil.which("nvidia-smi") is not None
|
||||||
|
|
||||||
|
|
||||||
|
def _hardware_flags(labels: dict, no_gpu: bool, no_usb: bool, no_host_net: bool) -> tuple[list[str], list[str]]:
|
||||||
|
flags: list[str] = []
|
||||||
|
notes: list[str] = []
|
||||||
|
profile = (labels.get(_LABEL_PROFILE) or "").lower()
|
||||||
|
hardware = (labels.get(_LABEL_HARDWARE) or "").lower()
|
||||||
|
hw_items = {h.strip() for h in hardware.split(",") if h.strip()}
|
||||||
|
|
||||||
|
wants_gpu = any(k in profile for k in ("nvidia", "holoscan", "cuda"))
|
||||||
|
if wants_gpu and not no_gpu:
|
||||||
|
if _gpu_available():
|
||||||
|
flags += ["--gpus", "all"]
|
||||||
|
else:
|
||||||
|
notes.append(
|
||||||
|
"image wants GPU but no NVIDIA runtime detected — skipping --gpus (use --force-gpu to override)"
|
||||||
|
)
|
||||||
|
|
||||||
|
if hw_items & {"pluto", "rtlsdr", "hackrf", "bladerf"} and not no_usb:
|
||||||
|
flags += ["--device", "/dev/bus/usb"]
|
||||||
|
|
||||||
|
if hw_items & {"usrp", "thinkrf", "pluto"} and not no_host_net:
|
||||||
|
flags += ["--net", "host"]
|
||||||
|
|
||||||
|
return flags, notes
|
||||||
|
|
||||||
|
|
||||||
|
def _cmd_configure(args: argparse.Namespace) -> int:
|
||||||
|
cfg = _config.load()
|
||||||
|
if args.registry:
|
||||||
|
cfg.registry = args.registry
|
||||||
|
if args.namespace:
|
||||||
|
cfg.namespace = args.namespace
|
||||||
|
if args.sudo is not None:
|
||||||
|
cfg.sudo = args.sudo
|
||||||
|
path = _config.save(cfg)
|
||||||
|
print(f"Saved app config to {path}")
|
||||||
|
print(f" registry: {cfg.registry or '(unset)'}")
|
||||||
|
print(f" namespace: {cfg.namespace or '(unset)'}")
|
||||||
|
print(f" sudo: {cfg.sudo}")
|
||||||
|
return 0
|
||||||
|
|
||||||
|
|
||||||
|
def _cmd_pull(args: argparse.Namespace) -> int:
|
||||||
|
cfg = _config.load()
|
||||||
|
engine = _engine(cfg, args.sudo)
|
||||||
|
ref = _resolve_ref(args.app, cfg)
|
||||||
|
print(f"Pulling {ref}")
|
||||||
|
return subprocess.call([*engine, "pull", ref])
|
||||||
|
|
||||||
|
|
||||||
|
def _cmd_run(args: argparse.Namespace) -> int:
|
||||||
|
cfg = _config.load()
|
||||||
|
engine = _engine(cfg, args.sudo)
|
||||||
|
ref = _resolve_ref(args.app, cfg)
|
||||||
|
|
||||||
|
if not _inspect_labels(engine, ref):
|
||||||
|
rc = subprocess.call([*engine, "pull", ref])
|
||||||
|
if rc != 0:
|
||||||
|
return rc
|
||||||
|
|
||||||
|
labels = _inspect_labels(engine, ref)
|
||||||
|
no_gpu = args.no_gpu and not args.force_gpu
|
||||||
|
hw_flags, notes = _hardware_flags(labels, no_gpu=no_gpu, no_usb=args.no_usb, no_host_net=args.no_host_net)
|
||||||
|
if args.force_gpu and "--gpus" not in hw_flags:
|
||||||
|
hw_flags = ["--gpus", "all", *hw_flags]
|
||||||
|
|
||||||
|
cmd = [*engine, "run", "--rm"]
|
||||||
|
if not args.foreground:
|
||||||
|
cmd += ["-d"]
|
||||||
|
cmd += ["--name", args.name or _container_name(ref)]
|
||||||
|
cmd += hw_flags
|
||||||
|
|
||||||
|
if args.config:
|
||||||
|
cmd += ["-v", f"{args.config}:/config/config.yaml:ro", "-e", "RIA_CONFIG=/config/config.yaml"]
|
||||||
|
|
||||||
|
for env in args.env or []:
|
||||||
|
cmd += ["-e", env]
|
||||||
|
for vol in args.volume or []:
|
||||||
|
cmd += ["-v", vol]
|
||||||
|
for port in args.publish or []:
|
||||||
|
cmd += ["-p", port]
|
||||||
|
|
||||||
|
cmd += list(args.docker_args or [])
|
||||||
|
cmd += [ref]
|
||||||
|
cmd += list(args.app_args or [])
|
||||||
|
|
||||||
|
if args.dry_run:
|
||||||
|
print(" ".join(cmd))
|
||||||
|
return 0
|
||||||
|
|
||||||
|
label_str = ", ".join(f"{k}={v}" for k, v in labels.items() if k.startswith("ria.")) or "(no ria.* labels)"
|
||||||
|
print(f"Running {ref} [{label_str}]")
|
||||||
|
if hw_flags:
|
||||||
|
print(f" auto flags: {' '.join(hw_flags)}")
|
||||||
|
for note in notes:
|
||||||
|
print(f" note: {note}")
|
||||||
|
return subprocess.call(cmd)
|
||||||
|
|
||||||
|
|
||||||
|
def _cmd_list(args: argparse.Namespace) -> int:
|
||||||
|
cfg = _config.load()
|
||||||
|
engine = _engine(cfg, args.sudo)
|
||||||
|
return subprocess.call(
|
||||||
|
[
|
||||||
|
*engine,
|
||||||
|
"images",
|
||||||
|
"--filter",
|
||||||
|
f"label={_LABEL_APP}",
|
||||||
|
"--format",
|
||||||
|
"table {{.Repository}}:{{.Tag}}\t{{.ID}}\t{{.Size}}",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _cmd_stop(args: argparse.Namespace) -> int:
|
||||||
|
cfg = _config.load()
|
||||||
|
engine = _engine(cfg, args.sudo)
|
||||||
|
name = args.name or _container_name(_resolve_ref(args.app, cfg))
|
||||||
|
return subprocess.call([*engine, "stop", name])
|
||||||
|
|
||||||
|
|
||||||
|
def _cmd_logs(args: argparse.Namespace) -> int:
|
||||||
|
cfg = _config.load()
|
||||||
|
engine = _engine(cfg, args.sudo)
|
||||||
|
name = args.name or _container_name(_resolve_ref(args.app, cfg))
|
||||||
|
cmd = [*engine, "logs"]
|
||||||
|
if args.follow:
|
||||||
|
cmd += ["-f"]
|
||||||
|
cmd += [name]
|
||||||
|
return subprocess.call(cmd)
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> None:
|
||||||
|
parser = argparse.ArgumentParser(prog="ria-app")
|
||||||
|
parser.add_argument("--sudo", action="store_true", default=False, help="Run docker/podman via sudo")
|
||||||
|
sub = parser.add_subparsers(dest="command", required=True)
|
||||||
|
|
||||||
|
p_cfg = sub.add_parser("configure", help="Set default registry/namespace")
|
||||||
|
p_cfg.add_argument("--registry", default=None, help="Default container registry (e.g. registry.riahub.ai)")
|
||||||
|
p_cfg.add_argument("--namespace", default=None, help="Default namespace (e.g. qoherent)")
|
||||||
|
p_cfg.add_argument(
|
||||||
|
"--sudo",
|
||||||
|
dest="sudo",
|
||||||
|
action=argparse.BooleanOptionalAction,
|
||||||
|
default=None,
|
||||||
|
help="Persist sudo default (--sudo / --no-sudo)",
|
||||||
|
)
|
||||||
|
|
||||||
|
p_pull = sub.add_parser("pull", help="Pull an app image")
|
||||||
|
p_pull.add_argument("app", help="App name or image reference")
|
||||||
|
|
||||||
|
p_run = sub.add_parser("run", help="Run an app, auto-detecting hardware flags")
|
||||||
|
p_run.add_argument("app", help="App name or image reference")
|
||||||
|
p_run.add_argument("--name", default=None, help="Container name (default: ria-app-<app>)")
|
||||||
|
p_run.add_argument("--config", default=None, help="Path to config.yaml to mount into the container")
|
||||||
|
p_run.add_argument("-e", "--env", action="append", help="Extra env var (KEY=VALUE)")
|
||||||
|
p_run.add_argument("-v", "--volume", action="append", help="Extra volume mount")
|
||||||
|
p_run.add_argument("-p", "--publish", action="append", help="Publish port")
|
||||||
|
p_run.add_argument("--foreground", "-F", action="store_true", help="Run in foreground (no -d)")
|
||||||
|
p_run.add_argument("--no-gpu", action="store_true", help="Skip --gpus flag even if image wants GPU")
|
||||||
|
p_run.add_argument("--force-gpu", action="store_true", help="Force --gpus all even if no NVIDIA runtime detected")
|
||||||
|
p_run.add_argument("--no-usb", action="store_true", help="Skip --device /dev/bus/usb")
|
||||||
|
p_run.add_argument("--no-host-net", action="store_true", help="Skip --net host")
|
||||||
|
p_run.add_argument("--dry-run", action="store_true", help="Print the container command and exit")
|
||||||
|
p_run.add_argument("--docker-args", nargs=argparse.REMAINDER, help="Pass remaining args to docker/podman run")
|
||||||
|
p_run.add_argument("--app-args", nargs=argparse.REMAINDER, help="Pass remaining args to the app entrypoint")
|
||||||
|
|
||||||
|
sub.add_parser("list", help="List locally cached RIA app images")
|
||||||
|
|
||||||
|
p_stop = sub.add_parser("stop", help="Stop a running app")
|
||||||
|
p_stop.add_argument("app", help="App name or image reference")
|
||||||
|
p_stop.add_argument("--name", default=None, help="Container name override")
|
||||||
|
|
||||||
|
p_logs = sub.add_parser("logs", help="Tail logs of a running app")
|
||||||
|
p_logs.add_argument("app", help="App name or image reference")
|
||||||
|
p_logs.add_argument("--name", default=None, help="Container name override")
|
||||||
|
p_logs.add_argument("-f", "--follow", action="store_true", help="Follow log output")
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
dispatch = {
|
||||||
|
"configure": _cmd_configure,
|
||||||
|
"pull": _cmd_pull,
|
||||||
|
"run": _cmd_run,
|
||||||
|
"list": _cmd_list,
|
||||||
|
"stop": _cmd_stop,
|
||||||
|
"logs": _cmd_logs,
|
||||||
|
}
|
||||||
|
sys.exit(dispatch[args.command](args))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
51
src/ria_toolkit_oss/app/config.py
Normal file
51
src/ria_toolkit_oss/app/config.py
Normal file
|
|
@ -0,0 +1,51 @@
|
||||||
|
"""App runner configuration at ``~/.ria/toolkit.json``.
|
||||||
|
|
||||||
|
Schema::
|
||||||
|
|
||||||
|
{
|
||||||
|
"registry": "registry.riahub.ai",
|
||||||
|
"namespace": "qoherent"
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
from dataclasses import asdict, dataclass
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
_DEFAULT_PATH = Path(os.environ.get("RIA_TOOLKIT_CONFIG", str(Path.home() / ".ria" / "toolkit.json")))
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class AppConfig:
|
||||||
|
registry: str = ""
|
||||||
|
namespace: str = ""
|
||||||
|
sudo: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
def default_path() -> Path:
|
||||||
|
return _DEFAULT_PATH
|
||||||
|
|
||||||
|
|
||||||
|
def load(path: Path | None = None) -> AppConfig:
|
||||||
|
p = path or _DEFAULT_PATH
|
||||||
|
if not p.exists():
|
||||||
|
return AppConfig(
|
||||||
|
registry=os.environ.get("RIA_REGISTRY", ""),
|
||||||
|
namespace=os.environ.get("RIA_NAMESPACE", ""),
|
||||||
|
)
|
||||||
|
data = json.loads(p.read_text())
|
||||||
|
return AppConfig(
|
||||||
|
registry=data.get("registry", "") or os.environ.get("RIA_REGISTRY", ""),
|
||||||
|
namespace=data.get("namespace", "") or os.environ.get("RIA_NAMESPACE", ""),
|
||||||
|
sudo=bool(data.get("sudo", False)) or os.environ.get("RIA_DOCKER_SUDO", "") not in ("", "0", "false"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def save(cfg: AppConfig, path: Path | None = None) -> Path:
|
||||||
|
p = path or _DEFAULT_PATH
|
||||||
|
p.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
p.write_text(json.dumps(asdict(cfg), indent=2))
|
||||||
|
return p
|
||||||
8
src/ria_toolkit_oss/data/__init__.py
Normal file
8
src/ria_toolkit_oss/data/__init__.py
Normal file
|
|
@ -0,0 +1,8 @@
|
||||||
|
"""
|
||||||
|
The Data package contains abstract data types tailored for radio machine learning, such as ``Recording``, as well
|
||||||
|
as the abstract interfaces for the radio dataset and radio dataset builder framework.
|
||||||
|
"""
|
||||||
|
|
||||||
|
__all__ = ["Annotation", "Recording"]
|
||||||
|
from .annotation import Annotation
|
||||||
|
from .recording import Recording
|
||||||
|
|
@ -7,8 +7,8 @@ from typing import Any, Optional
|
||||||
|
|
||||||
from packaging.version import Version
|
from packaging.version import Version
|
||||||
|
|
||||||
from ria_toolkit_oss.datatypes.datasets.license.dataset_license import DatasetLicense
|
from ria_toolkit_oss.data.datasets.license.dataset_license import DatasetLicense
|
||||||
from ria_toolkit_oss.datatypes.datasets.radio_dataset import RadioDataset
|
from ria_toolkit_oss.data.datasets.radio_dataset import RadioDataset
|
||||||
from ria_toolkit_oss.utils.abstract_attribute import abstract_attribute
|
from ria_toolkit_oss.utils.abstract_attribute import abstract_attribute
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -21,7 +21,8 @@ class DatasetBuilder(ABC):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
_url: str = abstract_attribute()
|
_url: str = abstract_attribute()
|
||||||
_SHA256: str # SHA256 checksum.
|
_SHA256: Optional[str] = None # SHA256 checksum.
|
||||||
|
_MD5: Optional[str] = None # MD5 checksum.
|
||||||
_name: str = abstract_attribute()
|
_name: str = abstract_attribute()
|
||||||
_author: str = abstract_attribute()
|
_author: str = abstract_attribute()
|
||||||
_license: DatasetLicense = abstract_attribute()
|
_license: DatasetLicense = abstract_attribute()
|
||||||
|
|
@ -109,13 +109,10 @@ def copy_file(original_source: str | os.PathLike, new_source: str | os.PathLike)
|
||||||
|
|
||||||
:return: None
|
:return: None
|
||||||
"""
|
"""
|
||||||
original_file = h5py.File(original_source, "r")
|
with h5py.File(original_source, "r") as original_file:
|
||||||
|
with h5py.File(new_source, "w") as new_file:
|
||||||
with h5py.File(new_source, "w") as new_file:
|
for key in original_file.keys():
|
||||||
for key in original_file.keys():
|
original_file.copy(key, new_file)
|
||||||
original_file.copy(key, new_file)
|
|
||||||
|
|
||||||
original_file.close()
|
|
||||||
|
|
||||||
|
|
||||||
def make_empty_clone(original_source: str | os.PathLike, new_source: str | os.PathLike, example_length: int) -> None:
|
def make_empty_clone(original_source: str | os.PathLike, new_source: str | os.PathLike, example_length: int) -> None:
|
||||||
|
|
@ -172,8 +169,10 @@ def delete_example_inplace(source: str | os.PathLike, idx: int) -> None:
|
||||||
with h5py.File(source, "a") as f:
|
with h5py.File(source, "a") as f:
|
||||||
ds, md = f["data"], f["metadata/metadata"]
|
ds, md = f["data"], f["metadata/metadata"]
|
||||||
m, c, n = ds.shape
|
m, c, n = ds.shape
|
||||||
assert 0 <= idx <= m - 1
|
if not (0 <= idx <= m - 1):
|
||||||
assert len(ds) == len(md)
|
raise IndexError(f"Index {idx} out of range [0, {m - 1}]")
|
||||||
|
if len(ds) != len(md):
|
||||||
|
raise ValueError("Data and metadata array lengths do not match")
|
||||||
|
|
||||||
new_ds = f.create_dataset(
|
new_ds = f.create_dataset(
|
||||||
"data.temp",
|
"data.temp",
|
||||||
|
|
@ -218,4 +217,3 @@ def overwrite_file(source: str | os.PathLike, new_data: np.ndarray) -> None:
|
||||||
ds_name = tuple(f.keys())[0]
|
ds_name = tuple(f.keys())[0]
|
||||||
del f[ds_name]
|
del f[ds_name]
|
||||||
f.create_dataset(ds_name, data=new_data)
|
f.create_dataset(ds_name, data=new_data)
|
||||||
f.close()
|
|
||||||
|
|
@ -7,11 +7,11 @@ from typing import Optional
|
||||||
import h5py
|
import h5py
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from ria_toolkit_oss.datatypes.datasets.h5helpers import (
|
from ria_toolkit_oss.data.datasets.h5helpers import (
|
||||||
append_entry_inplace,
|
append_entry_inplace,
|
||||||
copy_dataset_entry_by_index,
|
copy_dataset_entry_by_index,
|
||||||
)
|
)
|
||||||
from ria_toolkit_oss.datatypes.datasets.radio_dataset import RadioDataset
|
from ria_toolkit_oss.data.datasets.radio_dataset import RadioDataset
|
||||||
|
|
||||||
|
|
||||||
class IQDataset(RadioDataset, ABC):
|
class IQDataset(RadioDataset, ABC):
|
||||||
|
|
@ -19,7 +19,7 @@ class IQDataset(RadioDataset, ABC):
|
||||||
radiofrequency (RF) signals represented as In-phase (I) and Quadrature (Q) samples.
|
radiofrequency (RF) signals represented as In-phase (I) and Quadrature (Q) samples.
|
||||||
|
|
||||||
For machine learning tasks that involve processing spectrograms, please use
|
For machine learning tasks that involve processing spectrograms, please use
|
||||||
ria_toolkit_oss.datatypes.datasets.SpectDataset instead.
|
ria_toolkit_oss.data.datasets.SpectDataset instead.
|
||||||
|
|
||||||
This is an abstract interface defining common properties and behaviour of IQDatasets. Therefore, this class
|
This is an abstract interface defining common properties and behaviour of IQDatasets. Therefore, this class
|
||||||
should not be instantiated directly. Instead, it is subclassed to define custom interfaces for specific machine
|
should not be instantiated directly. Instead, it is subclassed to define custom interfaces for specific machine
|
||||||
|
|
@ -169,8 +169,10 @@ class IQDataset(RadioDataset, ABC):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if split_factor is not None and example_length is not None:
|
if split_factor is not None and example_length is not None:
|
||||||
# Raise warning and use split factor
|
# Warn and use split factor
|
||||||
raise Warning("split_factor and example_length should not both be specified.")
|
import warnings
|
||||||
|
|
||||||
|
warnings.warn("split_factor and example_length should not both be specified.")
|
||||||
|
|
||||||
if not inplace:
|
if not inplace:
|
||||||
# ds = self.create_new_dataset(example_length=example_length)
|
# ds = self.create_new_dataset(example_length=example_length)
|
||||||
|
|
@ -12,7 +12,7 @@ import numpy as np
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
from numpy.typing import ArrayLike
|
from numpy.typing import ArrayLike
|
||||||
|
|
||||||
from ria_toolkit_oss.datatypes.datasets.h5helpers import (
|
from ria_toolkit_oss.data.datasets.h5helpers import (
|
||||||
append_entry_inplace,
|
append_entry_inplace,
|
||||||
copy_file,
|
copy_file,
|
||||||
copy_over_example,
|
copy_over_example,
|
||||||
|
|
@ -29,7 +29,7 @@ class RadioDataset(ABC):
|
||||||
|
|
||||||
This is an abstract interface defining common properties and behavior of radio datasets. Therefore, this class
|
This is an abstract interface defining common properties and behavior of radio datasets. Therefore, this class
|
||||||
should not be instantiated directly. Instead, it should be subclassed to define specific interfaces for different
|
should not be instantiated directly. Instead, it should be subclassed to define specific interfaces for different
|
||||||
types of radio datasets. For example, see ria_toolkit_oss.datatypes.datasets.IQDataset, which is a radio dataset
|
types of radio datasets. For example, see ria_toolkit_oss.data.datasets.IQDataset, which is a radio dataset
|
||||||
subclass tailored for tasks involving the processing of radio signals represented as IQ (In-phase and Quadrature)
|
subclass tailored for tasks involving the processing of radio signals represented as IQ (In-phase and Quadrature)
|
||||||
samples.
|
samples.
|
||||||
|
|
||||||
|
|
@ -255,7 +255,9 @@ class RadioDataset(ABC):
|
||||||
else:
|
else:
|
||||||
classes_to_augment = classes_to_augment.encode("utf-8")
|
classes_to_augment = classes_to_augment.encode("utf-8")
|
||||||
if classes_to_augment not in class_sizes:
|
if classes_to_augment not in class_sizes:
|
||||||
raise ValueError(f"class name of {i} does not belong to the class key of {class_key}")
|
raise ValueError(
|
||||||
|
f"class name of {classes_to_augment} does not belong to the class key of {class_key}"
|
||||||
|
)
|
||||||
|
|
||||||
result_sizes = get_result_sizes(
|
result_sizes = get_result_sizes(
|
||||||
level=level, target_size=target_size, classes_to_augment=classes_to_augment, class_sizes=class_sizes
|
level=level, target_size=target_size, classes_to_augment=classes_to_augment, class_sizes=class_sizes
|
||||||
|
|
@ -375,7 +377,7 @@ class RadioDataset(ABC):
|
||||||
counters[key] = counters.get(key, 0)
|
counters[key] = counters.get(key, 0)
|
||||||
|
|
||||||
idx = 0
|
idx = 0
|
||||||
with h5py.File(self.source, "a") as f:
|
with h5py.File(self.source, "r") as f:
|
||||||
while idx < len(self):
|
while idx < len(self):
|
||||||
labels = f["metadata/metadata"][class_key]
|
labels = f["metadata/metadata"][class_key]
|
||||||
current_class = labels[idx]
|
current_class = labels[idx]
|
||||||
|
|
@ -514,7 +516,7 @@ class RadioDataset(ABC):
|
||||||
|
|
||||||
idx = 0
|
idx = 0
|
||||||
|
|
||||||
with h5py.File(self.source, "a") as f:
|
with h5py.File(self.source, "r") as f:
|
||||||
while idx < len(self):
|
while idx < len(self):
|
||||||
labels = f["metadata/metadata"][class_key]
|
labels = f["metadata/metadata"][class_key]
|
||||||
current_class = labels[idx]
|
current_class = labels[idx]
|
||||||
|
|
@ -3,7 +3,7 @@ from __future__ import annotations
|
||||||
import os
|
import os
|
||||||
from abc import ABC
|
from abc import ABC
|
||||||
|
|
||||||
from ria_toolkit_oss.datatypes.datasets.radio_dataset import RadioDataset
|
from ria_toolkit_oss.data.datasets.radio_dataset import RadioDataset
|
||||||
|
|
||||||
|
|
||||||
class SpectDataset(RadioDataset, ABC):
|
class SpectDataset(RadioDataset, ABC):
|
||||||
|
|
@ -13,7 +13,7 @@ class SpectDataset(RadioDataset, ABC):
|
||||||
radio signal spectrograms.
|
radio signal spectrograms.
|
||||||
|
|
||||||
For machine learning tasks that involve processing on IQ samples, please use
|
For machine learning tasks that involve processing on IQ samples, please use
|
||||||
ria_toolkit_oss.datatypes.datasets.IQDataset instead.
|
ria_toolkit_oss.data.datasets.IQDataset instead.
|
||||||
|
|
||||||
This is an abstract interface defining common properties and behaviour of IQDatasets. Therefore, this class
|
This is an abstract interface defining common properties and behaviour of IQDatasets. Therefore, this class
|
||||||
should not be instantiated directly. Instead, it is subclassed to define custom interfaces for specific machine
|
should not be instantiated directly. Instead, it is subclassed to define custom interfaces for specific machine
|
||||||
|
|
@ -6,11 +6,8 @@ from typing import Optional
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from numpy.random import Generator
|
from numpy.random import Generator
|
||||||
|
|
||||||
from ria_toolkit_oss.datatypes.datasets import RadioDataset
|
from ria_toolkit_oss.data.datasets import RadioDataset
|
||||||
from ria_toolkit_oss.datatypes.datasets.h5helpers import (
|
from ria_toolkit_oss.data.datasets.h5helpers import copy_over_example, make_empty_clone
|
||||||
copy_over_example,
|
|
||||||
make_empty_clone,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def split(dataset: RadioDataset, lengths: list[int | float]) -> list[RadioDataset]:
|
def split(dataset: RadioDataset, lengths: list[int | float]) -> list[RadioDataset]:
|
||||||
|
|
@ -31,7 +28,7 @@ def split(dataset: RadioDataset, lengths: list[int | float]) -> list[RadioDatase
|
||||||
cases.
|
cases.
|
||||||
|
|
||||||
This function is deterministic, meaning it will always produce the same split. For a random split, see
|
This function is deterministic, meaning it will always produce the same split. For a random split, see
|
||||||
ria_toolkit_oss.datatypes.datasets.random_split.
|
ria_toolkit_oss.data.datasets.random_split.
|
||||||
|
|
||||||
:param dataset: Dataset to be split.
|
:param dataset: Dataset to be split.
|
||||||
:type dataset: RadioDataset
|
:type dataset: RadioDataset
|
||||||
|
|
@ -50,7 +47,7 @@ def split(dataset: RadioDataset, lengths: list[int | float]) -> list[RadioDatase
|
||||||
>>> import string
|
>>> import string
|
||||||
>>> import numpy as np
|
>>> import numpy as np
|
||||||
>>> import pandas as pd
|
>>> import pandas as pd
|
||||||
>>> from ria_toolkit_oss.datatypes.datasets import split
|
>>> from ria_toolkit_oss.data.datasets import split
|
||||||
|
|
||||||
First, let's generate some random data:
|
First, let's generate some random data:
|
||||||
|
|
||||||
|
|
@ -126,7 +123,7 @@ def random_split(
|
||||||
training and test datasets.
|
training and test datasets.
|
||||||
|
|
||||||
This restriction makes it unlikely that a random split will produce datasets with the exact lengths specified.
|
This restriction makes it unlikely that a random split will produce datasets with the exact lengths specified.
|
||||||
If it is important to ensure the closest possible split, consider using ria_toolkit_oss.datatypes.datasets.split
|
If it is important to ensure the closest possible split, consider using ria_toolkit_oss.data.datasets.split
|
||||||
instead.
|
instead.
|
||||||
|
|
||||||
:param dataset: Dataset to be split.
|
:param dataset: Dataset to be split.
|
||||||
|
|
@ -144,7 +141,7 @@ def random_split(
|
||||||
:rtype: list of RadioDataset
|
:rtype: list of RadioDataset
|
||||||
|
|
||||||
See Also:
|
See Also:
|
||||||
ria_toolkit_oss.datatypes.datasets.split: Usage is the same as for ``random_split()``.
|
ria_toolkit_oss.data.datasets.split: Usage is the same as for ``random_split()``.
|
||||||
"""
|
"""
|
||||||
if not isinstance(dataset, RadioDataset):
|
if not isinstance(dataset, RadioDataset):
|
||||||
raise ValueError(f"'dataset' must be RadioDataset or one of its subclasses, got {type(dataset)}.")
|
raise ValueError(f"'dataset' must be RadioDataset or one of its subclasses, got {type(dataset)}.")
|
||||||
|
|
@ -247,7 +244,7 @@ def _validate_sublists(list_of_lists: list[list[str]], ids: list[str]) -> None:
|
||||||
"""Ensure that each ID is present in one and only one sublist."""
|
"""Ensure that each ID is present in one and only one sublist."""
|
||||||
all_elements = [item for sublist in list_of_lists for item in sublist]
|
all_elements = [item for sublist in list_of_lists for item in sublist]
|
||||||
|
|
||||||
assert len(all_elements) == len(set(all_elements)) and list(set(ids)).sort() == list(set(all_elements)).sort()
|
assert len(all_elements) == len(set(all_elements)) and sorted(set(ids)) == sorted(set(all_elements))
|
||||||
|
|
||||||
|
|
||||||
def _generate_split_source_filenames(
|
def _generate_split_source_filenames(
|
||||||
|
|
@ -12,7 +12,7 @@ from typing import Any, Iterator, Optional
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from numpy.typing import ArrayLike
|
from numpy.typing import ArrayLike
|
||||||
|
|
||||||
from ria_toolkit_oss.datatypes.annotation import Annotation
|
from ria_toolkit_oss.data.annotation import Annotation
|
||||||
|
|
||||||
PROTECTED_KEYS = ["rec_id", "timestamp"]
|
PROTECTED_KEYS = ["rec_id", "timestamp"]
|
||||||
|
|
||||||
|
|
@ -26,7 +26,7 @@ class Recording:
|
||||||
Metadata is stored in a dictionary of key value pairs,
|
Metadata is stored in a dictionary of key value pairs,
|
||||||
to include information such as sample_rate and center_frequency.
|
to include information such as sample_rate and center_frequency.
|
||||||
|
|
||||||
Annotations are a list of :class:`~ria_toolkit_oss.datatypes.Annotation`,
|
Annotations are a list of :class:`~ria_toolkit_oss.data.Annotation`,
|
||||||
defining bounding boxes in time and frequency with labels and metadata.
|
defining bounding boxes in time and frequency with labels and metadata.
|
||||||
|
|
||||||
Here, signal data is represented as a NumPy array. This class is then extended in the RIA Backends to provide
|
Here, signal data is represented as a NumPy array. This class is then extended in the RIA Backends to provide
|
||||||
|
|
@ -46,7 +46,7 @@ class Recording:
|
||||||
|
|
||||||
:param metadata: Additional information associated with the recording.
|
:param metadata: Additional information associated with the recording.
|
||||||
:type metadata: dict, optional
|
:type metadata: dict, optional
|
||||||
:param annotations: A collection of :class:`~ria_toolkit_oss.datatypes.Annotation` objects defining bounding boxes.
|
:param annotations: A collection of :class:`~ria_toolkit_oss.data.Annotation` objects defining bounding boxes.
|
||||||
:type annotations: list of Annotations, optional
|
:type annotations: list of Annotations, optional
|
||||||
|
|
||||||
:param dtype: Explicitly specify the data-type of the complex samples. Must be a complex NumPy type, such as
|
:param dtype: Explicitly specify the data-type of the complex samples. Must be a complex NumPy type, such as
|
||||||
|
|
@ -66,7 +66,7 @@ class Recording:
|
||||||
**Examples:**
|
**Examples:**
|
||||||
|
|
||||||
>>> import numpy
|
>>> import numpy
|
||||||
>>> from ria_toolkit_oss.datatypes import Recording, Annotation
|
>>> from ria_toolkit_oss.data import Recording, Annotation
|
||||||
|
|
||||||
>>> # Create an array of complex samples, just 1s in this case.
|
>>> # Create an array of complex samples, just 1s in this case.
|
||||||
>>> samples = numpy.ones(10000, dtype=numpy.complex64)
|
>>> samples = numpy.ones(10000, dtype=numpy.complex64)
|
||||||
|
|
@ -146,7 +146,7 @@ class Recording:
|
||||||
self._metadata["timestamp"] = time.time()
|
self._metadata["timestamp"] = time.time()
|
||||||
else:
|
else:
|
||||||
if not isinstance(self._metadata["timestamp"], (int, float)):
|
if not isinstance(self._metadata["timestamp"], (int, float)):
|
||||||
raise ValueError("timestamp must be int or float, not ", type(self._metadata["timestamp"]))
|
raise ValueError(f"timestamp must be int or float, not {type(self._metadata['timestamp'])}")
|
||||||
|
|
||||||
if "rec_id" not in self.metadata:
|
if "rec_id" not in self.metadata:
|
||||||
self._metadata["rec_id"] = generate_recording_id(data=self.data, timestamp=self._metadata["timestamp"])
|
self._metadata["rec_id"] = generate_recording_id(data=self.data, timestamp=self._metadata["timestamp"])
|
||||||
|
|
@ -244,7 +244,7 @@ class Recording:
|
||||||
@property
|
@property
|
||||||
def sample_rate(self) -> float | None:
|
def sample_rate(self) -> float | None:
|
||||||
"""
|
"""
|
||||||
:return: Sample rate of the recording, or None is 'sample_rate' is not in metadata.
|
:return: Sample rate of the recording, or None if 'sample_rate' is not in metadata.
|
||||||
:type: str
|
:type: str
|
||||||
"""
|
"""
|
||||||
return self.metadata.get("sample_rate")
|
return self.metadata.get("sample_rate")
|
||||||
|
|
@ -311,7 +311,7 @@ class Recording:
|
||||||
Create a recording and add metadata:
|
Create a recording and add metadata:
|
||||||
|
|
||||||
>>> import numpy
|
>>> import numpy
|
||||||
>>> from ria_toolkit_oss.datatypes import Recording
|
>>> from ria_toolkit_oss.data import Recording
|
||||||
>>>
|
>>>
|
||||||
>>> samples = numpy.ones(10000, dtype=numpy.complex64)
|
>>> samples = numpy.ones(10000, dtype=numpy.complex64)
|
||||||
>>> metadata = {
|
>>> metadata = {
|
||||||
|
|
@ -366,7 +366,7 @@ class Recording:
|
||||||
Create a recording and update metadata:
|
Create a recording and update metadata:
|
||||||
|
|
||||||
>>> import numpy
|
>>> import numpy
|
||||||
>>> from ria_toolkit_oss.datatypes import Recording
|
>>> from ria_toolkit_oss.data import Recording
|
||||||
|
|
||||||
>>> samples = numpy.ones(10000, dtype=numpy.complex64)
|
>>> samples = numpy.ones(10000, dtype=numpy.complex64)
|
||||||
>>> metadata = {
|
>>> metadata = {
|
||||||
|
|
@ -393,6 +393,7 @@ class Recording:
|
||||||
"""
|
"""
|
||||||
if key not in self.metadata:
|
if key not in self.metadata:
|
||||||
self.add_to_metadata(key=key, value=value)
|
self.add_to_metadata(key=key, value=value)
|
||||||
|
return
|
||||||
|
|
||||||
if not _is_jsonable(value):
|
if not _is_jsonable(value):
|
||||||
raise ValueError("Value must be JSON serializable.")
|
raise ValueError("Value must be JSON serializable.")
|
||||||
|
|
@ -420,7 +421,7 @@ class Recording:
|
||||||
Create a recording and add metadata:
|
Create a recording and add metadata:
|
||||||
|
|
||||||
>>> import numpy
|
>>> import numpy
|
||||||
>>> from ria_toolkit_oss.datatypes import Recording
|
>>> from ria_toolkit_oss.data import Recording
|
||||||
|
|
||||||
>>> samples = numpy.ones(10000, dtype=numpy.complex64)
|
>>> samples = numpy.ones(10000, dtype=numpy.complex64)
|
||||||
>>> metadata = {
|
>>> metadata = {
|
||||||
|
|
@ -444,7 +445,7 @@ class Recording:
|
||||||
'rec_id': 'fda0f41...'} # Example value
|
'rec_id': 'fda0f41...'} # Example value
|
||||||
"""
|
"""
|
||||||
if key not in PROTECTED_KEYS:
|
if key not in PROTECTED_KEYS:
|
||||||
self._metadata.pop(key)
|
self._metadata.pop(key, None)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Key {key} is protected and cannot be modified or removed.")
|
raise ValueError(f"Key {key} is protected and cannot be modified or removed.")
|
||||||
|
|
||||||
|
|
@ -453,7 +454,7 @@ class Recording:
|
||||||
|
|
||||||
:param output_path: The output image path. Defaults to "images/signal.png".
|
:param output_path: The output image path. Defaults to "images/signal.png".
|
||||||
:type output_path: str, optional
|
:type output_path: str, optional
|
||||||
:param kwargs: Keyword arguments passed on to utils.view.view_sig.
|
:param kwargs: Keyword arguments passed on to ria_toolkit_oss.view.view_sig.
|
||||||
:type: dict of keyword arguments
|
:type: dict of keyword arguments
|
||||||
|
|
||||||
**Examples:**
|
**Examples:**
|
||||||
|
|
@ -461,7 +462,7 @@ class Recording:
|
||||||
Create a recording and view it as a plot in a .png image:
|
Create a recording and view it as a plot in a .png image:
|
||||||
|
|
||||||
>>> import numpy
|
>>> import numpy
|
||||||
>>> from utils.data import Recording
|
>>> from ria_toolkit_oss.data import Recording
|
||||||
|
|
||||||
>>> samples = numpy.ones(10000, dtype=numpy.complex64)
|
>>> samples = numpy.ones(10000, dtype=numpy.complex64)
|
||||||
>>> metadata = {
|
>>> metadata = {
|
||||||
|
|
@ -479,7 +480,7 @@ class Recording:
|
||||||
def simple_view(self, **kwargs) -> None:
|
def simple_view(self, **kwargs) -> None:
|
||||||
"""Create a plot of various signal visualizations as a PNG or SVG image.
|
"""Create a plot of various signal visualizations as a PNG or SVG image.
|
||||||
|
|
||||||
:param kwargs: Keyword arguments passed on to utils.view.view_signal_simple.create_plots.
|
:param kwargs: Keyword arguments passed on to ria_toolkit_oss.view.view_signal_simple.view_simple_sig.
|
||||||
:type: dict of keyword arguments
|
:type: dict of keyword arguments
|
||||||
|
|
||||||
**Examples:**
|
**Examples:**
|
||||||
|
|
@ -487,7 +488,7 @@ class Recording:
|
||||||
Create a recording and view it as a plot in a .png image:
|
Create a recording and view it as a plot in a .png image:
|
||||||
|
|
||||||
>>> import numpy
|
>>> import numpy
|
||||||
>>> from utils.data import Recording
|
>>> from ria_toolkit_oss.data import Recording
|
||||||
|
|
||||||
>>> samples = numpy.ones(10000, dtype=numpy.complex64)
|
>>> samples = numpy.ones(10000, dtype=numpy.complex64)
|
||||||
>>> metadata = {
|
>>> metadata = {
|
||||||
|
|
@ -510,7 +511,7 @@ class Recording:
|
||||||
The SigMF io format is defined by the `SigMF Specification Project <https://github.com/sigmf/SigMF>`_
|
The SigMF io format is defined by the `SigMF Specification Project <https://github.com/sigmf/SigMF>`_
|
||||||
|
|
||||||
:param recording: The recording to be written to file.
|
:param recording: The recording to be written to file.
|
||||||
:type recording: ria_toolkit_oss.datatypes.Recording
|
:type recording: ria_toolkit_oss.data.Recording
|
||||||
:param filename: The name of the file where the recording is to be saved. Defaults to auto generated filename.
|
:param filename: The name of the file where the recording is to be saved. Defaults to auto generated filename.
|
||||||
:type filename: os.PathLike or str, optional
|
:type filename: os.PathLike or str, optional
|
||||||
:param path: The directory path to where the recording is to be saved. Defaults to recordings/.
|
:param path: The directory path to where the recording is to be saved. Defaults to recordings/.
|
||||||
|
|
@ -544,7 +545,7 @@ class Recording:
|
||||||
Create a recording and save it to a .npy file:
|
Create a recording and save it to a .npy file:
|
||||||
|
|
||||||
>>> import numpy
|
>>> import numpy
|
||||||
>>> from ria_toolkit_oss.datatypes import Recording
|
>>> from ria_toolkit_oss.data import Recording
|
||||||
|
|
||||||
>>> samples = numpy.ones(10000, dtype=numpy.complex64)
|
>>> samples = numpy.ones(10000, dtype=numpy.complex64)
|
||||||
>>> metadata = {
|
>>> metadata = {
|
||||||
|
|
@ -595,13 +596,13 @@ class Recording:
|
||||||
Create a recording and save it to a .wav file:
|
Create a recording and save it to a .wav file:
|
||||||
|
|
||||||
>>> import numpy
|
>>> import numpy
|
||||||
>>> from utils.data import Recording
|
>>> from ria_toolkit_oss.data import Recording
|
||||||
>>> samples = numpy.exp(1j * 2 * numpy.pi * 0.1 * numpy.arange(10000))
|
>>> samples = numpy.exp(1j * 2 * numpy.pi * 0.1 * numpy.arange(10000))
|
||||||
>>> metadata = {"sample_rate": 1e6, "center_frequency": 915e6}
|
>>> metadata = {"sample_rate": 1e6, "center_frequency": 915e6}
|
||||||
>>> recording = Recording(data=samples, metadata=metadata)
|
>>> recording = Recording(data=samples, metadata=metadata)
|
||||||
>>> recording.to_wav()
|
>>> recording.to_wav()
|
||||||
"""
|
"""
|
||||||
from utils.io.recording import to_wav
|
from ria_toolkit_oss.io.recording import to_wav
|
||||||
|
|
||||||
return to_wav(
|
return to_wav(
|
||||||
recording=self,
|
recording=self,
|
||||||
|
|
@ -645,13 +646,13 @@ class Recording:
|
||||||
Create a recording and save it to a .blue file:
|
Create a recording and save it to a .blue file:
|
||||||
|
|
||||||
>>> import numpy
|
>>> import numpy
|
||||||
>>> from utils.data import Recording
|
>>> from ria_toolkit_oss.data import Recording
|
||||||
>>> samples = numpy.ones(10000, dtype=numpy.complex64)
|
>>> samples = numpy.ones(10000, dtype=numpy.complex64)
|
||||||
>>> metadata = {"sample_rate": 1e6, "center_frequency": 2.44e9}
|
>>> metadata = {"sample_rate": 1e6, "center_frequency": 2.44e9}
|
||||||
>>> recording = Recording(data=samples, metadata=metadata)
|
>>> recording = Recording(data=samples, metadata=metadata)
|
||||||
>>> recording.to_blue()
|
>>> recording.to_blue()
|
||||||
"""
|
"""
|
||||||
from utils.io.recording import to_blue
|
from ria_toolkit_oss.io.recording import to_blue
|
||||||
|
|
||||||
return to_blue(recording=self, filename=filename, path=path, data_format=data_format, overwrite=overwrite)
|
return to_blue(recording=self, filename=filename, path=path, data_format=data_format, overwrite=overwrite)
|
||||||
|
|
||||||
|
|
@ -673,7 +674,7 @@ class Recording:
|
||||||
Create a recording and trim it:
|
Create a recording and trim it:
|
||||||
|
|
||||||
>>> import numpy
|
>>> import numpy
|
||||||
>>> from ria_toolkit_oss.datatypes import Recording
|
>>> from ria_toolkit_oss.data import Recording
|
||||||
|
|
||||||
>>> samples = numpy.ones(10000, dtype=numpy.complex64)
|
>>> samples = numpy.ones(10000, dtype=numpy.complex64)
|
||||||
>>> metadata = {
|
>>> metadata = {
|
||||||
|
|
@ -702,7 +703,14 @@ class Recording:
|
||||||
data = self.data[:, start_sample:end_sample]
|
data = self.data[:, start_sample:end_sample]
|
||||||
|
|
||||||
new_annotations = copy.deepcopy(self.annotations)
|
new_annotations = copy.deepcopy(self.annotations)
|
||||||
|
trimmed_annotations = []
|
||||||
for annotation in new_annotations:
|
for annotation in new_annotations:
|
||||||
|
# skip annotations entirely outside the trim window
|
||||||
|
if annotation.sample_start + annotation.sample_count <= start_sample:
|
||||||
|
continue
|
||||||
|
if annotation.sample_start >= end_sample:
|
||||||
|
continue
|
||||||
|
|
||||||
# trim annotation if it goes outside the trim boundaries
|
# trim annotation if it goes outside the trim boundaries
|
||||||
if annotation.sample_start < start_sample:
|
if annotation.sample_start < start_sample:
|
||||||
annotation.sample_count = annotation.sample_count - (start_sample - annotation.sample_start)
|
annotation.sample_count = annotation.sample_count - (start_sample - annotation.sample_start)
|
||||||
|
|
@ -713,8 +721,9 @@ class Recording:
|
||||||
|
|
||||||
# shift annotation to align with the new start point
|
# shift annotation to align with the new start point
|
||||||
annotation.sample_start = annotation.sample_start - start_sample
|
annotation.sample_start = annotation.sample_start - start_sample
|
||||||
|
trimmed_annotations.append(annotation)
|
||||||
|
|
||||||
return Recording(data=data, metadata=self.metadata, annotations=new_annotations)
|
return Recording(data=data, metadata=self.metadata, annotations=trimmed_annotations)
|
||||||
|
|
||||||
def normalize(self) -> Recording:
|
def normalize(self) -> Recording:
|
||||||
"""Scale the recording data, relative to its maximum value, so that the magnitude of the maximum sample is 1.
|
"""Scale the recording data, relative to its maximum value, so that the magnitude of the maximum sample is 1.
|
||||||
|
|
@ -727,7 +736,7 @@ class Recording:
|
||||||
Create a recording with maximum amplitude 0.5 and normalize to a maximum amplitude of 1:
|
Create a recording with maximum amplitude 0.5 and normalize to a maximum amplitude of 1:
|
||||||
|
|
||||||
>>> import numpy
|
>>> import numpy
|
||||||
>>> from ria_toolkit_oss.datatypes import Recording
|
>>> from ria_toolkit_oss.data import Recording
|
||||||
|
|
||||||
>>> samples = numpy.ones(10000, dtype=numpy.complex64) * 0.5
|
>>> samples = numpy.ones(10000, dtype=numpy.complex64) * 0.5
|
||||||
>>> metadata = {
|
>>> metadata = {
|
||||||
|
|
@ -743,7 +752,10 @@ class Recording:
|
||||||
>>> print(numpy.max(numpy.abs(normalized_recording.data)))
|
>>> print(numpy.max(numpy.abs(normalized_recording.data)))
|
||||||
1
|
1
|
||||||
"""
|
"""
|
||||||
scaled_data = self.data / np.max(abs(self.data))
|
max_val = np.max(abs(self.data))
|
||||||
|
if max_val == 0:
|
||||||
|
raise ValueError("Cannot normalize a recording with all-zero data.")
|
||||||
|
scaled_data = self.data / max_val
|
||||||
return Recording(data=scaled_data, metadata=self.metadata, annotations=self.annotations)
|
return Recording(data=scaled_data, metadata=self.metadata, annotations=self.annotations)
|
||||||
|
|
||||||
def __len__(self) -> int:
|
def __len__(self) -> int:
|
||||||
|
|
@ -1,8 +0,0 @@
|
||||||
"""
|
|
||||||
The datatypes package contains abstract data types tailored for radio machine learning.
|
|
||||||
"""
|
|
||||||
|
|
||||||
__all__ = ["Annotation", "Recording"]
|
|
||||||
|
|
||||||
from .annotation import Annotation
|
|
||||||
from .recording import Recording
|
|
||||||
|
|
@ -1,13 +1,15 @@
|
||||||
"""
|
"""
|
||||||
Utilities for input/output operations on the ria_toolkit_oss.datatypes.Recording object.
|
Utilities for input/output operations on the ria_toolkit_oss.data.Recording object.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import datetime
|
import datetime
|
||||||
import datetime as dt
|
import datetime as dt
|
||||||
|
import json
|
||||||
import numbers
|
import numbers
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
import struct
|
import struct
|
||||||
|
import warnings
|
||||||
from datetime import timezone
|
from datetime import timezone
|
||||||
from typing import Any, List, Optional
|
from typing import Any, List, Optional
|
||||||
|
|
||||||
|
|
@ -17,8 +19,8 @@ from quantiphy import Quantity
|
||||||
from sigmf import SigMFFile, sigmffile
|
from sigmf import SigMFFile, sigmffile
|
||||||
from sigmf.utils import get_data_type_str
|
from sigmf.utils import get_data_type_str
|
||||||
|
|
||||||
from ria_toolkit_oss.datatypes import Annotation
|
from ria_toolkit_oss.data import Annotation
|
||||||
from ria_toolkit_oss.datatypes.recording import Recording
|
from ria_toolkit_oss.data.recording import Recording
|
||||||
|
|
||||||
_BLUE_META_PREFIX = "META_"
|
_BLUE_META_PREFIX = "META_"
|
||||||
_BLUE_META_TAG_MAX_LEN = 60
|
_BLUE_META_TAG_MAX_LEN = 60
|
||||||
|
|
@ -62,7 +64,7 @@ def to_npy(
|
||||||
"""Write recording to ``.npy`` binary file.
|
"""Write recording to ``.npy`` binary file.
|
||||||
|
|
||||||
:param recording: The recording to be written to file.
|
:param recording: The recording to be written to file.
|
||||||
:type recording: ria_toolkit_oss.datatypes.Recording
|
:type recording: ria_toolkit_oss.data.Recording
|
||||||
:param filename: The name of the file where the recording is to be saved. Defaults to auto generated filename.
|
:param filename: The name of the file where the recording is to be saved. Defaults to auto generated filename.
|
||||||
:type filename: os.PathLike or str, optional
|
:type filename: os.PathLike or str, optional
|
||||||
:param path: The directory path to where the recording is to be saved. Defaults to recordings/.
|
:param path: The directory path to where the recording is to be saved. Defaults to recordings/.
|
||||||
|
|
@ -91,15 +93,35 @@ def to_npy(
|
||||||
metadata = recording.metadata
|
metadata = recording.metadata
|
||||||
annotations = recording.annotations
|
annotations = recording.annotations
|
||||||
|
|
||||||
with open(file=fullpath, mode="wb") as f:
|
# Serialize metadata and annotations as JSON to avoid pickle-based deserialization.
|
||||||
np.save(f, data)
|
# JSON is safe; pickle allows arbitrary code execution when loading untrusted files.
|
||||||
np.save(f, metadata)
|
metadata_bytes = json.dumps(convert_to_serializable(metadata)).encode()
|
||||||
np.save(f, annotations)
|
annotations_bytes = json.dumps([a.__dict__ for a in annotations]).encode()
|
||||||
|
|
||||||
|
with open(file=fullpath, mode="wb") as f:
|
||||||
|
# Write format version marker first so from_npy can detect the safe JSON format.
|
||||||
|
np.save(f, np.array("ria-toolkit-oss-v2"))
|
||||||
|
np.save(f, data)
|
||||||
|
np.save(f, np.frombuffer(metadata_bytes, dtype=np.uint8))
|
||||||
|
np.save(f, np.frombuffer(annotations_bytes, dtype=np.uint8))
|
||||||
|
|
||||||
# print(f"Saved recording to {os.getcwd()}/{fullpath}")
|
|
||||||
return str(fullpath)
|
return str(fullpath)
|
||||||
|
|
||||||
|
|
||||||
|
_NPY_MAGIC = b"\x93NUMPY"
|
||||||
|
|
||||||
|
|
||||||
|
def _check_npy_magic(filepath: str) -> None:
|
||||||
|
"""Raise ValueError if the file does not start with the NumPy magic bytes."""
|
||||||
|
try:
|
||||||
|
with open(filepath, "rb") as f:
|
||||||
|
header = f.read(6)
|
||||||
|
except OSError as e:
|
||||||
|
raise IOError(f"Cannot open file for validation: {filepath}") from e
|
||||||
|
if header != _NPY_MAGIC:
|
||||||
|
raise ValueError(f"File does not appear to be a valid NumPy .npy file (bad magic bytes): {filepath}")
|
||||||
|
|
||||||
|
|
||||||
def from_npy(file: os.PathLike | str, legacy: bool = False) -> Recording:
|
def from_npy(file: os.PathLike | str, legacy: bool = False) -> Recording:
|
||||||
"""Load a recording from a ``.npy`` binary file.
|
"""Load a recording from a ``.npy`` binary file.
|
||||||
|
|
||||||
|
|
@ -113,7 +135,7 @@ def from_npy(file: os.PathLike | str, legacy: bool = False) -> Recording:
|
||||||
:raises IOError: If there is an issue encountered during the file reading process.
|
:raises IOError: If there is an issue encountered during the file reading process.
|
||||||
|
|
||||||
:return: The recording, as initialized from the ``.npy`` file.
|
:return: The recording, as initialized from the ``.npy`` file.
|
||||||
:rtype: ria_toolkit_oss.datatypes.Recording
|
:rtype: ria_toolkit_oss.data.Recording
|
||||||
"""
|
"""
|
||||||
|
|
||||||
filename, extension = os.path.splitext(file)
|
filename, extension = os.path.splitext(file)
|
||||||
|
|
@ -126,14 +148,37 @@ def from_npy(file: os.PathLike | str, legacy: bool = False) -> Recording:
|
||||||
if legacy:
|
if legacy:
|
||||||
return from_npy_legacy(filename)
|
return from_npy_legacy(filename)
|
||||||
|
|
||||||
|
_check_npy_magic(filename)
|
||||||
|
|
||||||
with open(file=filename, mode="rb") as f:
|
with open(file=filename, mode="rb") as f:
|
||||||
data = np.load(f, allow_pickle=True)
|
first = np.load(f, allow_pickle=False)
|
||||||
metadata = np.load(f, allow_pickle=True)
|
|
||||||
metadata = metadata.tolist()
|
if first.ndim == 0 and first.dtype.kind in ("U", "S") and str(first) == "ria-toolkit-oss-v2":
|
||||||
try:
|
# Safe JSON format written by current to_npy.
|
||||||
annotations = list(np.load(f, allow_pickle=True))
|
data = np.load(f, allow_pickle=False)
|
||||||
except EOFError:
|
raw_meta = np.load(f, allow_pickle=False)
|
||||||
annotations = []
|
metadata = json.loads(raw_meta.tobytes().decode())
|
||||||
|
try:
|
||||||
|
raw_ann = np.load(f, allow_pickle=False)
|
||||||
|
ann_list = json.loads(raw_ann.tobytes().decode())
|
||||||
|
from ria_toolkit_oss.data.annotation import Annotation
|
||||||
|
|
||||||
|
annotations = [Annotation(**a) for a in ann_list]
|
||||||
|
except EOFError:
|
||||||
|
annotations = []
|
||||||
|
else:
|
||||||
|
# Legacy pickle-based format. Only load files from trusted sources.
|
||||||
|
warnings.warn(
|
||||||
|
"Loading .npy file in legacy pickle format — only load files from trusted sources. "
|
||||||
|
"Re-save with to_npy() to upgrade to the safe JSON format.",
|
||||||
|
stacklevel=2,
|
||||||
|
)
|
||||||
|
data = first # already loaded without pickle (numeric array)
|
||||||
|
metadata = np.load(f, allow_pickle=True).tolist()
|
||||||
|
try:
|
||||||
|
annotations = list(np.load(f, allow_pickle=True))
|
||||||
|
except EOFError:
|
||||||
|
annotations = []
|
||||||
|
|
||||||
recording = Recording(data=data, metadata=metadata, annotations=annotations)
|
recording = Recording(data=data, metadata=metadata, annotations=annotations)
|
||||||
return recording
|
return recording
|
||||||
|
|
@ -153,7 +198,7 @@ def from_npy_legacy(file: os.PathLike | str) -> Recording:
|
||||||
:raises IOError: If there is an issue encountered during the file reading process.
|
:raises IOError: If there is an issue encountered during the file reading process.
|
||||||
|
|
||||||
:return: The recording, as initialized from the legacy ``.npy`` file.
|
:return: The recording, as initialized from the legacy ``.npy`` file.
|
||||||
:rtype: ria_toolkit_oss.datatypes.Recording
|
:rtype: ria_toolkit_oss.data.Recording
|
||||||
|
|
||||||
**Examples:**
|
**Examples:**
|
||||||
|
|
||||||
|
|
@ -171,14 +216,20 @@ def from_npy_legacy(file: os.PathLike | str) -> Recording:
|
||||||
# Rebuild with .npy extension.
|
# Rebuild with .npy extension.
|
||||||
filename = str(filename) + ".npy"
|
filename = str(filename) + ".npy"
|
||||||
|
|
||||||
|
warnings.warn(
|
||||||
|
"from_npy_legacy uses pickle deserialization for extended metadata — only load files from trusted sources.",
|
||||||
|
stacklevel=2,
|
||||||
|
)
|
||||||
|
_check_npy_magic(filename)
|
||||||
|
|
||||||
with open(filename, "rb") as f:
|
with open(filename, "rb") as f:
|
||||||
# Read IQ data (2, N) format
|
# Read IQ data (2, N) format
|
||||||
iqdata = np.load(f)
|
iqdata = np.load(f, allow_pickle=False)
|
||||||
|
|
||||||
# Read basic metadata array [center_freq, rec_length, decimation, sample_rate]
|
# Read basic metadata array [center_freq, rec_length, decimation, sample_rate]
|
||||||
meta = np.load(f)
|
meta = np.load(f, allow_pickle=False)
|
||||||
|
|
||||||
# Read extended metadata dict
|
# Read extended metadata dict (legacy format requires pickle)
|
||||||
extended_meta = np.load(f, allow_pickle=True)[0]
|
extended_meta = np.load(f, allow_pickle=True)[0]
|
||||||
|
|
||||||
# Convert IQ data from (2, N) to (N,) complex format
|
# Convert IQ data from (2, N) to (N,) complex format
|
||||||
|
|
@ -219,7 +270,7 @@ def to_sigmf(
|
||||||
The SigMF io format is defined by the `SigMF Specification Project <https://github.com/sigmf/SigMF>`_
|
The SigMF io format is defined by the `SigMF Specification Project <https://github.com/sigmf/SigMF>`_
|
||||||
|
|
||||||
:param recording: The recording to be written to file.
|
:param recording: The recording to be written to file.
|
||||||
:type recording: ria_toolkit_oss.datatypes.Recording
|
:type recording: ria_toolkit_oss.data.Recording
|
||||||
:param filename: The name of the file where the recording is to be saved. Defaults to auto generated filename.
|
:param filename: The name of the file where the recording is to be saved. Defaults to auto generated filename.
|
||||||
:type filename: os.PathLike or str, optional
|
:type filename: os.PathLike or str, optional
|
||||||
:param path: The directory path to where the recording is to be saved. Defaults to recordings/.
|
:param path: The directory path to where the recording is to be saved. Defaults to recordings/.
|
||||||
|
|
@ -279,7 +330,7 @@ def to_sigmf(
|
||||||
converted_metadata = {
|
converted_metadata = {
|
||||||
sigmf_key: metadata[metadata_key]
|
sigmf_key: metadata[metadata_key]
|
||||||
for sigmf_key, metadata_key in SIGMF_KEY_CONVERSION.items()
|
for sigmf_key, metadata_key in SIGMF_KEY_CONVERSION.items()
|
||||||
if metadata_key in metadata
|
if metadata_key in metadata and sigmf_key != SigMFFile.HASH_KEY
|
||||||
}
|
}
|
||||||
|
|
||||||
# Merge dictionaries, giving priority to sigmf_meta
|
# Merge dictionaries, giving priority to sigmf_meta
|
||||||
|
|
@ -316,7 +367,7 @@ def to_sigmf(
|
||||||
meta_dict = sigMF_metafile.ordered_metadata()
|
meta_dict = sigMF_metafile.ordered_metadata()
|
||||||
meta_dict["ria"] = metadata
|
meta_dict["ria"] = metadata
|
||||||
|
|
||||||
sigMF_metafile.tofile(meta_file_path)
|
sigMF_metafile.tofile(meta_file_path, overwrite=overwrite)
|
||||||
|
|
||||||
|
|
||||||
def from_sigmf(file: os.PathLike | str) -> Recording:
|
def from_sigmf(file: os.PathLike | str) -> Recording:
|
||||||
|
|
@ -330,13 +381,12 @@ def from_sigmf(file: os.PathLike | str) -> Recording:
|
||||||
:raises IOError: If there is an issue encountered during the file reading process.
|
:raises IOError: If there is an issue encountered during the file reading process.
|
||||||
|
|
||||||
:return: The recording, as initialized from the SigMF files.
|
:return: The recording, as initialized from the SigMF files.
|
||||||
:rtype: ria_toolkit_oss.datatypes.Recording
|
:rtype: ria_toolkit_oss.data.Recording
|
||||||
"""
|
"""
|
||||||
|
|
||||||
file = str(file)
|
file = str(file)
|
||||||
if len(file) > 11:
|
if not file.endswith((".sigmf-data", ".sigmf-meta", ".sigmf")):
|
||||||
if file[-11:-5] != ".sigmf":
|
file = file + ".sigmf-data"
|
||||||
file = file + ".sigmf-data"
|
|
||||||
|
|
||||||
sigmf_file = sigmffile.fromfile(file)
|
sigmf_file = sigmffile.fromfile(file)
|
||||||
|
|
||||||
|
|
@ -349,7 +399,7 @@ def from_sigmf(file: os.PathLike | str) -> Recording:
|
||||||
# Process core keys
|
# Process core keys
|
||||||
if key.startswith("core:"):
|
if key.startswith("core:"):
|
||||||
base_key = key[5:] # Remove 'core:' prefix
|
base_key = key[5:] # Remove 'core:' prefix
|
||||||
converted_key = SIGMF_KEY_CONVERSION.get(base_key, base_key)
|
converted_key = SIGMF_KEY_CONVERSION.get(key, base_key)
|
||||||
# Process ria keys
|
# Process ria keys
|
||||||
elif key.startswith("ria:"):
|
elif key.startswith("ria:"):
|
||||||
converted_key = key[4:] # Remove 'ria:' prefix
|
converted_key = key[4:] # Remove 'ria:' prefix
|
||||||
|
|
@ -393,7 +443,7 @@ def to_wav(
|
||||||
in the ICMT (comment) field for human readability.
|
in the ICMT (comment) field for human readability.
|
||||||
|
|
||||||
:param recording: The recording to be written to file.
|
:param recording: The recording to be written to file.
|
||||||
:type recording: ria_toolkit_oss.datatypes.Recording
|
:type recording: ria_toolkit_oss.data.Recording
|
||||||
:param filename: The name of the file where the recording is to be saved.
|
:param filename: The name of the file where the recording is to be saved.
|
||||||
Defaults to auto-generated filename.
|
Defaults to auto-generated filename.
|
||||||
:type filename: str, optional
|
:type filename: str, optional
|
||||||
|
|
@ -503,7 +553,7 @@ def from_wav(file: os.PathLike | str) -> Recording:
|
||||||
:raises ValueError: If file is not stereo or has unsupported format.
|
:raises ValueError: If file is not stereo or has unsupported format.
|
||||||
|
|
||||||
:return: The recording, as initialized from the WAV file.
|
:return: The recording, as initialized from the WAV file.
|
||||||
:rtype: ria_toolkit_oss.datatypes.Recording
|
:rtype: ria_toolkit_oss.data.Recording
|
||||||
"""
|
"""
|
||||||
import wave
|
import wave
|
||||||
|
|
||||||
|
|
@ -585,7 +635,7 @@ def to_blue(
|
||||||
Commonly used with X-Midas and other RF/radar signal processing tools.
|
Commonly used with X-Midas and other RF/radar signal processing tools.
|
||||||
|
|
||||||
:param recording: The recording to be written to file.
|
:param recording: The recording to be written to file.
|
||||||
:type recording: ria_toolkit_oss.datatypes.Recording
|
:type recording: ria_toolkit_oss.data.Recording
|
||||||
:param filename: The name of the file where the recording is to be saved.
|
:param filename: The name of the file where the recording is to be saved.
|
||||||
Defaults to auto-generated filename.
|
Defaults to auto-generated filename.
|
||||||
:type filename: str, optional
|
:type filename: str, optional
|
||||||
|
|
@ -742,7 +792,7 @@ def from_blue(file: os.PathLike | str) -> Recording:
|
||||||
:raises ValueError: If file format is not valid or unsupported.
|
:raises ValueError: If file format is not valid or unsupported.
|
||||||
|
|
||||||
:return: The recording, as initialized from the Blue file.
|
:return: The recording, as initialized from the Blue file.
|
||||||
:rtype: ria_toolkit_oss.datatypes.Recording
|
:rtype: ria_toolkit_oss.data.Recording
|
||||||
"""
|
"""
|
||||||
filename = str(file)
|
filename = str(file)
|
||||||
if not filename.endswith(".blue"):
|
if not filename.endswith(".blue"):
|
||||||
|
|
@ -867,7 +917,7 @@ def load_recording(file: os.PathLike) -> Recording:
|
||||||
:raises ValueError: If the inferred file extension is not supported.
|
:raises ValueError: If the inferred file extension is not supported.
|
||||||
|
|
||||||
:return: The recording, as initialized from file(s).
|
:return: The recording, as initialized from file(s).
|
||||||
:rtype: ria_toolkit_oss.datatypes.Recording
|
:rtype: ria_toolkit_oss.data.Recording
|
||||||
"""
|
"""
|
||||||
_, extension = os.path.splitext(file)
|
_, extension = os.path.splitext(file)
|
||||||
extension = extension.lstrip(".")
|
extension = extension.lstrip(".")
|
||||||
|
|
|
||||||
26
src/ria_toolkit_oss/orchestration/__init__.py
Normal file
26
src/ria_toolkit_oss/orchestration/__init__.py
Normal file
|
|
@ -0,0 +1,26 @@
|
||||||
|
"""Orchestration layer for automated RF capture campaigns."""
|
||||||
|
|
||||||
|
from .campaign import (
|
||||||
|
CampaignConfig,
|
||||||
|
CaptureStep,
|
||||||
|
QAConfig,
|
||||||
|
RecorderConfig,
|
||||||
|
TransmitterConfig,
|
||||||
|
)
|
||||||
|
from .executor import CampaignExecutor, CampaignResult, StepResult
|
||||||
|
from .labeler import label_recording
|
||||||
|
from .qa import QAResult, check_recording
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"CampaignConfig",
|
||||||
|
"CaptureStep",
|
||||||
|
"QAConfig",
|
||||||
|
"RecorderConfig",
|
||||||
|
"TransmitterConfig",
|
||||||
|
"CampaignExecutor",
|
||||||
|
"CampaignResult",
|
||||||
|
"StepResult",
|
||||||
|
"label_recording",
|
||||||
|
"QAResult",
|
||||||
|
"check_recording",
|
||||||
|
]
|
||||||
503
src/ria_toolkit_oss/orchestration/campaign.py
Normal file
503
src/ria_toolkit_oss/orchestration/campaign.py
Normal file
|
|
@ -0,0 +1,503 @@
|
||||||
|
"""Campaign configuration schema and YAML parser for orchestrated RF captures."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import re
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import yaml
|
||||||
|
|
||||||
|
# Allowed characters in campaign names when used as filename components.
|
||||||
|
_SAFE_NAME_RE = re.compile(r"[^a-zA-Z0-9_\-]")
|
||||||
|
|
||||||
|
# Reasonable RF bounds for consumer/research SDR hardware.
|
||||||
|
_FREQ_MIN_HZ = 1.0 # 1 Hz
|
||||||
|
_FREQ_MAX_HZ = 300e9 # 300 GHz
|
||||||
|
_GAIN_MIN_DB = -30.0
|
||||||
|
_GAIN_MAX_DB = 120.0
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Parsing helpers
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def parse_duration(value: str | float | int) -> float:
|
||||||
|
"""Parse a duration string to seconds.
|
||||||
|
|
||||||
|
Accepts:
|
||||||
|
"30s" → 30.0
|
||||||
|
"1.5m" or "1.5min" → 90.0
|
||||||
|
"2h" → 7200.0
|
||||||
|
30 (numeric) → 30.0
|
||||||
|
"""
|
||||||
|
if isinstance(value, (int, float)):
|
||||||
|
return float(value)
|
||||||
|
value = str(value).strip()
|
||||||
|
match = re.fullmatch(r"([\d.]+)\s*(s|sec|m|min|h|hr)?", value, re.IGNORECASE)
|
||||||
|
if not match:
|
||||||
|
raise ValueError(f"Cannot parse duration: '{value}'")
|
||||||
|
amount = float(match.group(1))
|
||||||
|
unit = (match.group(2) or "s").lower()
|
||||||
|
if unit in ("h", "hr"):
|
||||||
|
return amount * 3600
|
||||||
|
if unit in ("m", "min"):
|
||||||
|
return amount * 60
|
||||||
|
return amount
|
||||||
|
|
||||||
|
|
||||||
|
def parse_frequency(value: str | float | int) -> float:
|
||||||
|
"""Parse a frequency string to Hz.
|
||||||
|
|
||||||
|
Accepts:
|
||||||
|
"2.45GHz" → 2_450_000_000.0
|
||||||
|
"40MHz" → 40_000_000.0
|
||||||
|
"915e6" → 915_000_000.0
|
||||||
|
2.45e9 (numeric) → 2_450_000_000.0
|
||||||
|
"""
|
||||||
|
if isinstance(value, (int, float)):
|
||||||
|
result = float(value)
|
||||||
|
if not (_FREQ_MIN_HZ <= result <= _FREQ_MAX_HZ):
|
||||||
|
raise ValueError(
|
||||||
|
f"Frequency {result:.3g} Hz is outside the supported range "
|
||||||
|
f"({_FREQ_MIN_HZ:.0f} Hz – {_FREQ_MAX_HZ:.3g} Hz)"
|
||||||
|
)
|
||||||
|
return result
|
||||||
|
value = str(value).strip()
|
||||||
|
|
||||||
|
# Try bare numeric first (handles scientific notation like "915e6")
|
||||||
|
try:
|
||||||
|
result = float(value)
|
||||||
|
except ValueError:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
if not (_FREQ_MIN_HZ <= result <= _FREQ_MAX_HZ):
|
||||||
|
raise ValueError(
|
||||||
|
f"Frequency {result:.3g} Hz is outside the supported range "
|
||||||
|
f"({_FREQ_MIN_HZ:.0f} Hz – {_FREQ_MAX_HZ:.3g} Hz): '{value}'"
|
||||||
|
)
|
||||||
|
return result
|
||||||
|
|
||||||
|
# Handle suffix notation: "2.45GHz", "40MHz", "40M", "433k"
|
||||||
|
match = re.fullmatch(r"([\d.]+)\s*(k|M|G)(?:\s*Hz?)?", value, re.IGNORECASE)
|
||||||
|
if match:
|
||||||
|
amount = float(match.group(1))
|
||||||
|
suffix = match.group(2).upper()
|
||||||
|
result = amount * {"K": 1e3, "M": 1e6, "G": 1e9}[suffix]
|
||||||
|
if not (_FREQ_MIN_HZ <= result <= _FREQ_MAX_HZ):
|
||||||
|
raise ValueError(
|
||||||
|
f"Frequency {result:.3g} Hz is outside the supported range "
|
||||||
|
f"({_FREQ_MIN_HZ:.0f} Hz – {_FREQ_MAX_HZ:.3g} Hz): '{value}'"
|
||||||
|
)
|
||||||
|
return result
|
||||||
|
|
||||||
|
raise ValueError(f"Cannot parse frequency: '{value}'")
|
||||||
|
|
||||||
|
|
||||||
|
def parse_gain(value: str | float | int) -> float | str:
|
||||||
|
"""Parse a gain string.
|
||||||
|
|
||||||
|
Accepts:
|
||||||
|
"40dB" or "40 dB" → 40.0
|
||||||
|
"auto" → "auto"
|
||||||
|
40 (numeric) → 40.0
|
||||||
|
"""
|
||||||
|
if isinstance(value, (int, float)):
|
||||||
|
result = float(value)
|
||||||
|
if not (_GAIN_MIN_DB <= result <= _GAIN_MAX_DB):
|
||||||
|
raise ValueError(f"Gain {result} dB is outside the supported range ({_GAIN_MIN_DB} – {_GAIN_MAX_DB} dB)")
|
||||||
|
return result
|
||||||
|
value = str(value).strip()
|
||||||
|
if value.lower() == "auto":
|
||||||
|
return "auto"
|
||||||
|
match = re.fullmatch(r"([\d.+\-]+)\s*dB?", value, re.IGNORECASE)
|
||||||
|
if not match:
|
||||||
|
raise ValueError(f"Cannot parse gain: '{value}'")
|
||||||
|
result = float(match.group(1))
|
||||||
|
if not (_GAIN_MIN_DB <= result <= _GAIN_MAX_DB):
|
||||||
|
raise ValueError(
|
||||||
|
f"Gain {result} dB is outside the supported range ({_GAIN_MIN_DB} – {_GAIN_MAX_DB} dB): '{value}'"
|
||||||
|
)
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def parse_bandwidth_mhz(value: str | float | int | None) -> Optional[float]:
|
||||||
|
"""Parse a bandwidth string to MHz.
|
||||||
|
|
||||||
|
Accepts:
|
||||||
|
"20MHz" → 20.0
|
||||||
|
"40MHz" → 40.0
|
||||||
|
20 (numeric, assumed MHz) → 20.0
|
||||||
|
None → None
|
||||||
|
"""
|
||||||
|
if value is None:
|
||||||
|
return None
|
||||||
|
if isinstance(value, (int, float)):
|
||||||
|
return float(value)
|
||||||
|
value = str(value).strip()
|
||||||
|
match = re.fullmatch(r"([\d.]+)\s*MHz?", value, re.IGNORECASE)
|
||||||
|
if match:
|
||||||
|
return float(match.group(1))
|
||||||
|
match = re.fullmatch(r"([\d.]+)", value)
|
||||||
|
if match:
|
||||||
|
return float(match.group(1))
|
||||||
|
raise ValueError(f"Cannot parse bandwidth: '{value}'")
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Config dataclasses
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class RecorderConfig:
|
||||||
|
"""SDR recorder configuration."""
|
||||||
|
|
||||||
|
device: str
|
||||||
|
center_freq: float # Hz
|
||||||
|
sample_rate: float # Hz
|
||||||
|
gain: float | str # dB float, or "auto"
|
||||||
|
bandwidth: Optional[float] = None # Hz, None = match sample_rate
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_dict(cls, d: dict) -> "RecorderConfig":
|
||||||
|
gain = parse_gain(d.get("gain", "auto"))
|
||||||
|
bandwidth_raw = d.get("bandwidth") or d.get("bandwidth_hz")
|
||||||
|
bandwidth = parse_frequency(bandwidth_raw) if bandwidth_raw else None
|
||||||
|
return cls(
|
||||||
|
device=str(d["device"]),
|
||||||
|
center_freq=parse_frequency(d["center_freq"]),
|
||||||
|
sample_rate=parse_frequency(d["sample_rate"]),
|
||||||
|
gain=gain,
|
||||||
|
bandwidth=bandwidth,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class CaptureStep:
|
||||||
|
"""A single timed capture within a transmitter schedule."""
|
||||||
|
|
||||||
|
duration: float # seconds
|
||||||
|
label: str # used as filename component
|
||||||
|
|
||||||
|
# WiFi-specific
|
||||||
|
channel: Optional[int] = None
|
||||||
|
bandwidth_mhz: Optional[float] = None # MHz
|
||||||
|
traffic: Optional[str] = None
|
||||||
|
|
||||||
|
# Bluetooth-specific
|
||||||
|
connection_interval_ms: Optional[float] = None
|
||||||
|
|
||||||
|
# Power (dBm), optional
|
||||||
|
power_dbm: Optional[float] = None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_dict(cls, d: dict, auto_label: bool = True) -> "CaptureStep":
|
||||||
|
duration = parse_duration(d["duration"])
|
||||||
|
label = d.get("label", "")
|
||||||
|
if not label and auto_label:
|
||||||
|
parts = []
|
||||||
|
if d.get("channel"):
|
||||||
|
parts.append(f"ch{d['channel']:02d}")
|
||||||
|
if d.get("bandwidth"):
|
||||||
|
bw = parse_bandwidth_mhz(d["bandwidth"])
|
||||||
|
parts.append(f"{int(bw)}mhz")
|
||||||
|
if d.get("traffic"):
|
||||||
|
parts.append(str(d["traffic"]).replace(" ", "_"))
|
||||||
|
label = "_".join(parts) if parts else "capture"
|
||||||
|
return cls(
|
||||||
|
duration=duration,
|
||||||
|
label=label,
|
||||||
|
channel=d.get("channel"),
|
||||||
|
bandwidth_mhz=parse_bandwidth_mhz(d.get("bandwidth")),
|
||||||
|
traffic=d.get("traffic"),
|
||||||
|
connection_interval_ms=d.get("connection_interval_ms"),
|
||||||
|
power_dbm=float(d["power"].removesuffix("dBm").strip()) if d.get("power") else None,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TransmitterConfig:
|
||||||
|
"""Configuration for a single transmitter device in the campaign."""
|
||||||
|
|
||||||
|
id: str
|
||||||
|
type: str # "wifi", "bluetooth", "sdr", "external"
|
||||||
|
control_method: str # "external_script" | "sdr" | "sdr_remote"
|
||||||
|
schedule: list[CaptureStep]
|
||||||
|
|
||||||
|
# For external_script control
|
||||||
|
script: Optional[str] = None # path to control script
|
||||||
|
device: Optional[str] = None # e.g. "/dev/wlan0"
|
||||||
|
|
||||||
|
# For sdr_remote control — keys: host, ssh_user, ssh_key_path, device_type, device_id, zmq_port
|
||||||
|
sdr_remote: Optional[dict] = None
|
||||||
|
|
||||||
|
# For sdr_agent control — keys: modulation, order, symbol_rate, center_frequency, filter, rolloff
|
||||||
|
sdr_agent: Optional[dict] = None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_dict(cls, d: dict) -> "TransmitterConfig":
|
||||||
|
schedule = [CaptureStep.from_dict(s) for s in d.get("schedule", [])]
|
||||||
|
return cls(
|
||||||
|
id=str(d["id"]),
|
||||||
|
type=str(d["type"]),
|
||||||
|
control_method=str(d.get("control_method", "external_script")),
|
||||||
|
schedule=schedule,
|
||||||
|
script=d.get("script"),
|
||||||
|
device=d.get("device"),
|
||||||
|
sdr_remote=d.get("sdr_remote"),
|
||||||
|
sdr_agent=d.get("sdr_agent"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class QAConfig:
|
||||||
|
"""Quality assurance thresholds."""
|
||||||
|
|
||||||
|
snr_threshold_db: float = 10.0
|
||||||
|
min_duration_s: float = 25.0
|
||||||
|
flag_for_review: bool = True
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_dict(cls, d: dict) -> "QAConfig":
|
||||||
|
return cls(
|
||||||
|
snr_threshold_db=float(str(d.get("snr_threshold", "10")).rstrip("dB").strip()),
|
||||||
|
min_duration_s=parse_duration(d.get("min_duration", "25s")),
|
||||||
|
flag_for_review=bool(d.get("flag_for_review", True)),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class OutputConfig:
|
||||||
|
"""Where to save captured recordings."""
|
||||||
|
|
||||||
|
format: str = "sigmf"
|
||||||
|
path: str = "recordings"
|
||||||
|
device_id: Optional[str] = None # for device-profile campaigns
|
||||||
|
repo: Optional[str] = None
|
||||||
|
folder: Optional[str] = None # repo subfolder: None = use campaign name, "" = no subfolder, str = custom
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_dict(cls, d: dict) -> "OutputConfig":
|
||||||
|
return cls(
|
||||||
|
format=str(d.get("format", "sigmf")),
|
||||||
|
path=str(d.get("path", "recordings")),
|
||||||
|
device_id=d.get("device_id"),
|
||||||
|
repo=d.get("repo"),
|
||||||
|
folder=d.get("folder"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class CampaignConfig:
|
||||||
|
"""Full campaign configuration parsed from YAML."""
|
||||||
|
|
||||||
|
name: str
|
||||||
|
recorder: RecorderConfig
|
||||||
|
transmitters: list[TransmitterConfig]
|
||||||
|
qa: QAConfig = field(default_factory=QAConfig)
|
||||||
|
output: OutputConfig = field(default_factory=OutputConfig)
|
||||||
|
mode: str = "controlled_testbed"
|
||||||
|
loops: int = 1 # repeat full schedule this many times; labels get _run{N:02d} suffix
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Loaders
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_dict(cls, raw: dict) -> "CampaignConfig":
|
||||||
|
"""Build a CampaignConfig from a parsed dictionary.
|
||||||
|
|
||||||
|
Accepts the same structure as the campaign YAML, already loaded into
|
||||||
|
a Python dict (e.g. from a JSON HTTP request body).
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If required fields are missing or malformed.
|
||||||
|
KeyError: If ``recorder`` key is absent.
|
||||||
|
"""
|
||||||
|
campaign_meta = raw.get("campaign", {})
|
||||||
|
transmitters = [TransmitterConfig.from_dict(t) for t in raw.get("transmitters", [])]
|
||||||
|
if not transmitters:
|
||||||
|
raise ValueError("Campaign config must define at least one transmitter")
|
||||||
|
if "recorder" not in raw:
|
||||||
|
raise ValueError("Campaign config is missing required 'recorder' section")
|
||||||
|
raw_name = str(campaign_meta.get("name", "unnamed"))
|
||||||
|
safe_name = _SAFE_NAME_RE.sub("_", raw_name)
|
||||||
|
return cls(
|
||||||
|
name=safe_name,
|
||||||
|
mode=str(campaign_meta.get("mode", "controlled_testbed")),
|
||||||
|
loops=max(1, int(campaign_meta.get("loops", 1))),
|
||||||
|
recorder=RecorderConfig.from_dict(raw["recorder"]),
|
||||||
|
transmitters=transmitters,
|
||||||
|
qa=QAConfig.from_dict(raw.get("qa", {})),
|
||||||
|
output=OutputConfig.from_dict(raw.get("output", {})),
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_yaml(cls, path: str | Path) -> "CampaignConfig":
|
||||||
|
"""Load a full campaign config YAML.
|
||||||
|
|
||||||
|
Expected format::
|
||||||
|
|
||||||
|
campaign:
|
||||||
|
name: "wifi_capture_001"
|
||||||
|
mode: "controlled_testbed"
|
||||||
|
|
||||||
|
transmitters:
|
||||||
|
- id: "laptop_wifi"
|
||||||
|
type: "wifi"
|
||||||
|
control_method: "external_script"
|
||||||
|
script: "./scripts/wifi_control.sh"
|
||||||
|
device: "/dev/wlan0"
|
||||||
|
schedule:
|
||||||
|
- channel: 6
|
||||||
|
bandwidth: "20MHz"
|
||||||
|
traffic: "iperf_udp"
|
||||||
|
duration: "30s"
|
||||||
|
|
||||||
|
recorder:
|
||||||
|
device: "usrp_b210"
|
||||||
|
center_freq: "2.45GHz"
|
||||||
|
sample_rate: "40MHz"
|
||||||
|
gain: "40dB"
|
||||||
|
|
||||||
|
qa:
|
||||||
|
snr_threshold: "10dB"
|
||||||
|
min_duration: "25s"
|
||||||
|
flag_for_review: true
|
||||||
|
|
||||||
|
output:
|
||||||
|
format: "sigmf"
|
||||||
|
path: "./recordings"
|
||||||
|
"""
|
||||||
|
path = Path(path)
|
||||||
|
try:
|
||||||
|
with open(path) as f:
|
||||||
|
raw = yaml.safe_load(f)
|
||||||
|
except FileNotFoundError:
|
||||||
|
raise FileNotFoundError(f"Campaign config not found: {path}")
|
||||||
|
except yaml.YAMLError as e:
|
||||||
|
raise ValueError(f"Invalid YAML in {path}: {e}")
|
||||||
|
|
||||||
|
campaign_meta = raw.get("campaign", {})
|
||||||
|
transmitters = [TransmitterConfig.from_dict(t) for t in raw.get("transmitters", [])]
|
||||||
|
if not transmitters:
|
||||||
|
raise ValueError("Campaign config must define at least one transmitter")
|
||||||
|
if "recorder" not in raw:
|
||||||
|
raise ValueError(f"Campaign config is missing required 'recorder' section in {path}")
|
||||||
|
raw_name = str(campaign_meta.get("name", path.stem))
|
||||||
|
safe_name = _SAFE_NAME_RE.sub("_", raw_name)
|
||||||
|
|
||||||
|
return cls(
|
||||||
|
name=safe_name,
|
||||||
|
mode=str(campaign_meta.get("mode", "controlled_testbed")),
|
||||||
|
loops=max(1, int(campaign_meta.get("loops", 1))),
|
||||||
|
recorder=RecorderConfig.from_dict(raw["recorder"]),
|
||||||
|
transmitters=transmitters,
|
||||||
|
qa=QAConfig.from_dict(raw.get("qa", {})),
|
||||||
|
output=OutputConfig.from_dict(raw.get("output", {})),
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_device_profile(cls, path: str | Path) -> "CampaignConfig":
|
||||||
|
"""Build a campaign config from an App 1 device profile YAML.
|
||||||
|
|
||||||
|
Expected format::
|
||||||
|
|
||||||
|
device:
|
||||||
|
name: "iPhone_13_WiFi"
|
||||||
|
type: "wifi"
|
||||||
|
protocol: "wifi_24ghz"
|
||||||
|
|
||||||
|
capture:
|
||||||
|
channels: [1, 6, 11] # WiFi only
|
||||||
|
bandwidth: "20MHz" # WiFi only
|
||||||
|
traffic_patterns: ["idle", "ping", "iperf_udp"]
|
||||||
|
duration_per_config: "30s"
|
||||||
|
|
||||||
|
recorder:
|
||||||
|
device: "usrp_b210"
|
||||||
|
center_freq: "2.45GHz"
|
||||||
|
sample_rate: "40MHz"
|
||||||
|
gain: "auto"
|
||||||
|
|
||||||
|
output:
|
||||||
|
path: "./recordings"
|
||||||
|
device_id: "iphone13_wifi_001"
|
||||||
|
|
||||||
|
For WiFi devices, schedule is expanded as channels × traffic_patterns.
|
||||||
|
For Bluetooth devices (no channels), schedule is traffic_patterns only.
|
||||||
|
"""
|
||||||
|
path = Path(path)
|
||||||
|
try:
|
||||||
|
with open(path) as f:
|
||||||
|
raw = yaml.safe_load(f)
|
||||||
|
except FileNotFoundError:
|
||||||
|
raise FileNotFoundError(f"Device profile not found: {path}")
|
||||||
|
except yaml.YAMLError as e:
|
||||||
|
raise ValueError(f"Invalid YAML in {path}: {e}")
|
||||||
|
|
||||||
|
device = raw.get("device", {})
|
||||||
|
capture = raw.get("capture", {})
|
||||||
|
device_type = str(device.get("type", "wifi")).lower()
|
||||||
|
device_name = str(device.get("name", path.stem))
|
||||||
|
duration = parse_duration(capture.get("duration_per_config", "30s"))
|
||||||
|
traffic_patterns = capture.get("traffic_patterns", ["idle"])
|
||||||
|
|
||||||
|
# Build capture schedule
|
||||||
|
schedule: list[CaptureStep] = []
|
||||||
|
|
||||||
|
if device_type in ("wifi", "wifi_24ghz", "wifi_5ghz"):
|
||||||
|
channels = capture.get("channels", [6])
|
||||||
|
bw_str = capture.get("bandwidth", "20MHz")
|
||||||
|
bw_mhz = parse_bandwidth_mhz(bw_str)
|
||||||
|
for ch in channels:
|
||||||
|
for traffic in traffic_patterns:
|
||||||
|
label = f"ch{ch:02d}_{int(bw_mhz)}mhz_{traffic}"
|
||||||
|
schedule.append(
|
||||||
|
CaptureStep(
|
||||||
|
duration=duration,
|
||||||
|
label=label,
|
||||||
|
channel=ch,
|
||||||
|
bandwidth_mhz=bw_mhz,
|
||||||
|
traffic=traffic,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Bluetooth / generic — no channels
|
||||||
|
for traffic in traffic_patterns:
|
||||||
|
schedule.append(
|
||||||
|
CaptureStep(
|
||||||
|
duration=duration,
|
||||||
|
label=traffic,
|
||||||
|
traffic=traffic,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
device_id = raw.get("output", {}).get("device_id", device_name.lower().replace(" ", "_"))
|
||||||
|
transmitter = TransmitterConfig(
|
||||||
|
id=device_id,
|
||||||
|
type=device_type,
|
||||||
|
control_method=str(capture.get("control_method", "external_script")),
|
||||||
|
schedule=schedule,
|
||||||
|
script=capture.get("script"),
|
||||||
|
device=capture.get("device"),
|
||||||
|
)
|
||||||
|
|
||||||
|
return cls(
|
||||||
|
name=f"enroll_{device_id}",
|
||||||
|
mode="controlled_testbed",
|
||||||
|
recorder=RecorderConfig.from_dict(raw["recorder"]),
|
||||||
|
transmitters=[transmitter],
|
||||||
|
qa=QAConfig.from_dict(raw.get("qa", {})),
|
||||||
|
output=OutputConfig.from_dict(raw.get("output", {})),
|
||||||
|
)
|
||||||
|
|
||||||
|
def total_capture_time_s(self) -> float:
|
||||||
|
"""Sum of all step durations across all transmitters and loops."""
|
||||||
|
return sum(step.duration for tx in self.transmitters for step in tx.schedule) * self.loops
|
||||||
|
|
||||||
|
def total_steps(self) -> int:
|
||||||
|
"""Total number of capture steps across all transmitters and loops."""
|
||||||
|
return sum(len(tx.schedule) for tx in self.transmitters) * self.loops
|
||||||
570
src/ria_toolkit_oss/orchestration/executor.py
Normal file
570
src/ria_toolkit_oss/orchestration/executor.py
Normal file
|
|
@ -0,0 +1,570 @@
|
||||||
|
"""Campaign executor: runs a capture campaign end-to-end."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import subprocess
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
from dataclasses import dataclass, field, replace
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Callable, Optional
|
||||||
|
|
||||||
|
from ria_toolkit_oss.data.recording import Recording
|
||||||
|
from ria_toolkit_oss.io.recording import to_sigmf
|
||||||
|
|
||||||
|
from .campaign import CampaignConfig, CaptureStep, TransmitterConfig
|
||||||
|
from .labeler import build_output_filename, label_recording
|
||||||
|
from .qa import QAResult, check_recording
|
||||||
|
from .tx_executor import TxExecutor
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Device name aliases: campaign YAML names → get_sdr_device() names
|
||||||
|
_DEVICE_ALIASES = {
|
||||||
|
"usrp_b210": "usrp",
|
||||||
|
"usrp_b200": "usrp",
|
||||||
|
"usrp": "usrp",
|
||||||
|
"plutosdr": "pluto",
|
||||||
|
"pluto": "pluto",
|
||||||
|
"hackrf": "hackrf",
|
||||||
|
"hackrf_one": "hackrf",
|
||||||
|
"bladerf": "bladerf",
|
||||||
|
"rtlsdr": "rtlsdr",
|
||||||
|
"rtl_sdr": "rtlsdr",
|
||||||
|
"thinkrf": "thinkrf",
|
||||||
|
# Simulated device — no hardware required
|
||||||
|
"mock": "mock",
|
||||||
|
"sim": "mock",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class StepResult:
|
||||||
|
"""Outcome of a single capture step."""
|
||||||
|
|
||||||
|
transmitter_id: str
|
||||||
|
step_label: str
|
||||||
|
output_path: Optional[str]
|
||||||
|
qa: QAResult
|
||||||
|
capture_timestamp: float
|
||||||
|
error: Optional[str] = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def ok(self) -> bool:
|
||||||
|
return self.error is None and self.qa.passed
|
||||||
|
|
||||||
|
def to_dict(self) -> dict:
|
||||||
|
return {
|
||||||
|
"transmitter_id": self.transmitter_id,
|
||||||
|
"step_label": self.step_label,
|
||||||
|
"output_path": self.output_path,
|
||||||
|
"capture_timestamp": self.capture_timestamp,
|
||||||
|
"qa": self.qa.to_dict(),
|
||||||
|
"error": self.error,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class CampaignResult:
|
||||||
|
"""Aggregate outcome of a full campaign."""
|
||||||
|
|
||||||
|
campaign_name: str
|
||||||
|
steps: list[StepResult] = field(default_factory=list)
|
||||||
|
start_time: float = field(default_factory=time.time)
|
||||||
|
end_time: Optional[float] = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def total_steps(self) -> int:
|
||||||
|
return len(self.steps)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def passed(self) -> int:
|
||||||
|
return sum(1 for s in self.steps if s.ok)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def flagged(self) -> int:
|
||||||
|
return sum(1 for s in self.steps if not s.error and s.qa.flagged)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def failed(self) -> int:
|
||||||
|
return sum(1 for s in self.steps if s.error or not s.qa.passed)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def duration_s(self) -> float:
|
||||||
|
if self.end_time:
|
||||||
|
return self.end_time - self.start_time
|
||||||
|
return time.time() - self.start_time
|
||||||
|
|
||||||
|
def to_dict(self) -> dict:
|
||||||
|
return {
|
||||||
|
"campaign_name": self.campaign_name,
|
||||||
|
"total_steps": self.total_steps,
|
||||||
|
"passed": self.passed,
|
||||||
|
"flagged": self.flagged,
|
||||||
|
"failed": self.failed,
|
||||||
|
"duration_s": round(self.duration_s, 1),
|
||||||
|
"steps": [s.to_dict() for s in self.steps],
|
||||||
|
}
|
||||||
|
|
||||||
|
def write_report(self, path: str | Path) -> None:
|
||||||
|
"""Write a JSON QA report to disk."""
|
||||||
|
path = Path(path)
|
||||||
|
path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
with open(path, "w") as f:
|
||||||
|
json.dump(self.to_dict(), f, indent=2)
|
||||||
|
logger.info(f"QA report written to {path}")
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# External script interface
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def _run_script(script: str, *args: str, timeout: float = 15.0) -> str:
|
||||||
|
"""Run an external control script and return stdout.
|
||||||
|
|
||||||
|
The script is called as::
|
||||||
|
|
||||||
|
<script> <arg1> <arg2> ...
|
||||||
|
|
||||||
|
A non-zero return code raises RuntimeError.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
script: Path to executable script. Must be an absolute path to an
|
||||||
|
existing regular file. Relative paths are rejected to prevent
|
||||||
|
accidentally executing files that are not the intended script.
|
||||||
|
*args: Positional arguments forwarded to the script.
|
||||||
|
timeout: Maximum seconds to wait.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Script stdout as a string.
|
||||||
|
"""
|
||||||
|
if not Path(script).is_absolute():
|
||||||
|
raise RuntimeError(f"Script path must be absolute: {script}")
|
||||||
|
script_path = Path(script).resolve()
|
||||||
|
if not script_path.is_file():
|
||||||
|
raise RuntimeError(f"Script not found or is not a regular file: {script}")
|
||||||
|
|
||||||
|
cmd = [str(script_path), *args]
|
||||||
|
logger.debug(f"Running script: {' '.join(cmd)}")
|
||||||
|
try:
|
||||||
|
result = subprocess.run(
|
||||||
|
cmd,
|
||||||
|
capture_output=True,
|
||||||
|
text=True,
|
||||||
|
timeout=timeout,
|
||||||
|
)
|
||||||
|
except subprocess.TimeoutExpired:
|
||||||
|
raise RuntimeError(f"Script timed out after {timeout}s: {script}")
|
||||||
|
except FileNotFoundError:
|
||||||
|
raise RuntimeError(f"Script not found: {script}")
|
||||||
|
|
||||||
|
if result.returncode != 0:
|
||||||
|
raise RuntimeError(f"Script exited {result.returncode}: {result.stderr.strip() or result.stdout.strip()}")
|
||||||
|
return result.stdout.strip()
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Campaign executor
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_tx_params(transmitter: TransmitterConfig) -> dict | None:
|
||||||
|
"""Build a tx_params dict from a transmitter's signal config for SigMF labeling.
|
||||||
|
|
||||||
|
For sdr_agent transmitters, returns the synthetic generation parameters
|
||||||
|
(modulation, order, symbol_rate, etc.) so recordings capture what was
|
||||||
|
transmitted. Returns None for control methods without signal-level params.
|
||||||
|
"""
|
||||||
|
sdr_agent_cfg = getattr(transmitter, "sdr_agent", None)
|
||||||
|
if not sdr_agent_cfg:
|
||||||
|
return None
|
||||||
|
# Extract known signal-level fields; ignore infra fields
|
||||||
|
_INFRA_KEYS = {"node_id", "session_code"}
|
||||||
|
return {k: v for k, v in sdr_agent_cfg.items() if k not in _INFRA_KEYS and v is not None}
|
||||||
|
|
||||||
|
|
||||||
|
class CampaignExecutor:
|
||||||
|
"""Executes a :class:`CampaignConfig` end-to-end.
|
||||||
|
|
||||||
|
Initialises the SDR recorder once, then for each (transmitter, step):
|
||||||
|
1. Configures the transmitter (via external script or SDR TX)
|
||||||
|
2. Records IQ samples
|
||||||
|
3. Labels the recording with device/config metadata
|
||||||
|
4. Runs QA checks
|
||||||
|
5. Saves the recording to disk
|
||||||
|
6. Stops/resets the transmitter
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config: Parsed campaign configuration.
|
||||||
|
progress_cb: Optional callback ``(step_index, total_steps, step_result)``
|
||||||
|
called after each step completes. Useful for status reporting.
|
||||||
|
verbose: Enable debug logging.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: CampaignConfig,
|
||||||
|
progress_cb: Optional[Callable[[int, int, StepResult], None]] = None,
|
||||||
|
verbose: bool = False,
|
||||||
|
skip_local_tx: bool = False,
|
||||||
|
):
|
||||||
|
self.config = config
|
||||||
|
self.progress_cb = progress_cb
|
||||||
|
self.skip_local_tx = skip_local_tx
|
||||||
|
self._sdr = None
|
||||||
|
self._remote_tx_controllers: dict = {}
|
||||||
|
self._tx_executors: dict[str, tuple] = {} # tx_id → (TxExecutor, stop_event, thread)
|
||||||
|
|
||||||
|
if verbose:
|
||||||
|
logging.basicConfig(level=logging.DEBUG)
|
||||||
|
else:
|
||||||
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Public interface
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def run(self) -> CampaignResult:
|
||||||
|
"""Execute the full campaign and return a :class:`CampaignResult`.
|
||||||
|
|
||||||
|
Initialises the SDR, runs all steps across all transmitters,
|
||||||
|
then closes the SDR. If SDR initialisation fails the exception
|
||||||
|
propagates immediately (nothing is captured).
|
||||||
|
"""
|
||||||
|
result = CampaignResult(campaign_name=self.config.name)
|
||||||
|
|
||||||
|
loops = self.config.loops
|
||||||
|
logger.info(
|
||||||
|
f"Starting campaign '{self.config.name}': "
|
||||||
|
f"{self.config.total_steps()} steps"
|
||||||
|
+ (f" ({self.config.total_steps() // loops} × {loops} loops)" if loops > 1 else "")
|
||||||
|
+ f", ~{self.config.total_capture_time_s():.0f}s capture time"
|
||||||
|
)
|
||||||
|
|
||||||
|
self._init_sdr()
|
||||||
|
self._init_remote_tx_controllers()
|
||||||
|
try:
|
||||||
|
total = self.config.total_steps()
|
||||||
|
step_index = 0
|
||||||
|
|
||||||
|
for loop_idx in range(loops):
|
||||||
|
if loops > 1:
|
||||||
|
logger.info(f"Loop {loop_idx + 1}/{loops}")
|
||||||
|
for transmitter in self.config.transmitters:
|
||||||
|
logger.info(f"Transmitter: {transmitter.id} ({len(transmitter.schedule)} steps)")
|
||||||
|
for step in transmitter.schedule:
|
||||||
|
looped_step = replace(step, label=f"{step.label}_run{loop_idx + 1:02d}") if loops > 1 else step
|
||||||
|
step_result = self._execute_step(transmitter, looped_step)
|
||||||
|
result.steps.append(step_result)
|
||||||
|
step_index += 1
|
||||||
|
|
||||||
|
if self.progress_cb:
|
||||||
|
self.progress_cb(step_index, total, step_result)
|
||||||
|
|
||||||
|
if step_result.error:
|
||||||
|
logger.warning(f"Step '{looped_step.label}' error: {step_result.error}")
|
||||||
|
elif step_result.qa.flagged:
|
||||||
|
logger.warning(
|
||||||
|
f"Step '{looped_step.label}' flagged for review: " + "; ".join(step_result.qa.issues)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.info(
|
||||||
|
f"Step '{looped_step.label}' OK "
|
||||||
|
f"(SNR {step_result.qa.snr_db:.1f} dB, "
|
||||||
|
f"{step_result.qa.duration_s:.1f}s)"
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
self._close_sdr()
|
||||||
|
self._close_remote_tx_controllers()
|
||||||
|
self._close_tx_executors()
|
||||||
|
|
||||||
|
result.end_time = time.time()
|
||||||
|
logger.info(
|
||||||
|
f"Campaign complete: {result.passed}/{result.total_steps} passed, "
|
||||||
|
f"{result.flagged} flagged, {result.failed} failed"
|
||||||
|
)
|
||||||
|
return result
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# SDR management
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def _init_sdr(self) -> None:
|
||||||
|
"""Initialise and configure the SDR recorder."""
|
||||||
|
from ria_toolkit_oss.sdr import get_sdr_device
|
||||||
|
|
||||||
|
rec = self.config.recorder
|
||||||
|
device_name = _DEVICE_ALIASES.get(rec.device.lower(), rec.device.lower())
|
||||||
|
logger.info(f"Initialising SDR: {device_name} @ {rec.center_freq/1e6:.2f} MHz")
|
||||||
|
|
||||||
|
self._sdr = get_sdr_device(device_name)
|
||||||
|
gain = None if rec.gain == "auto" else float(rec.gain)
|
||||||
|
self._sdr.init_rx(
|
||||||
|
sample_rate=rec.sample_rate,
|
||||||
|
center_frequency=rec.center_freq,
|
||||||
|
gain=gain,
|
||||||
|
channel=0,
|
||||||
|
)
|
||||||
|
if rec.bandwidth and hasattr(self._sdr, "set_rx_bandwidth"):
|
||||||
|
self._sdr.set_rx_bandwidth(rec.bandwidth)
|
||||||
|
|
||||||
|
def _close_sdr(self) -> None:
|
||||||
|
if self._sdr is not None:
|
||||||
|
try:
|
||||||
|
self._sdr.close()
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"SDR close error: {e}")
|
||||||
|
self._sdr = None
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Remote Tx controller management
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def _init_remote_tx_controllers(self) -> None:
|
||||||
|
"""Open SSH+ZMQ connections for all sdr_remote transmitters."""
|
||||||
|
from ria_toolkit_oss.remote_control import RemoteTransmitterController
|
||||||
|
|
||||||
|
for tx in self.config.transmitters:
|
||||||
|
if tx.control_method != "sdr_remote":
|
||||||
|
continue
|
||||||
|
cfg = tx.sdr_remote
|
||||||
|
if not cfg:
|
||||||
|
raise RuntimeError(f"Transmitter '{tx.id}' uses sdr_remote but has no sdr_remote config")
|
||||||
|
logger.info(f"Connecting remote Tx controller for {tx.id} → {cfg['host']}")
|
||||||
|
ctrl = RemoteTransmitterController(
|
||||||
|
host=cfg["host"],
|
||||||
|
ssh_user=cfg["ssh_user"],
|
||||||
|
ssh_key_path=cfg["ssh_key_path"],
|
||||||
|
zmq_port=int(cfg.get("zmq_port", 5556)),
|
||||||
|
)
|
||||||
|
ctrl.set_radio(
|
||||||
|
device_type=cfg["device_type"],
|
||||||
|
device_id=cfg.get("device_id", ""),
|
||||||
|
)
|
||||||
|
self._remote_tx_controllers[tx.id] = ctrl
|
||||||
|
|
||||||
|
def _close_remote_tx_controllers(self) -> None:
|
||||||
|
for tx_id, ctrl in list(self._remote_tx_controllers.items()):
|
||||||
|
try:
|
||||||
|
ctrl.close()
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning(f"Error closing remote Tx controller {tx_id}: {exc}")
|
||||||
|
self._remote_tx_controllers.clear()
|
||||||
|
|
||||||
|
def _close_tx_executors(self) -> None:
|
||||||
|
for tx_id, (_, stop_event, t) in list(self._tx_executors.items()):
|
||||||
|
stop_event.set()
|
||||||
|
t.join(timeout=5.0)
|
||||||
|
self._tx_executors.clear()
|
||||||
|
|
||||||
|
def _record(self, duration_s: float) -> Recording:
|
||||||
|
"""Capture ``duration_s`` seconds of IQ samples."""
|
||||||
|
num_samples = int(duration_s * self.config.recorder.sample_rate)
|
||||||
|
return self._sdr.record(num_samples=num_samples)
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Step execution
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def _execute_step(self, transmitter: TransmitterConfig, step: CaptureStep) -> StepResult:
|
||||||
|
"""Run a single capture step.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
StepResult with QA outcome and output path (or error string).
|
||||||
|
"""
|
||||||
|
capture_timestamp = time.time()
|
||||||
|
output_path: Optional[str] = None
|
||||||
|
|
||||||
|
try:
|
||||||
|
self._start_transmitter(transmitter, step)
|
||||||
|
recording = self._record(step.duration)
|
||||||
|
self._stop_transmitter(transmitter, step)
|
||||||
|
except Exception as e:
|
||||||
|
# Best-effort stop on error
|
||||||
|
try:
|
||||||
|
self._stop_transmitter(transmitter, step)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
return StepResult(
|
||||||
|
transmitter_id=transmitter.id,
|
||||||
|
step_label=step.label,
|
||||||
|
output_path=None,
|
||||||
|
qa=QAResult(passed=False, flagged=True, snr_db=0.0, duration_s=0.0, issues=[f"Capture error: {e}"]),
|
||||||
|
capture_timestamp=capture_timestamp,
|
||||||
|
error=str(e),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Label recording
|
||||||
|
recording = label_recording(
|
||||||
|
recording=recording,
|
||||||
|
device_id=transmitter.id,
|
||||||
|
step=step,
|
||||||
|
capture_timestamp=capture_timestamp,
|
||||||
|
campaign_name=self.config.name,
|
||||||
|
tx_params=_extract_tx_params(transmitter),
|
||||||
|
)
|
||||||
|
|
||||||
|
# QA
|
||||||
|
qa_result = check_recording(recording, self.config.qa)
|
||||||
|
|
||||||
|
# Save
|
||||||
|
try:
|
||||||
|
output_path = self._save(recording, transmitter.id, step)
|
||||||
|
except Exception as e:
|
||||||
|
return StepResult(
|
||||||
|
transmitter_id=transmitter.id,
|
||||||
|
step_label=step.label,
|
||||||
|
output_path=None,
|
||||||
|
qa=qa_result,
|
||||||
|
capture_timestamp=capture_timestamp,
|
||||||
|
error=f"Save failed: {e}",
|
||||||
|
)
|
||||||
|
|
||||||
|
return StepResult(
|
||||||
|
transmitter_id=transmitter.id,
|
||||||
|
step_label=step.label,
|
||||||
|
output_path=output_path,
|
||||||
|
qa=qa_result,
|
||||||
|
capture_timestamp=capture_timestamp,
|
||||||
|
)
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Transmitter control (external script interface)
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def _start_transmitter(self, transmitter: TransmitterConfig, step: CaptureStep) -> None:
|
||||||
|
"""Configure the transmitter for this step.
|
||||||
|
|
||||||
|
For ``external_script`` control method the script is called as::
|
||||||
|
|
||||||
|
<script> configure <step_params_json>
|
||||||
|
|
||||||
|
where ``step_params_json`` is a JSON object with channel, bandwidth,
|
||||||
|
traffic, etc. The script is responsible for applying the configuration
|
||||||
|
and returning promptly (i.e. not blocking for the capture duration).
|
||||||
|
|
||||||
|
For ``sdr_remote`` the remote ZMQ controller calls ``init_tx`` then
|
||||||
|
starts a background transmit thread that runs for the step duration.
|
||||||
|
"""
|
||||||
|
if transmitter.control_method == "external_script":
|
||||||
|
if not transmitter.script:
|
||||||
|
logger.debug(f"No script configured for {transmitter.id}, skipping configure")
|
||||||
|
return
|
||||||
|
params = self._step_params_json(transmitter, step)
|
||||||
|
_run_script(transmitter.script, "configure", params)
|
||||||
|
|
||||||
|
elif transmitter.control_method == "sdr":
|
||||||
|
logger.debug("SDR TX not yet implemented — skipping start")
|
||||||
|
|
||||||
|
elif transmitter.control_method == "sdr_remote":
|
||||||
|
ctrl = self._remote_tx_controllers.get(transmitter.id)
|
||||||
|
if ctrl is None:
|
||||||
|
raise RuntimeError(f"No remote Tx controller found for transmitter '{transmitter.id}'")
|
||||||
|
gain = step.power_dbm if step.power_dbm is not None else 0.0
|
||||||
|
ctrl.init_tx(
|
||||||
|
center_frequency=self.config.recorder.center_freq,
|
||||||
|
sample_rate=self.config.recorder.sample_rate,
|
||||||
|
gain=gain,
|
||||||
|
channel=step.channel or 0,
|
||||||
|
)
|
||||||
|
# Start transmission in background; _record() runs concurrently
|
||||||
|
ctrl.transmit_async(step.duration + 1.0)
|
||||||
|
|
||||||
|
elif transmitter.control_method == "sdr_agent":
|
||||||
|
if self.skip_local_tx:
|
||||||
|
logger.debug(f"skip_local_tx — TX for '{transmitter.id}' delegated to TX agent node")
|
||||||
|
return
|
||||||
|
if not transmitter.sdr_agent:
|
||||||
|
logger.warning(f"Transmitter '{transmitter.id}' has no sdr_agent config — skipping")
|
||||||
|
return
|
||||||
|
step_dict: dict = {"label": step.label, "duration": step.duration + 1.0}
|
||||||
|
if step.power_dbm is not None:
|
||||||
|
step_dict["power_dbm"] = step.power_dbm
|
||||||
|
tx_config = {
|
||||||
|
"id": transmitter.id,
|
||||||
|
"sdr_agent": transmitter.sdr_agent,
|
||||||
|
"schedule": [step_dict],
|
||||||
|
}
|
||||||
|
rec = self.config.recorder
|
||||||
|
tx_device = transmitter.device or rec.device
|
||||||
|
sdr_device = _DEVICE_ALIASES.get(tx_device.lower(), tx_device.lower())
|
||||||
|
stop_event = threading.Event()
|
||||||
|
executor = TxExecutor(tx_config, sdr_device=sdr_device, stop_event=stop_event)
|
||||||
|
t = threading.Thread(target=executor.run, daemon=True, name=f"tx-{transmitter.id}")
|
||||||
|
self._tx_executors[transmitter.id] = (executor, stop_event, t)
|
||||||
|
t.start()
|
||||||
|
|
||||||
|
else:
|
||||||
|
logger.warning(f"Unknown control method '{transmitter.control_method}' — skipping")
|
||||||
|
|
||||||
|
def _stop_transmitter(self, transmitter: TransmitterConfig, step: CaptureStep) -> None:
|
||||||
|
"""Signal the transmitter to stop.
|
||||||
|
|
||||||
|
Calls ``<script> stop`` for external_script transmitters.
|
||||||
|
For ``sdr_remote``, waits for the background transmit thread to finish.
|
||||||
|
"""
|
||||||
|
if transmitter.control_method == "external_script":
|
||||||
|
if not transmitter.script:
|
||||||
|
return
|
||||||
|
try:
|
||||||
|
_run_script(transmitter.script, "stop")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Script stop failed for {transmitter.id}: {e}")
|
||||||
|
|
||||||
|
elif transmitter.control_method == "sdr_remote":
|
||||||
|
ctrl = self._remote_tx_controllers.get(transmitter.id)
|
||||||
|
if ctrl is not None:
|
||||||
|
ctrl.wait_transmit(timeout=step.duration + 10.0)
|
||||||
|
|
||||||
|
elif transmitter.control_method == "sdr_agent":
|
||||||
|
entry = self._tx_executors.pop(transmitter.id, None)
|
||||||
|
if entry is not None:
|
||||||
|
_, stop_event, t = entry
|
||||||
|
stop_event.set()
|
||||||
|
t.join(timeout=step.duration + 10.0)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _step_params_json(transmitter: TransmitterConfig, step: CaptureStep) -> str:
|
||||||
|
"""Serialise step parameters to a JSON string for the control script."""
|
||||||
|
params: dict = {"device": transmitter.device or ""}
|
||||||
|
if step.channel is not None:
|
||||||
|
params["channel"] = step.channel
|
||||||
|
if step.bandwidth_mhz is not None:
|
||||||
|
params["bandwidth_mhz"] = step.bandwidth_mhz
|
||||||
|
if step.traffic is not None:
|
||||||
|
params["traffic"] = step.traffic
|
||||||
|
if step.power_dbm is not None:
|
||||||
|
params["power_dbm"] = step.power_dbm
|
||||||
|
return json.dumps(params)
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Output
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def _save(self, recording: Recording, device_id: str, step: CaptureStep) -> str:
|
||||||
|
"""Save a recording to disk and return the data file path."""
|
||||||
|
out = self.config.output
|
||||||
|
rel_filename = build_output_filename(device_id, step)
|
||||||
|
out_dir = Path(out.path).resolve()
|
||||||
|
|
||||||
|
# build_output_filename returns "<device_id>/<label>"
|
||||||
|
# to_sigmf needs filename (base) and path (dir) separately
|
||||||
|
parts = Path(rel_filename)
|
||||||
|
subdir = (out_dir / parts.parent).resolve()
|
||||||
|
|
||||||
|
# Prevent path traversal: the resolved subdir must stay within the configured output directory.
|
||||||
|
try:
|
||||||
|
subdir.relative_to(out_dir)
|
||||||
|
except ValueError:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Output path escape detected: '{subdir}' is outside configured output directory '{out_dir}'"
|
||||||
|
)
|
||||||
|
|
||||||
|
subdir.mkdir(parents=True, exist_ok=True)
|
||||||
|
base = parts.name
|
||||||
|
|
||||||
|
to_sigmf(recording, filename=base, path=str(subdir), overwrite=True)
|
||||||
|
return str(subdir / f"{base}.sigmf-data")
|
||||||
86
src/ria_toolkit_oss/orchestration/labeler.py
Normal file
86
src/ria_toolkit_oss/orchestration/labeler.py
Normal file
|
|
@ -0,0 +1,86 @@
|
||||||
|
"""Timestamp-based labeling for captured recordings."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from ria_toolkit_oss.data.recording import Recording
|
||||||
|
|
||||||
|
from .campaign import CaptureStep
|
||||||
|
|
||||||
|
|
||||||
|
def label_recording(
|
||||||
|
recording: Recording,
|
||||||
|
device_id: str,
|
||||||
|
step: CaptureStep,
|
||||||
|
capture_timestamp: float,
|
||||||
|
campaign_name: Optional[str] = None,
|
||||||
|
tx_params: Optional[dict] = None,
|
||||||
|
) -> Recording:
|
||||||
|
"""Apply device identity and capture configuration labels to a recording's metadata.
|
||||||
|
|
||||||
|
Labels are stored in the ``ria:*`` namespace when the recording is saved
|
||||||
|
as SigMF, via the existing ``update_metadata`` mechanism.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
recording: The recording to label.
|
||||||
|
device_id: Identifier for the transmitting device (e.g. "iphone13_wifi_001").
|
||||||
|
step: The capture step that was active during this recording.
|
||||||
|
capture_timestamp: Unix timestamp (float) of when capture started.
|
||||||
|
campaign_name: Optional campaign name for cross-recording reference.
|
||||||
|
tx_params: Optional dict of transmitter signal parameters (e.g. modulation,
|
||||||
|
order, symbol_rate) written as ``ria:tx_<key>`` fields so downstream
|
||||||
|
training pipelines know what was transmitted into the recording.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The same recording with updated metadata.
|
||||||
|
"""
|
||||||
|
recording.update_metadata("device_id", device_id)
|
||||||
|
recording.update_metadata("capture_timestamp", capture_timestamp)
|
||||||
|
recording.update_metadata("step_label", step.label)
|
||||||
|
recording.update_metadata("step_duration_s", step.duration)
|
||||||
|
|
||||||
|
if campaign_name:
|
||||||
|
recording.update_metadata("campaign", campaign_name)
|
||||||
|
|
||||||
|
# WiFi-specific labels
|
||||||
|
if step.channel is not None:
|
||||||
|
recording.update_metadata("wifi_channel", step.channel)
|
||||||
|
if step.bandwidth_mhz is not None:
|
||||||
|
recording.update_metadata("wifi_bandwidth_mhz", step.bandwidth_mhz)
|
||||||
|
|
||||||
|
# Bluetooth-specific labels
|
||||||
|
if step.connection_interval_ms is not None:
|
||||||
|
recording.update_metadata("bt_connection_interval_ms", step.connection_interval_ms)
|
||||||
|
|
||||||
|
# Traffic pattern (WiFi + BT)
|
||||||
|
if step.traffic is not None:
|
||||||
|
recording.update_metadata("traffic_pattern", step.traffic)
|
||||||
|
|
||||||
|
# TX power
|
||||||
|
if step.power_dbm is not None:
|
||||||
|
recording.update_metadata("tx_power_dbm", step.power_dbm)
|
||||||
|
|
||||||
|
# Transmitter signal parameters (e.g. from sdr_agent synthetic generation)
|
||||||
|
if tx_params:
|
||||||
|
for key, value in tx_params.items():
|
||||||
|
recording.update_metadata(f"tx_{key}", value)
|
||||||
|
|
||||||
|
return recording
|
||||||
|
|
||||||
|
|
||||||
|
def build_output_filename(device_id: str, step: CaptureStep) -> str:
|
||||||
|
"""Generate a deterministic filename for a labeled recording.
|
||||||
|
|
||||||
|
Format: ``<device_id>/<step_label>``
|
||||||
|
|
||||||
|
Args:
|
||||||
|
device_id: Device identifier string.
|
||||||
|
step: Capture step.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Relative path string (no extension) to use as ``filename`` in ``to_sigmf()``.
|
||||||
|
"""
|
||||||
|
safe_id = device_id.replace("/", "_").replace(" ", "_")
|
||||||
|
safe_label = step.label.replace("/", "_").replace(" ", "_")
|
||||||
|
return f"{safe_id}/{safe_label}"
|
||||||
109
src/ria_toolkit_oss/orchestration/qa.py
Normal file
109
src/ria_toolkit_oss/orchestration/qa.py
Normal file
|
|
@ -0,0 +1,109 @@
|
||||||
|
"""QA metrics for captured RF recordings."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from ria_toolkit_oss.data.recording import Recording
|
||||||
|
|
||||||
|
from .campaign import QAConfig
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class QAResult:
|
||||||
|
"""Result of QA checks on a single recording."""
|
||||||
|
|
||||||
|
passed: bool
|
||||||
|
flagged: bool # True if any metric is below threshold (but not hard-failed)
|
||||||
|
snr_db: float
|
||||||
|
duration_s: float
|
||||||
|
issues: list[str] = field(default_factory=list)
|
||||||
|
|
||||||
|
def to_dict(self) -> dict:
|
||||||
|
return {
|
||||||
|
"passed": self.passed,
|
||||||
|
"flagged": self.flagged,
|
||||||
|
"snr_db": round(self.snr_db, 2),
|
||||||
|
"duration_s": round(self.duration_s, 3),
|
||||||
|
"issues": self.issues,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def estimate_snr_db(samples: np.ndarray, signal_fraction: float = 0.7) -> float:
|
||||||
|
"""Estimate SNR from IQ samples using PSD-based signal/noise separation.
|
||||||
|
|
||||||
|
Computes an FFT of the samples and assumes the top ``signal_fraction``
|
||||||
|
of power bins are signal and the remainder are noise. This is a
|
||||||
|
heuristic appropriate for a controlled testbed where a single dominant
|
||||||
|
signal is expected.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
samples: 1-D complex array of IQ samples.
|
||||||
|
signal_fraction: Fraction of PSD bins to treat as signal (0–1).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Estimated SNR in dB, or 0.0 if the noise floor is zero.
|
||||||
|
"""
|
||||||
|
n_fft = min(4096, len(samples))
|
||||||
|
window = np.hanning(n_fft)
|
||||||
|
psd = np.abs(np.fft.fft(samples[:n_fft] * window)) ** 2
|
||||||
|
|
||||||
|
psd_sorted = np.sort(psd)[::-1]
|
||||||
|
n_signal = min(max(1, int(n_fft * signal_fraction)), n_fft - 1)
|
||||||
|
signal_power = psd_sorted[:n_signal].mean()
|
||||||
|
noise_power = psd_sorted[n_signal:].mean()
|
||||||
|
|
||||||
|
if noise_power <= 0.0:
|
||||||
|
return 0.0
|
||||||
|
return float(10.0 * np.log10(signal_power / noise_power))
|
||||||
|
|
||||||
|
|
||||||
|
def check_recording(recording: Recording, config: QAConfig) -> QAResult:
|
||||||
|
"""Run QA checks on a recording against the campaign QA config.
|
||||||
|
|
||||||
|
Checks performed:
|
||||||
|
- Duration: number of samples / sample_rate >= min_duration_s
|
||||||
|
- SNR: estimated SNR >= snr_threshold_db
|
||||||
|
|
||||||
|
Args:
|
||||||
|
recording: Recording to evaluate.
|
||||||
|
config: QA thresholds from the campaign config.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
QAResult with pass/flag status and per-metric details.
|
||||||
|
"""
|
||||||
|
issues: list[str] = []
|
||||||
|
flagged = False
|
||||||
|
|
||||||
|
# --- Duration check ---
|
||||||
|
sample_rate = recording.metadata.get("sample_rate", 1.0)
|
||||||
|
n_samples = recording.data.shape[-1]
|
||||||
|
duration_s = n_samples / sample_rate if sample_rate else 0.0
|
||||||
|
|
||||||
|
if duration_s < config.min_duration_s:
|
||||||
|
issues.append(f"Duration too short: {duration_s:.1f}s < {config.min_duration_s:.1f}s threshold")
|
||||||
|
flagged = True
|
||||||
|
|
||||||
|
# --- SNR check ---
|
||||||
|
samples = recording.data[0] if recording.data.ndim > 1 else recording.data
|
||||||
|
snr_db = estimate_snr_db(samples)
|
||||||
|
|
||||||
|
if snr_db < config.snr_threshold_db:
|
||||||
|
issues.append(f"SNR below threshold: {snr_db:.1f} dB < {config.snr_threshold_db:.1f} dB")
|
||||||
|
flagged = True
|
||||||
|
|
||||||
|
# In flag_for_review mode: flag but don't hard-fail
|
||||||
|
if config.flag_for_review:
|
||||||
|
passed = True # always accept; human reviews flagged recordings
|
||||||
|
else:
|
||||||
|
passed = not flagged
|
||||||
|
|
||||||
|
return QAResult(
|
||||||
|
passed=passed,
|
||||||
|
flagged=flagged,
|
||||||
|
snr_db=snr_db,
|
||||||
|
duration_s=duration_s,
|
||||||
|
issues=issues,
|
||||||
|
)
|
||||||
299
src/ria_toolkit_oss/orchestration/tx_executor.py
Normal file
299
src/ria_toolkit_oss/orchestration/tx_executor.py
Normal file
|
|
@ -0,0 +1,299 @@
|
||||||
|
"""TX campaign executor — synthesises and transmits signals via a local SDR.
|
||||||
|
|
||||||
|
The TxExecutor receives a transmitter config dict (matching the
|
||||||
|
``sdr_agent`` control method's schema) and a step schedule, then for each
|
||||||
|
step builds a signal chain with the block generator and transmits it via
|
||||||
|
the local SDR device.
|
||||||
|
|
||||||
|
Supported modulations (``modulation`` field in config):
|
||||||
|
BPSK, QPSK, 8PSK, 16QAM, 64QAM, 256QAM, FSK, OOK, GMSK, OQPSK
|
||||||
|
|
||||||
|
Example config dict (matches CampaignConfig transmitter with
|
||||||
|
``control_method: sdr_agent``)::
|
||||||
|
|
||||||
|
{
|
||||||
|
"id": "synthetic-tx",
|
||||||
|
"type": "sdr",
|
||||||
|
"control_method": "sdr_agent",
|
||||||
|
"sdr_agent": {
|
||||||
|
"modulation": "QPSK",
|
||||||
|
"order": 4,
|
||||||
|
"symbol_rate": 1000000,
|
||||||
|
"center_frequency": 0.0,
|
||||||
|
"filter": "rrc",
|
||||||
|
"rolloff": 0.35
|
||||||
|
},
|
||||||
|
"schedule": [
|
||||||
|
{"label": "step1", "duration": 10, "power_dbm": -10}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import threading
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_hz(val: object) -> float:
|
||||||
|
"""Parse a frequency value that may be a float (Hz) or a string like '2.45GHz'."""
|
||||||
|
if isinstance(val, (int, float)):
|
||||||
|
return float(val)
|
||||||
|
s = str(val).strip()
|
||||||
|
for suffix, mult in (("GHz", 1e9), ("MHz", 1e6), ("kHz", 1e3), ("Hz", 1.0)):
|
||||||
|
if s.endswith(suffix):
|
||||||
|
return float(s[: -len(suffix)]) * mult
|
||||||
|
return float(s)
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_seconds(val: object) -> float:
|
||||||
|
"""Parse a duration value that may be a float (seconds) or a string like '5s'."""
|
||||||
|
if isinstance(val, (int, float)):
|
||||||
|
return float(val)
|
||||||
|
s = str(val).strip()
|
||||||
|
return float(s[:-1]) if s.endswith("s") else float(s)
|
||||||
|
|
||||||
|
|
||||||
|
# Mapping from modulation name → (PSK/QAM order, generator_type)
|
||||||
|
# 'psk' uses PSKGenerator, 'qam' uses QAMGenerator
|
||||||
|
_MOD_TABLE: dict[str, tuple[int, str]] = {
|
||||||
|
"BPSK": (1, "psk"),
|
||||||
|
"QPSK": (2, "psk"),
|
||||||
|
"8PSK": (3, "psk"),
|
||||||
|
"16QAM": (4, "qam"),
|
||||||
|
"64QAM": (6, "qam"),
|
||||||
|
"256QAM": (8, "qam"),
|
||||||
|
}
|
||||||
|
|
||||||
|
_SPECIAL_MODS = {"FSK", "OOK", "GMSK", "OQPSK"}
|
||||||
|
|
||||||
|
# usrp-uhd-client's tx_recording() streams 2 000-sample chunks and loops the
|
||||||
|
# source buffer for the full tx_time, so only this many samples ever need to
|
||||||
|
# be in RAM regardless of step duration or sample rate.
|
||||||
|
# 50 000 complex64 samples ≈ 400 kB — enough spectral diversity for looping.
|
||||||
|
_SYNTH_BLOCK_SAMPLES = 50_000
|
||||||
|
|
||||||
|
|
||||||
|
class TxExecutor:
|
||||||
|
"""Synthesise and transmit a signal campaign via a local SDR.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config: Transmitter config dict (must have ``sdr_agent`` sub-dict with
|
||||||
|
modulation params, and ``schedule`` list of step dicts).
|
||||||
|
sdr_device: SDR device name to open in TX mode (e.g. "pluto", "usrp").
|
||||||
|
stop_event: External event that aborts the TX loop mid-step.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: dict,
|
||||||
|
sdr_device: str = "unknown",
|
||||||
|
stop_event: threading.Event | None = None,
|
||||||
|
) -> None:
|
||||||
|
self.config = config
|
||||||
|
self.sdr_device = sdr_device
|
||||||
|
self.stop_event = stop_event or threading.Event()
|
||||||
|
self._sdr: Any = None
|
||||||
|
|
||||||
|
def run(self) -> None:
|
||||||
|
"""Execute all steps in the schedule, transmitting for each step duration."""
|
||||||
|
agent_cfg: dict = self.config.get("sdr_agent") or {}
|
||||||
|
schedule: list[dict] = self.config.get("schedule") or []
|
||||||
|
|
||||||
|
if not schedule:
|
||||||
|
logger.warning("TxExecutor: no schedule steps — nothing to transmit")
|
||||||
|
return
|
||||||
|
|
||||||
|
modulation: str = agent_cfg.get("modulation", "QPSK").upper()
|
||||||
|
symbol_rate: float = float(agent_cfg.get("symbol_rate", 1e6))
|
||||||
|
center_freq: float = _parse_hz(agent_cfg.get("center_frequency", 0.0))
|
||||||
|
filter_type: str = agent_cfg.get("filter", "rrc").lower()
|
||||||
|
rolloff: float = float(agent_cfg.get("rolloff", 0.35))
|
||||||
|
loops: int = max(1, int(self.config.get("loops", 1)))
|
||||||
|
|
||||||
|
# Upsampling factor: samples_per_symbol, fixed at 8 for SDR compatibility.
|
||||||
|
sps = 8
|
||||||
|
sample_rate = symbol_rate * sps
|
||||||
|
|
||||||
|
self._init_sdr(sample_rate, center_freq)
|
||||||
|
try:
|
||||||
|
for loop_idx in range(loops):
|
||||||
|
if self.stop_event.is_set():
|
||||||
|
break
|
||||||
|
if loops > 1:
|
||||||
|
logger.info("TX loop %d/%d", loop_idx + 1, loops)
|
||||||
|
for step in schedule:
|
||||||
|
if self.stop_event.is_set():
|
||||||
|
break
|
||||||
|
looped_step = (
|
||||||
|
{**step, "label": f"{step.get('label', 'step')}_run{loop_idx + 1:02d}"} if loops > 1 else step
|
||||||
|
)
|
||||||
|
self._execute_step(looped_step, modulation, sps, symbol_rate, filter_type, rolloff)
|
||||||
|
finally:
|
||||||
|
self._close_sdr()
|
||||||
|
|
||||||
|
def _execute_step(
|
||||||
|
self,
|
||||||
|
step: dict,
|
||||||
|
modulation: str,
|
||||||
|
sps: int,
|
||||||
|
symbol_rate: float,
|
||||||
|
filter_type: str,
|
||||||
|
rolloff: float,
|
||||||
|
) -> None:
|
||||||
|
duration: float = _parse_seconds(step.get("duration", 10.0))
|
||||||
|
label: str = step.get("label", "step")
|
||||||
|
gain: float = float(step.get("power_dbm") or 0.0)
|
||||||
|
sample_rate = symbol_rate * sps
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"TX step '%s': %.0f s, %s @ %.3f MHz (sps=%d, filter=%s)",
|
||||||
|
label,
|
||||||
|
duration,
|
||||||
|
modulation,
|
||||||
|
symbol_rate / 1e6,
|
||||||
|
sps,
|
||||||
|
filter_type,
|
||||||
|
)
|
||||||
|
|
||||||
|
num_samples = int(duration * sample_rate)
|
||||||
|
|
||||||
|
# Synthesise a short representative block. tx_recording() loops this
|
||||||
|
# buffer for the full tx_time using a 2 000-sample streaming callback,
|
||||||
|
# so peak memory is O(_SYNTH_BLOCK_SAMPLES) regardless of duration.
|
||||||
|
block_size = min(num_samples, _SYNTH_BLOCK_SAMPLES)
|
||||||
|
signal = self._synthesise(modulation, sps, block_size, filter_type, rolloff)
|
||||||
|
|
||||||
|
if self._sdr is not None:
|
||||||
|
try:
|
||||||
|
# Apply gain update if SDR supports it
|
||||||
|
if hasattr(self._sdr, "set_tx_gain"):
|
||||||
|
self._sdr.set_tx_gain(gain)
|
||||||
|
self._sdr.tx_recording(signal, tx_time=duration)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error("TX step '%s' SDR error: %s", label, exc)
|
||||||
|
else:
|
||||||
|
# No SDR available — simulate by sleeping for the step duration.
|
||||||
|
logger.warning("TX step '%s': no SDR — simulating %.0f s delay", label, duration)
|
||||||
|
self.stop_event.wait(timeout=duration)
|
||||||
|
|
||||||
|
def _synthesise(
|
||||||
|
self,
|
||||||
|
modulation: str,
|
||||||
|
sps: int,
|
||||||
|
num_samples: int,
|
||||||
|
filter_type: str,
|
||||||
|
rolloff: float,
|
||||||
|
):
|
||||||
|
"""Build a block-generator chain and return IQ samples as a numpy array."""
|
||||||
|
try:
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from ria_toolkit_oss.signal.block_generator import (
|
||||||
|
BinarySource,
|
||||||
|
GMSKModulator,
|
||||||
|
Mapper,
|
||||||
|
OOKModulator,
|
||||||
|
OQPSKModulator,
|
||||||
|
RaisedCosineFilter,
|
||||||
|
RootRaisedCosineFilter,
|
||||||
|
Upsampling,
|
||||||
|
)
|
||||||
|
from ria_toolkit_oss.signal.block_generator.continuous_modulation.fsk_modulator import (
|
||||||
|
FSKModulator,
|
||||||
|
)
|
||||||
|
except ImportError as exc:
|
||||||
|
raise RuntimeError(f"ria_toolkit_oss block generator not available: {exc}") from exc
|
||||||
|
|
||||||
|
# ── Special modulations with their own source-connected modulator ──
|
||||||
|
if modulation in ("OOK", "GMSK", "OQPSK"):
|
||||||
|
src = BinarySource()
|
||||||
|
if modulation == "OOK":
|
||||||
|
mod = OOKModulator(src, samples_per_symbol=sps)
|
||||||
|
elif modulation == "GMSK":
|
||||||
|
mod = GMSKModulator(src, samples_per_symbol=sps)
|
||||||
|
else:
|
||||||
|
mod = OQPSKModulator(src, samples_per_symbol=sps)
|
||||||
|
recording = mod.record(num_samples)
|
||||||
|
flat = np.asarray(recording.data).flatten().astype(np.complex64)
|
||||||
|
if len(flat) < num_samples:
|
||||||
|
flat = np.tile(flat, num_samples // len(flat) + 1)
|
||||||
|
return flat[:num_samples]
|
||||||
|
|
||||||
|
if modulation == "FSK":
|
||||||
|
symbol_rate = num_samples / sps
|
||||||
|
bits_per_sym = 1 # 2-FSK
|
||||||
|
num_bits = max(num_samples // sps, 128) * bits_per_sym
|
||||||
|
bits = BinarySource()((1, num_bits))
|
||||||
|
mod = FSKModulator(
|
||||||
|
num_bits_per_symbol=bits_per_sym,
|
||||||
|
frequency_spacing=symbol_rate * 0.5,
|
||||||
|
symbol_duration=1.0 / max(symbol_rate, 1.0),
|
||||||
|
sampling_frequency=symbol_rate * sps,
|
||||||
|
)
|
||||||
|
flat = np.asarray(mod(bits)).flatten().astype(np.complex64)
|
||||||
|
if len(flat) < num_samples:
|
||||||
|
flat = np.tile(flat, num_samples // len(flat) + 1)
|
||||||
|
return flat[:num_samples]
|
||||||
|
|
||||||
|
# ── PSK / QAM via Mapper → Upsampling → pulse filter ──────────────
|
||||||
|
if modulation not in _MOD_TABLE:
|
||||||
|
logger.warning("Unknown modulation %r — defaulting to QPSK", modulation)
|
||||||
|
modulation = "QPSK"
|
||||||
|
|
||||||
|
bits_per_sym, gen_type = _MOD_TABLE[modulation]
|
||||||
|
mod_family = "QAM" if gen_type == "qam" else "PSK"
|
||||||
|
|
||||||
|
source = BinarySource()
|
||||||
|
mapper = Mapper(constellation_type=mod_family, num_bits_per_symbol=bits_per_sym)
|
||||||
|
upsampler = Upsampling(factor=sps)
|
||||||
|
|
||||||
|
mapper.connect_input([source])
|
||||||
|
upsampler.connect_input([mapper])
|
||||||
|
|
||||||
|
if filter_type in ("rrc",):
|
||||||
|
pulse_filter = RootRaisedCosineFilter(span_in_symbols=6, upsampling_factor=sps, beta=rolloff)
|
||||||
|
pulse_filter.connect_input([upsampler])
|
||||||
|
recording = pulse_filter.record(num_samples)
|
||||||
|
elif filter_type in ("rc",):
|
||||||
|
pulse_filter = RaisedCosineFilter(span_in_symbols=6, upsampling_factor=sps, beta=rolloff)
|
||||||
|
pulse_filter.connect_input([upsampler])
|
||||||
|
recording = pulse_filter.record(num_samples)
|
||||||
|
else:
|
||||||
|
# "none", "rect", "gaussian" — use upsampler output directly
|
||||||
|
recording = upsampler.record(num_samples)
|
||||||
|
|
||||||
|
flat = np.asarray(recording.data).flatten().astype(np.complex64)
|
||||||
|
if len(flat) < num_samples:
|
||||||
|
flat = np.tile(flat, num_samples // len(flat) + 1)
|
||||||
|
return flat[:num_samples]
|
||||||
|
|
||||||
|
def _init_sdr(self, sample_rate: float, center_freq: float) -> None:
|
||||||
|
try:
|
||||||
|
from ria_toolkit_oss.sdr import get_sdr_device
|
||||||
|
|
||||||
|
self._sdr = get_sdr_device(self.sdr_device)
|
||||||
|
self._sdr.init_tx(
|
||||||
|
sample_rate=sample_rate,
|
||||||
|
center_frequency=center_freq,
|
||||||
|
gain=0,
|
||||||
|
channel=0,
|
||||||
|
gain_mode="manual",
|
||||||
|
)
|
||||||
|
logger.info(
|
||||||
|
"TX SDR initialised: %s @ %.3f MHz, %.1f Msps", self.sdr_device, center_freq / 1e6, sample_rate / 1e6
|
||||||
|
)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("TX SDR init failed (%s) — will simulate: %s", self.sdr_device, exc)
|
||||||
|
self._sdr = None
|
||||||
|
|
||||||
|
def _close_sdr(self) -> None:
|
||||||
|
if self._sdr is not None:
|
||||||
|
try:
|
||||||
|
self._sdr.close()
|
||||||
|
except Exception as exc:
|
||||||
|
logger.debug("TX SDR close error: %s", exc)
|
||||||
|
self._sdr = None
|
||||||
6
src/ria_toolkit_oss/remote_control/__init__.py
Normal file
6
src/ria_toolkit_oss/remote_control/__init__.py
Normal file
|
|
@ -0,0 +1,6 @@
|
||||||
|
"""Remote SDR transmitter control via SSH + ZMQ."""
|
||||||
|
|
||||||
|
from .remote_transmitter import RemoteTransmitter
|
||||||
|
from .remote_transmitter_controller import RemoteTransmitterController
|
||||||
|
|
||||||
|
__all__ = ["RemoteTransmitter", "RemoteTransmitterController"]
|
||||||
152
src/ria_toolkit_oss/remote_control/remote_transmitter.py
Normal file
152
src/ria_toolkit_oss/remote_control/remote_transmitter.py
Normal file
|
|
@ -0,0 +1,152 @@
|
||||||
|
"""Server-side ZMQ RPC receiver for SDR transmission.
|
||||||
|
|
||||||
|
Run this script on the Tx machine. The script binds a ZMQ REP socket and
|
||||||
|
waits for JSON-RPC commands from a :class:`RemoteTransmitterController`.
|
||||||
|
|
||||||
|
Requires: zmq, and ria-toolkit or utils installed for SDR support.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import io
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from contextlib import redirect_stderr, redirect_stdout
|
||||||
|
|
||||||
|
import zmq
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class RemoteTransmitter:
|
||||||
|
"""Executes SDR Tx commands received over ZMQ.
|
||||||
|
|
||||||
|
Loads the appropriate SDR driver dynamically so the script can run on
|
||||||
|
machines that have only a subset of SDR libraries installed.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self._sdr = None
|
||||||
|
|
||||||
|
def set_radio(self, radio_str: str, identifier: str = "") -> None:
|
||||||
|
"""Initialise the SDR radio.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
radio_str: SDR type — pluto | usrp | hackrf | bladerf.
|
||||||
|
identifier: Device-specific identifier (IP, serial, etc.).
|
||||||
|
"""
|
||||||
|
radio_str = radio_str.lower()
|
||||||
|
try:
|
||||||
|
if radio_str in ("pluto", "plutosdr"):
|
||||||
|
from ria_toolkit_oss.sdr.pluto import Pluto
|
||||||
|
|
||||||
|
self._sdr = Pluto(identifier)
|
||||||
|
elif radio_str in ("usrp",):
|
||||||
|
from ria_toolkit_oss.sdr.usrp import USRP
|
||||||
|
|
||||||
|
self._sdr = USRP(identifier)
|
||||||
|
elif radio_str in ("hackrf", "hackrf_one"):
|
||||||
|
from ria_toolkit_oss.sdr.hackrf import HackRF
|
||||||
|
|
||||||
|
self._sdr = HackRF(identifier)
|
||||||
|
elif radio_str in ("bladerf", "blade"):
|
||||||
|
from ria_toolkit_oss.sdr.blade import Blade
|
||||||
|
|
||||||
|
self._sdr = Blade(identifier)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown SDR type: {radio_str!r}")
|
||||||
|
except ImportError as exc:
|
||||||
|
raise RuntimeError(f"SDR driver for '{radio_str}' is not installed: {exc}") from exc
|
||||||
|
|
||||||
|
def init_tx(
|
||||||
|
self,
|
||||||
|
center_frequency: float,
|
||||||
|
sample_rate: float,
|
||||||
|
gain: float,
|
||||||
|
channel: int = 0,
|
||||||
|
gain_mode: str = "absolute",
|
||||||
|
) -> None:
|
||||||
|
if self._sdr is None:
|
||||||
|
raise RuntimeError("Call set_radio() before init_tx()")
|
||||||
|
self._sdr.init_tx(
|
||||||
|
center_frequency=center_frequency,
|
||||||
|
sample_rate=sample_rate,
|
||||||
|
gain=gain,
|
||||||
|
channel=channel,
|
||||||
|
)
|
||||||
|
|
||||||
|
def transmit(self, duration_s: float) -> None:
|
||||||
|
"""Transmit a continuous wave for ``duration_s`` seconds."""
|
||||||
|
if self._sdr is None:
|
||||||
|
raise RuntimeError("Call set_radio() and init_tx() before transmit()")
|
||||||
|
import time
|
||||||
|
|
||||||
|
# Transmit in a loop until duration has elapsed
|
||||||
|
end = time.monotonic() + duration_s
|
||||||
|
while time.monotonic() < end:
|
||||||
|
try:
|
||||||
|
self._sdr.tx_cw()
|
||||||
|
except AttributeError:
|
||||||
|
time.sleep(0.01)
|
||||||
|
|
||||||
|
def stop(self) -> None:
|
||||||
|
"""Stop transmission and close the SDR."""
|
||||||
|
if self._sdr is not None:
|
||||||
|
try:
|
||||||
|
self._sdr.close()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
self._sdr = None
|
||||||
|
|
||||||
|
def run_function(self, command_dict: dict) -> dict:
|
||||||
|
"""Dispatch a JSON-RPC command and return a response dict."""
|
||||||
|
out_buf = io.StringIO()
|
||||||
|
err_buf = io.StringIO()
|
||||||
|
fn = command_dict.get("function_name", "")
|
||||||
|
try:
|
||||||
|
with redirect_stdout(out_buf), redirect_stderr(err_buf):
|
||||||
|
if fn == "set_radio":
|
||||||
|
self.set_radio(
|
||||||
|
radio_str=command_dict["radio_str"],
|
||||||
|
identifier=command_dict.get("identifier", ""),
|
||||||
|
)
|
||||||
|
elif fn == "init_tx":
|
||||||
|
self.init_tx(
|
||||||
|
center_frequency=command_dict["center_frequency"],
|
||||||
|
sample_rate=command_dict["sample_rate"],
|
||||||
|
gain=command_dict["gain"],
|
||||||
|
channel=command_dict.get("channel", 0),
|
||||||
|
gain_mode=command_dict.get("gain_mode", "absolute"),
|
||||||
|
)
|
||||||
|
elif fn == "transmit":
|
||||||
|
self.transmit(duration_s=command_dict.get("duration_s", 1.0))
|
||||||
|
elif fn == "stop":
|
||||||
|
self.stop()
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown function: {fn!r}")
|
||||||
|
return {"status": True, "message": out_buf.getvalue(), "error_message": err_buf.getvalue()}
|
||||||
|
except Exception as exc:
|
||||||
|
logger.exception("Error executing %s", fn)
|
||||||
|
return {"status": False, "message": out_buf.getvalue(), "error_message": str(exc)}
|
||||||
|
|
||||||
|
|
||||||
|
def _serve(port: int) -> None:
|
||||||
|
context = zmq.Context()
|
||||||
|
socket = context.socket(zmq.REP)
|
||||||
|
socket.bind(f"tcp://*:{port}")
|
||||||
|
logger.info("RemoteTransmitter listening on port %d", port)
|
||||||
|
tx = RemoteTransmitter()
|
||||||
|
while True:
|
||||||
|
raw = socket.recv()
|
||||||
|
cmd = json.loads(raw.decode())
|
||||||
|
response = tx.run_function(cmd)
|
||||||
|
socket.send(json.dumps(response).encode())
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
parser = argparse.ArgumentParser(description="SDR Tx ZMQ server")
|
||||||
|
parser.add_argument("--port", type=int, default=5556)
|
||||||
|
args = parser.parse_args()
|
||||||
|
_serve(args.port)
|
||||||
|
|
@ -0,0 +1,218 @@
|
||||||
|
"""Client-side SSH + ZMQ controller for a remote SDR transmitter.
|
||||||
|
|
||||||
|
Run this on the Rx machine (or hub). It SSH-es into the Tx machine,
|
||||||
|
starts :mod:`remote_transmitter` there, then sends JSON-RPC commands over
|
||||||
|
ZMQ.
|
||||||
|
|
||||||
|
Requires: paramiko, zmq.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
|
||||||
|
import paramiko
|
||||||
|
import zmq
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
_STARTUP_WAIT_S = 2.0 # seconds to wait for remote ZMQ server to bind
|
||||||
|
|
||||||
|
|
||||||
|
class RemoteTransmitterController:
|
||||||
|
"""SSH into a Tx machine, start the ZMQ server, and send commands.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
host: IP or hostname of the Tx machine.
|
||||||
|
ssh_user: SSH username.
|
||||||
|
ssh_key_path: Path to SSH private key file.
|
||||||
|
zmq_port: ZMQ port that the remote transmitter will bind on.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
host: str,
|
||||||
|
ssh_user: str,
|
||||||
|
ssh_key_path: str,
|
||||||
|
zmq_port: int = 5556,
|
||||||
|
) -> None:
|
||||||
|
self._host = host
|
||||||
|
self._zmq_port = zmq_port
|
||||||
|
self._ssh: paramiko.SSHClient | None = None
|
||||||
|
self._ssh_stdout = None
|
||||||
|
self._context: zmq.Context | None = None
|
||||||
|
self._socket: zmq.Socket | None = None
|
||||||
|
self._tx_thread: threading.Thread | None = None
|
||||||
|
self._lock = threading.Lock()
|
||||||
|
|
||||||
|
self._connect(host, ssh_user, ssh_key_path, zmq_port)
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Connection management
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def _connect(self, host: str, ssh_user: str, ssh_key_path: str, zmq_port: int) -> None:
|
||||||
|
"""Open SSH tunnel, start remote server, connect ZMQ socket."""
|
||||||
|
try:
|
||||||
|
import paramiko
|
||||||
|
except ImportError as exc:
|
||||||
|
raise RuntimeError("paramiko is required for remote SDR control: pip install paramiko") from exc
|
||||||
|
try:
|
||||||
|
import zmq
|
||||||
|
except ImportError as exc:
|
||||||
|
raise RuntimeError("pyzmq is required for remote SDR control: pip install pyzmq") from exc
|
||||||
|
|
||||||
|
logger.info("SSH connecting to %s@%s …", ssh_user, host)
|
||||||
|
self._ssh = paramiko.SSHClient()
|
||||||
|
self._ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
|
||||||
|
self._ssh.connect(hostname=host, username=ssh_user, key_filename=ssh_key_path)
|
||||||
|
|
||||||
|
cmd = f"python -m ria_toolkit_oss.remote_control.remote_transmitter --port {zmq_port}"
|
||||||
|
logger.info("Starting remote Tx server: %s", cmd)
|
||||||
|
_, self._ssh_stdout, _ = self._ssh.exec_command(cmd)
|
||||||
|
|
||||||
|
time.sleep(_STARTUP_WAIT_S)
|
||||||
|
|
||||||
|
self._context = zmq.Context()
|
||||||
|
self._socket = self._context.socket(zmq.REQ)
|
||||||
|
self._socket.connect(f"tcp://{host}:{zmq_port}")
|
||||||
|
logger.info("ZMQ connected to tcp://%s:%d", host, zmq_port)
|
||||||
|
|
||||||
|
def close(self) -> None:
|
||||||
|
"""Tear down ZMQ and SSH connections."""
|
||||||
|
if self._socket is not None:
|
||||||
|
try:
|
||||||
|
self._socket.close(linger=0)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
self._socket = None
|
||||||
|
if self._context is not None:
|
||||||
|
try:
|
||||||
|
self._context.term()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
self._context = None
|
||||||
|
if self._ssh_stdout is not None:
|
||||||
|
try:
|
||||||
|
self._ssh_stdout.channel.close()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
self._ssh_stdout = None
|
||||||
|
if self._ssh is not None:
|
||||||
|
try:
|
||||||
|
self._ssh.close()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
self._ssh = None
|
||||||
|
logger.info("RemoteTransmitterController closed")
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# ZMQ dispatch
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def _send(self, command: dict) -> dict:
|
||||||
|
"""Send a JSON-RPC command and return the response dict (thread-safe)."""
|
||||||
|
with self._lock:
|
||||||
|
if self._socket is None:
|
||||||
|
raise RuntimeError("Controller is closed")
|
||||||
|
self._socket.send(json.dumps(command).encode())
|
||||||
|
raw = self._socket.recv()
|
||||||
|
reply: dict = json.loads(raw.decode())
|
||||||
|
if not reply.get("status"):
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Remote command '{command.get('function_name')}' failed: "
|
||||||
|
f"{reply.get('error_message', 'unknown error')}"
|
||||||
|
)
|
||||||
|
return reply
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Public API
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def set_radio(self, device_type: str, device_id: str = "") -> None:
|
||||||
|
"""Initialise the SDR radio on the Tx machine.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
device_type: SDR type — ``pluto``, ``usrp``, ``hackrf``, ``bladerf``.
|
||||||
|
device_id: Device-specific identifier (IP, serial, etc.).
|
||||||
|
"""
|
||||||
|
logger.info("set_radio(%s, %r)", device_type, device_id)
|
||||||
|
self._send({"function_name": "set_radio", "radio_str": device_type, "identifier": device_id})
|
||||||
|
|
||||||
|
def init_tx(
|
||||||
|
self,
|
||||||
|
center_frequency: float,
|
||||||
|
sample_rate: float,
|
||||||
|
gain: float,
|
||||||
|
channel: int = 0,
|
||||||
|
gain_mode: str = "absolute",
|
||||||
|
) -> None:
|
||||||
|
"""Configure Tx parameters on the remote SDR.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
center_frequency: Center frequency in Hz.
|
||||||
|
sample_rate: Sample rate in Hz.
|
||||||
|
gain: Tx gain in dB.
|
||||||
|
channel: RF channel index (default 0).
|
||||||
|
gain_mode: ``"absolute"`` (default) or ``"relative"``.
|
||||||
|
"""
|
||||||
|
logger.info(
|
||||||
|
"init_tx: fc=%.3f MHz, fs=%.3f MHz, gain=%.1f dB, ch=%d",
|
||||||
|
center_frequency / 1e6,
|
||||||
|
sample_rate / 1e6,
|
||||||
|
gain,
|
||||||
|
channel,
|
||||||
|
)
|
||||||
|
self._send(
|
||||||
|
{
|
||||||
|
"function_name": "init_tx",
|
||||||
|
"center_frequency": center_frequency,
|
||||||
|
"sample_rate": sample_rate,
|
||||||
|
"gain": gain,
|
||||||
|
"channel": channel,
|
||||||
|
"gain_mode": gain_mode,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
def transmit_async(self, duration_s: float) -> None:
|
||||||
|
"""Start a timed CW transmission in a background thread.
|
||||||
|
|
||||||
|
Returns immediately. Call :meth:`wait_transmit` after recording to
|
||||||
|
ensure the transmit thread has finished before the next step.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
duration_s: Transmission duration in seconds.
|
||||||
|
"""
|
||||||
|
logger.info("transmit_async: %.1f s", duration_s)
|
||||||
|
|
||||||
|
def _run() -> None:
|
||||||
|
try:
|
||||||
|
self._send({"function_name": "transmit", "duration_s": duration_s})
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("Background transmit error: %s", exc)
|
||||||
|
|
||||||
|
self._tx_thread = threading.Thread(target=_run, daemon=True, name="remote-tx")
|
||||||
|
self._tx_thread.start()
|
||||||
|
|
||||||
|
def wait_transmit(self, timeout: float | None = None) -> None:
|
||||||
|
"""Wait for the background transmit thread to finish.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
timeout: Maximum seconds to wait. ``None`` = wait indefinitely.
|
||||||
|
"""
|
||||||
|
if self._tx_thread is not None:
|
||||||
|
self._tx_thread.join(timeout=timeout)
|
||||||
|
self._tx_thread = None
|
||||||
|
|
||||||
|
def stop(self) -> None:
|
||||||
|
"""Stop transmission and release the remote SDR, then close connections."""
|
||||||
|
logger.info("Sending stop to remote Tx")
|
||||||
|
try:
|
||||||
|
self._send({"function_name": "stop"})
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("stop command error (may be normal if connection closed): %s", exc)
|
||||||
|
finally:
|
||||||
|
self.close()
|
||||||
|
|
@ -4,6 +4,82 @@ It streamlines tasks involving signal reception and transmission, as well as com
|
||||||
operations such as detecting and configuring available devices.
|
operations such as detecting and configuring available devices.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
__all__ = ["SDR", "SDRError", "SDRParameterError"]
|
__all__ = [
|
||||||
|
"SDR",
|
||||||
|
"SDRError",
|
||||||
|
"SDRParameterError",
|
||||||
|
"SdrDisconnectedError",
|
||||||
|
"MockSDR",
|
||||||
|
"get_sdr_device",
|
||||||
|
"detect_available",
|
||||||
|
]
|
||||||
|
|
||||||
from .sdr import SDR, SDRError, SDRParameterError
|
from .mock import MockSDR
|
||||||
|
from .sdr import ( # noqa: F401
|
||||||
|
SDR,
|
||||||
|
SdrDisconnectedError,
|
||||||
|
SDRError,
|
||||||
|
SDRParameterError,
|
||||||
|
translate_disconnect,
|
||||||
|
)
|
||||||
|
|
||||||
|
_DRIVER_CANDIDATES: tuple[tuple[str, str, str], ...] = (
|
||||||
|
("mock", "ria_toolkit_oss.sdr.mock", "MockSDR"),
|
||||||
|
("pluto", "ria_toolkit_oss.sdr.pluto", "Pluto"),
|
||||||
|
("hackrf", "ria_toolkit_oss.sdr.hackrf", "HackRF"),
|
||||||
|
("rtlsdr", "ria_toolkit_oss.sdr.rtlsdr", "RTLSDR"),
|
||||||
|
("usrp", "ria_toolkit_oss.sdr.usrp", "USRP"),
|
||||||
|
("blade", "ria_toolkit_oss.sdr.blade", "Blade"),
|
||||||
|
("thinkrf", "ria_toolkit_oss.sdr.thinkrf", "ThinkRF"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def detect_available() -> dict[str, type]:
|
||||||
|
"""Return ``{device_name: driver_class}`` for every driver whose module imports cleanly.
|
||||||
|
|
||||||
|
Importability is a proxy for "the user has installed this driver's optional dependency".
|
||||||
|
It does not probe for physical hardware presence — that requires actually instantiating
|
||||||
|
the driver, which can be slow and side-effectful.
|
||||||
|
"""
|
||||||
|
import importlib
|
||||||
|
|
||||||
|
out: dict[str, type] = {}
|
||||||
|
for name, module_path, cls_name in _DRIVER_CANDIDATES:
|
||||||
|
try:
|
||||||
|
mod = importlib.import_module(module_path)
|
||||||
|
out[name] = getattr(mod, cls_name)
|
||||||
|
except Exception:
|
||||||
|
continue
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def get_sdr_device(device_type: str, ident: str | None = None, tx: bool = False) -> SDR:
|
||||||
|
"""Return an SDR instance for *device_type*.
|
||||||
|
|
||||||
|
For ``"mock"`` / ``"sim"`` device types, returns a :class:`MockSDR`
|
||||||
|
immediately (no hardware required). For all real device types, delegates
|
||||||
|
to ``ria_toolkit_oss_cli.ria_toolkit_oss.common.get_sdr_device`` if the
|
||||||
|
CLI package is installed; otherwise raises ``ImportError`` with a helpful
|
||||||
|
message.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
device_type: Device name (``"mock"``, ``"pluto"``, ``"usrp"``, …).
|
||||||
|
ident: Optional device identifier (IP address, serial number, …).
|
||||||
|
tx: If True, require TX capability.
|
||||||
|
"""
|
||||||
|
if device_type in ("mock", "sim"):
|
||||||
|
return MockSDR()
|
||||||
|
|
||||||
|
# Delegate real device types to the CLI package which holds the driver
|
||||||
|
# imports behind hardware-specific optional dependencies.
|
||||||
|
try:
|
||||||
|
from ria_toolkit_oss_cli.ria_toolkit_oss.common import (
|
||||||
|
get_sdr_device as _cli_get,
|
||||||
|
)
|
||||||
|
except ImportError as exc:
|
||||||
|
raise ImportError(
|
||||||
|
f"ria_toolkit_oss_cli is required to use hardware SDR device '{device_type}'. "
|
||||||
|
"Install it with: pip install ria-toolkit-oss-cli"
|
||||||
|
) from exc
|
||||||
|
|
||||||
|
return _cli_get(device_type, ident=ident, tx=tx)
|
||||||
|
|
|
||||||
|
|
@ -5,7 +5,7 @@ from typing import Optional
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from bladerf import _bladerf
|
from bladerf import _bladerf
|
||||||
|
|
||||||
from ria_toolkit_oss.datatypes import Recording
|
from ria_toolkit_oss.data import Recording
|
||||||
from ria_toolkit_oss.sdr import SDR, SDRError, SDRParameterError
|
from ria_toolkit_oss.sdr import SDR, SDRError, SDRParameterError
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -4,7 +4,7 @@ from typing import Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from ria_toolkit_oss.datatypes.recording import Recording
|
from ria_toolkit_oss.data.recording import Recording
|
||||||
from ria_toolkit_oss.sdr._external.libhackrf import HackRF as hrf
|
from ria_toolkit_oss.sdr._external.libhackrf import HackRF as hrf
|
||||||
from ria_toolkit_oss.sdr.sdr import SDR, SDRParameterError
|
from ria_toolkit_oss.sdr.sdr import SDR, SDRParameterError
|
||||||
|
|
||||||
|
|
@ -58,7 +58,7 @@ class HackRF(SDR):
|
||||||
:param channel: The channel the HackRF is set to. (Not actually used)
|
:param channel: The channel the HackRF is set to. (Not actually used)
|
||||||
:type channel: int
|
:type channel: int
|
||||||
:param gain_mode: 'absolute' passes gain directly to the sdr,
|
:param gain_mode: 'absolute' passes gain directly to the sdr,
|
||||||
'relative' means that gain should be a negative value, and it will be subtracted from the max gain (40).
|
'relative' means that gain should be a negative value, and it will be subtracted from the max gain (40).
|
||||||
:type gain_mode: str
|
:type gain_mode: str
|
||||||
"""
|
"""
|
||||||
print("Initializing RX")
|
print("Initializing RX")
|
||||||
|
|
|
||||||
131
src/ria_toolkit_oss/sdr/mock.py
Normal file
131
src/ria_toolkit_oss/sdr/mock.py
Normal file
|
|
@ -0,0 +1,131 @@
|
||||||
|
"""Simulated SDR device for testing without hardware.
|
||||||
|
|
||||||
|
Set ``recorder.device = "mock"`` (or ``"sim"``) in a campaign config to use
|
||||||
|
this driver. The inference loop can also use it by specifying ``device:
|
||||||
|
"mock"`` in the SDR start request.
|
||||||
|
|
||||||
|
The mock generates complex float32 AWGN samples normalised to [-1, 1].
|
||||||
|
It satisfies both interfaces used in this codebase:
|
||||||
|
|
||||||
|
- ``record(num_samples)`` / ``_stream_rx(callback)`` — used by
|
||||||
|
``CampaignExecutor`` (inherits from ``SDR`` base class).
|
||||||
|
- ``rx(num_samples)`` — PlutoSDR-style interface used by the controller
|
||||||
|
inference loop.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import time
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from ria_toolkit_oss.sdr.sdr import SDR
|
||||||
|
|
||||||
|
_DEFAULT_BUFFER_SIZE = 4096
|
||||||
|
# Simulated sample rate throttle: sleep this long between buffers so the
|
||||||
|
# loop does not spin at 100% CPU. 10 ms ≈ 100 buffers/s which is fine for
|
||||||
|
# tests and campaign execution timing.
|
||||||
|
_SLEEP_PER_BUFFER_S = 0.01
|
||||||
|
|
||||||
|
|
||||||
|
class MockSDR(SDR):
|
||||||
|
"""Software-simulated SDR that generates AWGN noise.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
buffer_size: Number of complex samples per streaming buffer.
|
||||||
|
seed: Optional RNG seed for reproducible output.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, buffer_size: int = _DEFAULT_BUFFER_SIZE, seed: int | None = None):
|
||||||
|
super().__init__()
|
||||||
|
self.rx_buffer_size: int = buffer_size
|
||||||
|
self._rng = np.random.default_rng(seed)
|
||||||
|
|
||||||
|
# Direct attribute aliases used by _apply_sdr_config in the controller.
|
||||||
|
self.center_freq: float = 2.45e9
|
||||||
|
self.sample_rate: float = 10e6
|
||||||
|
self.gain: float = 40.0
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Abstract method implementations
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def init_rx(
|
||||||
|
self,
|
||||||
|
sample_rate: float,
|
||||||
|
center_frequency: float,
|
||||||
|
gain,
|
||||||
|
channel: int = 0,
|
||||||
|
gain_mode: str = "manual",
|
||||||
|
) -> None:
|
||||||
|
self.rx_sample_rate = float(sample_rate)
|
||||||
|
self.rx_center_frequency = float(center_frequency)
|
||||||
|
self.rx_gain = 40.0 if gain is None else float(gain)
|
||||||
|
# Mirror to the attribute names used by _apply_sdr_config.
|
||||||
|
self.sample_rate = self.rx_sample_rate
|
||||||
|
self.center_freq = self.rx_center_frequency
|
||||||
|
self.gain = self.rx_gain
|
||||||
|
self._rx_initialized = True
|
||||||
|
|
||||||
|
def init_tx(
|
||||||
|
self,
|
||||||
|
sample_rate: float,
|
||||||
|
center_frequency: float,
|
||||||
|
gain,
|
||||||
|
channel: int = 0,
|
||||||
|
gain_mode: str = "manual",
|
||||||
|
) -> None:
|
||||||
|
self.tx_sample_rate = float(sample_rate)
|
||||||
|
self.tx_center_frequency = float(center_frequency)
|
||||||
|
self.tx_gain = 40.0 if gain is None else float(gain)
|
||||||
|
self._tx_initialized = True
|
||||||
|
|
||||||
|
def _stream_rx(self, callback) -> None:
|
||||||
|
"""Generate 1-D AWGN buffers and pass each to *callback* until stopped.
|
||||||
|
|
||||||
|
Uses 1-D arrays so the base class ``_validate_buffer`` check does not
|
||||||
|
incorrectly flag them as corrupted (the (1, N) form triggers a false
|
||||||
|
positive in the all-same-value check).
|
||||||
|
"""
|
||||||
|
self._enable_rx = True
|
||||||
|
while self._enable_rx:
|
||||||
|
buf = self._awgn(self.rx_buffer_size)
|
||||||
|
callback(buf)
|
||||||
|
time.sleep(_SLEEP_PER_BUFFER_S)
|
||||||
|
|
||||||
|
def _stream_tx(self, callback) -> None:
|
||||||
|
self._enable_tx = True
|
||||||
|
while self._enable_tx:
|
||||||
|
callback(self.rx_buffer_size)
|
||||||
|
time.sleep(_SLEEP_PER_BUFFER_S)
|
||||||
|
|
||||||
|
def set_clock_source(self, source: str) -> None:
|
||||||
|
pass # no-op
|
||||||
|
|
||||||
|
def close(self) -> None:
|
||||||
|
self._enable_rx = False
|
||||||
|
self._enable_tx = False
|
||||||
|
self._rx_initialized = False
|
||||||
|
self._tx_initialized = False
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# PlutoSDR-style interface used by the controller inference loop
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def rx(self, num_samples: int) -> np.ndarray:
|
||||||
|
"""Return *num_samples* complex64 AWGN samples (PlutoSDR-style)."""
|
||||||
|
return self._awgn(num_samples)
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Internal helpers
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def _awgn(self, n: int) -> np.ndarray:
|
||||||
|
"""Return *n* normalised complex64 AWGN samples as a 1-D array."""
|
||||||
|
real = self._rng.standard_normal(n).astype(np.float32)
|
||||||
|
imag = self._rng.standard_normal(n).astype(np.float32)
|
||||||
|
buf = real + 1j * imag
|
||||||
|
peak = np.abs(buf).max()
|
||||||
|
if peak > 1e-9:
|
||||||
|
buf /= peak
|
||||||
|
return buf
|
||||||
|
|
@ -7,8 +7,13 @@ from typing import Optional
|
||||||
import adi
|
import adi
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from ria_toolkit_oss.datatypes.recording import Recording
|
from ria_toolkit_oss.data.recording import Recording
|
||||||
from ria_toolkit_oss.sdr.sdr import SDR, SDRError, SDRParameterError
|
from ria_toolkit_oss.sdr.sdr import (
|
||||||
|
SDR,
|
||||||
|
SDRError,
|
||||||
|
SDRParameterError,
|
||||||
|
translate_disconnect,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class Pluto(SDR):
|
class Pluto(SDR):
|
||||||
|
|
@ -164,6 +169,25 @@ class Pluto(SDR):
|
||||||
# send callback complex signal
|
# send callback complex signal
|
||||||
callback(buffer=signal, metadata=None)
|
callback(buffer=signal, metadata=None)
|
||||||
|
|
||||||
|
def rx(self, num_samples: Optional[int] = None) -> np.ndarray:
|
||||||
|
"""PlutoSDR-style single-buffer capture returning a complex64 array.
|
||||||
|
|
||||||
|
Sets the radio buffer size to *num_samples* (if given) and returns one
|
||||||
|
buffer directly from ``self.radio.rx()``. Raises
|
||||||
|
:class:`SdrDisconnectedError` on USB/device drop so callers (e.g. the
|
||||||
|
streamer) can report the failure and stop cleanly instead of crashing.
|
||||||
|
"""
|
||||||
|
if num_samples is not None:
|
||||||
|
try:
|
||||||
|
self.set_rx_buffer_size(buffer_size=int(num_samples))
|
||||||
|
except Exception as exc:
|
||||||
|
raise translate_disconnect(exc) from exc
|
||||||
|
try:
|
||||||
|
samples = self.radio.rx()
|
||||||
|
except Exception as exc:
|
||||||
|
raise translate_disconnect(exc) from exc
|
||||||
|
return np.asarray(samples)
|
||||||
|
|
||||||
def _record_fast(self, num_samples):
|
def _record_fast(self, num_samples):
|
||||||
"""Optimized single-buffer capture for ≤16M samples."""
|
"""Optimized single-buffer capture for ≤16M samples."""
|
||||||
|
|
||||||
|
|
@ -329,7 +353,12 @@ class Pluto(SDR):
|
||||||
elif tx_time is not None:
|
elif tx_time is not None:
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
tx_time = len(recording) / self.tx_sample_rate
|
if isinstance(recording, Recording):
|
||||||
|
tx_time = recording.data.shape[-1] / self.tx_sample_rate
|
||||||
|
elif isinstance(recording, np.ndarray):
|
||||||
|
tx_time = recording.shape[-1] / self.tx_sample_rate
|
||||||
|
else:
|
||||||
|
tx_time = len(recording[0]) / self.tx_sample_rate
|
||||||
|
|
||||||
data = self._format_tx_data(recording=recording)
|
data = self._format_tx_data(recording=recording)
|
||||||
|
|
||||||
|
|
@ -360,7 +389,10 @@ class Pluto(SDR):
|
||||||
self._enable_tx = True
|
self._enable_tx = True
|
||||||
while self._enable_tx is True:
|
while self._enable_tx is True:
|
||||||
buffer = self._convert_tx_samples(callback(self.tx_buffer_size))
|
buffer = self._convert_tx_samples(callback(self.tx_buffer_size))
|
||||||
self.radio.tx(buffer[0])
|
# pyadi-iio's ``radio.tx`` auto-wraps single-channel 1-D input.
|
||||||
|
# Indexing ``buffer[0]`` was a latent bug for callbacks that
|
||||||
|
# returned 1-D samples (scalar → TypeError inside pyadi).
|
||||||
|
self.radio.tx(buffer)
|
||||||
|
|
||||||
def set_rx_center_frequency(self, center_frequency):
|
def set_rx_center_frequency(self, center_frequency):
|
||||||
"""
|
"""
|
||||||
|
|
@ -431,7 +463,7 @@ class Pluto(SDR):
|
||||||
abs_gain = gain
|
abs_gain = gain
|
||||||
|
|
||||||
if abs_gain < rx_gain_min or abs_gain > rx_gain_max:
|
if abs_gain < rx_gain_min or abs_gain > rx_gain_max:
|
||||||
abs_gain = min(max(gain, rx_gain_min), rx_gain_max)
|
abs_gain = min(max(abs_gain, rx_gain_min), rx_gain_max)
|
||||||
print(f"Gain {gain} out of range for Pluto.")
|
print(f"Gain {gain} out of range for Pluto.")
|
||||||
print(f"Gain range: {rx_gain_min} to {rx_gain_max} dB")
|
print(f"Gain range: {rx_gain_min} to {rx_gain_max} dB")
|
||||||
|
|
||||||
|
|
@ -490,74 +522,85 @@ class Pluto(SDR):
|
||||||
raise SDRError(e)
|
raise SDRError(e)
|
||||||
|
|
||||||
def set_tx_center_frequency(self, center_frequency):
|
def set_tx_center_frequency(self, center_frequency):
|
||||||
if center_frequency < 70e6 or center_frequency > 6e9:
|
# ``adi.Pluto`` exposes one radio handle shared between RX and TX; concurrent
|
||||||
raise SDRParameterError(
|
# RX + TX sessions (see the agent ``_SdrRegistry``) may call RX and TX
|
||||||
f"{self.__class__.__name__}: Center frequency {center_frequency/1e9:.3f} GHz "
|
# setters at the same time. Serialize with ``_param_lock`` — RX setters hold
|
||||||
f"out of range:\nStandard:\t[{325e6/1e9:.3f} - {3.8e9/1e9:.3f} GHz]\nHacked:\t"
|
# the same reentrant lock — so native attribute writes don't interleave.
|
||||||
f"[{70e6/1e9:.3f} - {6e9/1e9:.3f} GHz]"
|
with self._param_lock:
|
||||||
)
|
if center_frequency < 70e6 or center_frequency > 6e9:
|
||||||
|
raise SDRParameterError(
|
||||||
|
f"{self.__class__.__name__}: Center frequency {center_frequency/1e9:.3f} GHz "
|
||||||
|
f"out of range:\nStandard:\t[{325e6/1e9:.3f} - {3.8e9/1e9:.3f} GHz]\nHacked:\t"
|
||||||
|
f"[{70e6/1e9:.3f} - {6e9/1e9:.3f} GHz]"
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
self.radio.tx_lo = int(center_frequency)
|
self.radio.tx_lo = int(center_frequency)
|
||||||
self.tx_center_frequency = center_frequency
|
self.tx_center_frequency = center_frequency
|
||||||
except OSError as e:
|
except OSError as e:
|
||||||
raise SDRError(e)
|
raise SDRError(e)
|
||||||
except ValueError:
|
except ValueError:
|
||||||
raise SDRParameterError(
|
raise SDRParameterError(
|
||||||
f"{self.__class__.__name__}: Center frequency {center_frequency/1e9:.3f} GHz "
|
f"{self.__class__.__name__}: Center frequency {center_frequency/1e9:.3f} GHz "
|
||||||
f"out of range:\nStandard:\t[{325e6/1e9:.3f} - {3.8e9/1e9:.3f} GHz]\nHacked:\t"
|
f"out of range:\nStandard:\t[{325e6/1e9:.3f} - {3.8e9/1e9:.3f} GHz]\nHacked:\t"
|
||||||
f"[{70e6/1e9:.3f} - {6e9/1e9:.3f} GHz]"
|
f"[{70e6/1e9:.3f} - {6e9/1e9:.3f} GHz]"
|
||||||
)
|
)
|
||||||
|
|
||||||
def set_tx_sample_rate(self, sample_rate):
|
def set_tx_sample_rate(self, sample_rate):
|
||||||
min_rate, max_rate = 65.1e3, 61.44e6
|
# ``self.radio.sample_rate`` is shared between RX and TX on Pluto — RX's
|
||||||
if sample_rate < min_rate or sample_rate > max_rate:
|
# ``set_rx_sample_rate`` writes the same native attribute. Hold ``_param_lock``
|
||||||
raise SDRParameterError(
|
# so full-duplex sessions can't interleave writes.
|
||||||
f"{self.__class__.__name__}: Sample rate {sample_rate/1e6:.3f} Msps "
|
with self._param_lock:
|
||||||
f"out of range: [{min_rate/1e6:.3f} - {max_rate/1e6:.3f} Msps]"
|
min_rate, max_rate = 65.1e3, 61.44e6
|
||||||
)
|
if sample_rate < min_rate or sample_rate > max_rate:
|
||||||
|
raise SDRParameterError(
|
||||||
|
f"{self.__class__.__name__}: Sample rate {sample_rate/1e6:.3f} Msps "
|
||||||
|
f"out of range: [{min_rate/1e6:.3f} - {max_rate/1e6:.3f} Msps]"
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
self.radio.sample_rate = sample_rate
|
self.radio.sample_rate = sample_rate
|
||||||
self.tx_sample_rate = sample_rate
|
self.tx_sample_rate = sample_rate
|
||||||
except OSError as e:
|
except OSError as e:
|
||||||
raise SDRError(e)
|
raise SDRError(e)
|
||||||
except ValueError:
|
except ValueError:
|
||||||
raise SDRParameterError(
|
raise SDRParameterError(
|
||||||
f"{self.__class__.__name__}: Sample rate {sample_rate/1e6:.3f} Msps "
|
f"{self.__class__.__name__}: Sample rate {sample_rate/1e6:.3f} Msps "
|
||||||
f"out of range: [{min_rate/1e6:.3f} - {max_rate/1e6:.3f} Msps]"
|
f"out of range: [{min_rate/1e6:.3f} - {max_rate/1e6:.3f} Msps]"
|
||||||
)
|
)
|
||||||
|
|
||||||
def set_tx_gain(self, gain, channel=0, gain_mode="absolute"):
|
def set_tx_gain(self, gain, channel=0, gain_mode="absolute"):
|
||||||
tx_gain_min = -89
|
# Serialize with RX setters: see ``set_tx_sample_rate`` above.
|
||||||
tx_gain_max = 0
|
with self._param_lock:
|
||||||
|
tx_gain_min = -89
|
||||||
|
tx_gain_max = 0
|
||||||
|
|
||||||
if gain_mode == "relative":
|
if gain_mode == "relative":
|
||||||
if gain > 0:
|
if gain > 0:
|
||||||
raise SDRParameterError("When gain_mode = 'relative', gain must be < 0. This sets\
|
raise SDRParameterError("When gain_mode = 'relative', gain must be < 0. This sets\
|
||||||
the gain relative to the maximum possible gain.")
|
the gain relative to the maximum possible gain.")
|
||||||
|
else:
|
||||||
|
abs_gain = tx_gain_max + gain
|
||||||
else:
|
else:
|
||||||
abs_gain = tx_gain_max + gain
|
abs_gain = gain
|
||||||
else:
|
|
||||||
abs_gain = gain
|
|
||||||
|
|
||||||
if abs_gain < tx_gain_min or abs_gain > tx_gain_max:
|
if abs_gain < tx_gain_min or abs_gain > tx_gain_max:
|
||||||
abs_gain = min(max(gain, tx_gain_min), tx_gain_max)
|
abs_gain = min(max(gain, tx_gain_min), tx_gain_max)
|
||||||
print(f"Gain {gain} out of range for Pluto.")
|
print(f"Gain {gain} out of range for Pluto.")
|
||||||
print(f"Gain range: {tx_gain_min} to {tx_gain_max} dB")
|
print(f"Gain range: {tx_gain_min} to {tx_gain_max} dB")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
self.tx_gain = abs_gain
|
self.tx_gain = abs_gain
|
||||||
|
|
||||||
if channel == 0:
|
if channel == 0:
|
||||||
self.radio.tx_hardwaregain_chan0 = int(abs_gain)
|
self.radio.tx_hardwaregain_chan0 = int(abs_gain)
|
||||||
elif channel == 1:
|
elif channel == 1:
|
||||||
self.radio.tx_hardwaregain_chan1 = int(abs_gain)
|
self.radio.tx_hardwaregain_chan1 = int(abs_gain)
|
||||||
else:
|
else:
|
||||||
raise SDRParameterError(f"Pluto channel must be 0 or 1 but was {channel}.")
|
raise SDRParameterError(f"Pluto channel must be 0 or 1 but was {channel}.")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise SDRError(e)
|
raise SDRError(e)
|
||||||
|
|
||||||
def set_tx_channel(self, channel):
|
def set_tx_channel(self, channel):
|
||||||
if channel == 0:
|
if channel == 0:
|
||||||
|
|
@ -583,6 +626,8 @@ class Pluto(SDR):
|
||||||
self.tx_buffer_size = buffer_size
|
self.tx_buffer_size = buffer_size
|
||||||
|
|
||||||
def close(self):
|
def close(self):
|
||||||
|
if not hasattr(self, "radio"):
|
||||||
|
return
|
||||||
if self.radio.tx_cyclic_buffer:
|
if self.radio.tx_cyclic_buffer:
|
||||||
self.radio.tx_destroy_buffer()
|
self.radio.tx_destroy_buffer()
|
||||||
del self.radio
|
del self.radio
|
||||||
|
|
|
||||||
|
|
@ -11,7 +11,7 @@ try:
|
||||||
except ImportError as exc: # pragma: no cover - dependency provided by end user
|
except ImportError as exc: # pragma: no cover - dependency provided by end user
|
||||||
raise ImportError("pyrtlsdr is required to use the RTLSDR class") from exc
|
raise ImportError("pyrtlsdr is required to use the RTLSDR class") from exc
|
||||||
|
|
||||||
from ria_toolkit_oss.datatypes.recording import Recording
|
from ria_toolkit_oss.data.recording import Recording
|
||||||
from ria_toolkit_oss.sdr.sdr import SDR, SDRParameterError
|
from ria_toolkit_oss.sdr.sdr import SDR, SDRParameterError
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -8,7 +8,7 @@ from typing import Optional
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import zmq
|
import zmq
|
||||||
|
|
||||||
from ria_toolkit_oss.datatypes.recording import Recording
|
from ria_toolkit_oss.data.recording import Recording
|
||||||
|
|
||||||
|
|
||||||
class SDR(ABC):
|
class SDR(ABC):
|
||||||
|
|
@ -32,7 +32,6 @@ class SDR(ABC):
|
||||||
self._accumulated_buffer = None
|
self._accumulated_buffer = None
|
||||||
self._max_num_buffers = None
|
self._max_num_buffers = None
|
||||||
self._num_buffers_processed = 0
|
self._num_buffers_processed = 0
|
||||||
self._accumulated_buffer = None
|
|
||||||
self._last_buffer = None
|
self._last_buffer = None
|
||||||
self._corrupted_buffer_count = 0
|
self._corrupted_buffer_count = 0
|
||||||
|
|
||||||
|
|
@ -44,6 +43,13 @@ class SDR(ABC):
|
||||||
self.tx_gain = None
|
self.tx_gain = None
|
||||||
self._param_lock = threading.RLock() # Reentrant lock
|
self._param_lock = threading.RLock() # Reentrant lock
|
||||||
|
|
||||||
|
# Pending config consumed by rx() on first call and by _apply_sdr_config
|
||||||
|
# in the agent inference loop. Subclasses that need different defaults
|
||||||
|
# (e.g. MockSDR) can overwrite these in their own __init__.
|
||||||
|
self.center_freq: float = 2.4e9
|
||||||
|
self.sample_rate: float = 10e6
|
||||||
|
self.gain: float = 40.0
|
||||||
|
|
||||||
def record(self, num_samples: Optional[int] = None, rx_time: Optional[int | float] = None) -> Recording:
|
def record(self, num_samples: Optional[int] = None, rx_time: Optional[int | float] = None) -> Recording:
|
||||||
"""
|
"""
|
||||||
Create a radio recording of a given length. Either ``num_samples`` or ``rx_time`` must be provided.
|
Create a radio recording of a given length. Either ``num_samples`` or ``rx_time`` must be provided.
|
||||||
|
|
@ -101,6 +107,32 @@ class SDR(ABC):
|
||||||
self._num_buffers_processed = 0
|
self._num_buffers_processed = 0
|
||||||
return recording
|
return recording
|
||||||
|
|
||||||
|
def rx(self, num_samples: int) -> "np.ndarray":
|
||||||
|
"""Return *num_samples* complex IQ samples as a 1-D complex64 array.
|
||||||
|
|
||||||
|
This is the interface used by the agent inference loop. On first call,
|
||||||
|
``init_rx()`` is invoked automatically using the values stored in
|
||||||
|
``center_freq``, ``sample_rate``, and ``gain`` (set beforehand by
|
||||||
|
``_apply_sdr_config``). Subsequent calls stream directly.
|
||||||
|
|
||||||
|
Subclasses may override this for hardware-native capture APIs (e.g.
|
||||||
|
``MockSDR`` uses AWGN generation; ``PlutoSDR`` could use
|
||||||
|
``self.radio.rx()``).
|
||||||
|
"""
|
||||||
|
if not self._rx_initialized:
|
||||||
|
gain = self.gain if isinstance(self.gain, (int, float)) else 40.0
|
||||||
|
self.init_rx(
|
||||||
|
sample_rate=self.sample_rate,
|
||||||
|
center_frequency=self.center_freq,
|
||||||
|
gain=gain,
|
||||||
|
channel=0,
|
||||||
|
)
|
||||||
|
recording = self.record(num_samples=num_samples)
|
||||||
|
# Recording.data is either a list of 1-D arrays (one per channel) or a
|
||||||
|
# 2-D ndarray (channels × samples). Either way, index 0 is channel 0.
|
||||||
|
data = recording.data
|
||||||
|
return data[0] if hasattr(data, "__getitem__") else data
|
||||||
|
|
||||||
def stream_to_zmq(self, zmq_address, n_samples: int, buffer_size: Optional[int] = 10000):
|
def stream_to_zmq(self, zmq_address, n_samples: int, buffer_size: Optional[int] = 10000):
|
||||||
"""
|
"""
|
||||||
Stream iq samples as interleaved bytes via zmq.
|
Stream iq samples as interleaved bytes via zmq.
|
||||||
|
|
@ -282,7 +314,7 @@ class SDR(ABC):
|
||||||
elif num_samples is not None:
|
elif num_samples is not None:
|
||||||
self._num_samples_to_transmit = num_samples
|
self._num_samples_to_transmit = num_samples
|
||||||
elif tx_time is not None:
|
elif tx_time is not None:
|
||||||
self._num_samples_to_transmit = tx_time * self.tx_sample_rate
|
self._num_samples_to_transmit = int(tx_time * self.tx_sample_rate)
|
||||||
else:
|
else:
|
||||||
self._num_samples_to_transmit = len(recording)
|
self._num_samples_to_transmit = len(recording)
|
||||||
|
|
||||||
|
|
@ -529,3 +561,51 @@ class SDROverflowError(SDRError):
|
||||||
"""Buffer overflow detected."""
|
"""Buffer overflow detected."""
|
||||||
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class SdrDisconnectedError(SDRError):
|
||||||
|
"""Raised when the SDR device disappears mid-operation (USB unplug, network drop)."""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
# Substrings that strongly indicate a device has disappeared rather than a
|
||||||
|
# transient / recoverable error. Checked case-insensitively against str(exc).
|
||||||
|
_DISCONNECT_MARKERS = (
|
||||||
|
"no such device",
|
||||||
|
"device not found",
|
||||||
|
"not found",
|
||||||
|
"broken pipe",
|
||||||
|
"disconnected",
|
||||||
|
"no device",
|
||||||
|
"device unplugged",
|
||||||
|
"usb",
|
||||||
|
"i/o error",
|
||||||
|
"input/output error",
|
||||||
|
"errno 19", # ENODEV
|
||||||
|
"errno 5", # EIO
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def translate_disconnect(exc: BaseException) -> BaseException:
|
||||||
|
"""Return ``SdrDisconnectedError`` if *exc* looks like a USB/device drop, else *exc*.
|
||||||
|
|
||||||
|
Drivers wrap their native-API calls with::
|
||||||
|
|
||||||
|
try:
|
||||||
|
return self.radio.rx()
|
||||||
|
except Exception as exc:
|
||||||
|
raise translate_disconnect(exc) from exc
|
||||||
|
|
||||||
|
The caller (e.g. the streamer) can then catch ``SdrDisconnectedError``
|
||||||
|
specifically and report it to the hub rather than crashing the loop.
|
||||||
|
"""
|
||||||
|
if isinstance(exc, SdrDisconnectedError):
|
||||||
|
return exc
|
||||||
|
msg = str(exc).lower()
|
||||||
|
if any(marker in msg for marker in _DISCONNECT_MARKERS):
|
||||||
|
return SdrDisconnectedError(str(exc))
|
||||||
|
# OSError subclass with ENODEV / EIO errno is also a disconnect signal.
|
||||||
|
if isinstance(exc, OSError) and getattr(exc, "errno", None) in (5, 19):
|
||||||
|
return SdrDisconnectedError(str(exc))
|
||||||
|
return exc
|
||||||
|
|
|
||||||
|
|
@ -6,7 +6,7 @@ from typing import Optional
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import uhd
|
import uhd
|
||||||
|
|
||||||
from ria_toolkit_oss.datatypes.recording import Recording
|
from ria_toolkit_oss.data.recording import Recording
|
||||||
from ria_toolkit_oss.sdr.sdr import SDR, SDRParameterError
|
from ria_toolkit_oss.sdr.sdr import SDR, SDRParameterError
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -54,7 +54,7 @@ class USRP(SDR):
|
||||||
:param channel: The channel the USRP is set to.
|
:param channel: The channel the USRP is set to.
|
||||||
:type channel: int
|
:type channel: int
|
||||||
:param gain_mode: 'absolute' passes gain directly to the sdr,
|
:param gain_mode: 'absolute' passes gain directly to the sdr,
|
||||||
'relative' means that gain should be a negative value, and it will be subtracted from the max gain.
|
'relative' means that gain should be a negative value, and it will be subtracted from the max gain.
|
||||||
:type gain_mode: str
|
:type gain_mode: str
|
||||||
:param rx_buffer_size: Internal buffer size for receiving samples. Defaults to 960000.
|
:param rx_buffer_size: Internal buffer size for receiving samples. Defaults to 960000.
|
||||||
:type rx_buffer_size: int
|
:type rx_buffer_size: int
|
||||||
|
|
@ -285,7 +285,7 @@ class USRP(SDR):
|
||||||
:param channel: The channel the USRP is set to.
|
:param channel: The channel the USRP is set to.
|
||||||
:type channel: int
|
:type channel: int
|
||||||
:param gain_mode: 'absolute' passes gain directly to the sdr,
|
:param gain_mode: 'absolute' passes gain directly to the sdr,
|
||||||
'relative' means that gain should be a negative value, and it will be subtracted from the max gain.
|
'relative' means that gain should be a negative value, and it will be subtracted from the max gain.
|
||||||
:type gain_mode: str
|
:type gain_mode: str
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
|
||||||
5
src/ria_toolkit_oss/server/__init__.py
Normal file
5
src/ria_toolkit_oss/server/__init__.py
Normal file
|
|
@ -0,0 +1,5 @@
|
||||||
|
"""RT-OSS HTTP server for RIA Hub integration."""
|
||||||
|
|
||||||
|
from .app import create_app
|
||||||
|
|
||||||
|
__all__ = ["create_app"]
|
||||||
48
src/ria_toolkit_oss/server/app.py
Normal file
48
src/ria_toolkit_oss/server/app.py
Normal file
|
|
@ -0,0 +1,48 @@
|
||||||
|
"""FastAPI application factory for the RT-OSS HTTP server."""
|
||||||
|
|
||||||
|
from fastapi import Depends, FastAPI
|
||||||
|
|
||||||
|
from .auth import require_api_key
|
||||||
|
from .routers import conductor, inference
|
||||||
|
|
||||||
|
|
||||||
|
def create_app(api_key: str = "") -> FastAPI:
|
||||||
|
"""Create and configure the RT-OSS FastAPI application.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
api_key: Secret key required in the ``X-API-Key`` request header.
|
||||||
|
Pass an empty string to disable authentication (development only).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Configured FastAPI application instance.
|
||||||
|
"""
|
||||||
|
app = FastAPI(
|
||||||
|
title="RIA Toolkit OSS Server",
|
||||||
|
version="0.1.0",
|
||||||
|
description=(
|
||||||
|
"HTTP API for RT-OSS campaign orchestration and RF zone inference. "
|
||||||
|
"All endpoints (except /health) require the X-API-Key header when "
|
||||||
|
"an API key is configured."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
app.state.api_key = api_key
|
||||||
|
|
||||||
|
app.include_router(
|
||||||
|
conductor.router,
|
||||||
|
prefix="/conductor",
|
||||||
|
tags=["Conductor"],
|
||||||
|
dependencies=[Depends(require_api_key)],
|
||||||
|
)
|
||||||
|
app.include_router(
|
||||||
|
inference.router,
|
||||||
|
prefix="/inference",
|
||||||
|
tags=["Inference"],
|
||||||
|
dependencies=[Depends(require_api_key)],
|
||||||
|
)
|
||||||
|
|
||||||
|
@app.get("/health", tags=["Health"])
|
||||||
|
async def health():
|
||||||
|
"""Health check — always returns 200."""
|
||||||
|
return {"status": "ok"}
|
||||||
|
|
||||||
|
return app
|
||||||
36
src/ria_toolkit_oss/server/auth.py
Normal file
36
src/ria_toolkit_oss/server/auth.py
Normal file
|
|
@ -0,0 +1,36 @@
|
||||||
|
"""API key authentication dependency."""
|
||||||
|
|
||||||
|
import hmac
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from fastapi import Depends, HTTPException, Request, status
|
||||||
|
from fastapi.security import APIKeyHeader
|
||||||
|
|
||||||
|
_api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False)
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
async def require_api_key(
|
||||||
|
request: Request,
|
||||||
|
api_key: str | None = Depends(_api_key_header),
|
||||||
|
) -> None:
|
||||||
|
"""FastAPI dependency that enforces X-API-Key header authentication.
|
||||||
|
|
||||||
|
If no API key is configured on the server (empty string), all requests
|
||||||
|
are allowed — this is intended for local development only.
|
||||||
|
"""
|
||||||
|
expected: str = request.app.state.api_key
|
||||||
|
if not expected:
|
||||||
|
return # dev mode: no key set, allow all
|
||||||
|
if not hmac.compare_digest(api_key or "", expected):
|
||||||
|
client = getattr(request.client, "host", "unknown")
|
||||||
|
logger.warning(
|
||||||
|
"Authentication failure from %s — %s %s",
|
||||||
|
client,
|
||||||
|
request.method,
|
||||||
|
request.url.path,
|
||||||
|
)
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_403_FORBIDDEN,
|
||||||
|
detail="Invalid or missing API key",
|
||||||
|
)
|
||||||
47
src/ria_toolkit_oss/server/cli.py
Normal file
47
src/ria_toolkit_oss/server/cli.py
Normal file
|
|
@ -0,0 +1,47 @@
|
||||||
|
"""CLI entry point for the RT-OSS HTTP server.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
ria-server # default: 0.0.0.0:8080, no auth
|
||||||
|
RT_OSS_API_KEY=secret ria-server # enforce X-API-Key header
|
||||||
|
RT_OSS_PORT=9000 ria-server # custom port
|
||||||
|
|
||||||
|
Environment variables:
|
||||||
|
RT_OSS_API_KEY Shared secret for X-API-Key auth (empty = dev mode, no auth)
|
||||||
|
RT_OSS_PORT TCP port to listen on (default: 8080)
|
||||||
|
RT_OSS_HOST Bind address (default: 0.0.0.0)
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
|
|
||||||
|
def serve() -> None:
|
||||||
|
try:
|
||||||
|
import uvicorn
|
||||||
|
except ImportError:
|
||||||
|
raise SystemExit(
|
||||||
|
"uvicorn is required to run the RT-OSS server.\n" "Install it with: pip install 'ria-toolkit-oss[server]'"
|
||||||
|
)
|
||||||
|
|
||||||
|
from .app import create_app
|
||||||
|
|
||||||
|
api_key = os.environ.get("RT_OSS_API_KEY", "")
|
||||||
|
host = os.environ.get("RT_OSS_HOST", "0.0.0.0")
|
||||||
|
port = int(os.environ.get("RT_OSS_PORT", "8080"))
|
||||||
|
|
||||||
|
app = create_app(api_key=api_key)
|
||||||
|
|
||||||
|
if not api_key:
|
||||||
|
print(
|
||||||
|
"\n"
|
||||||
|
"╔══════════════════════════════════════════════════════════════╗\n"
|
||||||
|
"║ WARNING: RT_OSS_API_KEY is not set. ║\n"
|
||||||
|
"║ The server is running with NO authentication. ║\n"
|
||||||
|
"║ Anyone who can reach this port has full API access. ║\n"
|
||||||
|
"║ Set RT_OSS_API_KEY=<secret> before exposing to a network. ║\n"
|
||||||
|
"╚══════════════════════════════════════════════════════════════╝\n",
|
||||||
|
flush=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
uvicorn.run(app, host=host, port=port)
|
||||||
114
src/ria_toolkit_oss/server/models.py
Normal file
114
src/ria_toolkit_oss/server/models.py
Normal file
|
|
@ -0,0 +1,114 @@
|
||||||
|
"""Pydantic request and response models for the RT-OSS HTTP server."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from pydantic import BaseModel, field_validator
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Conductor
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class DeployRequest(BaseModel):
|
||||||
|
config: dict
|
||||||
|
|
||||||
|
|
||||||
|
class DeployResponse(BaseModel):
|
||||||
|
campaign_id: str
|
||||||
|
|
||||||
|
|
||||||
|
class CampaignStatusResponse(BaseModel):
|
||||||
|
campaign_id: str
|
||||||
|
status: str
|
||||||
|
config_name: str
|
||||||
|
progress: int
|
||||||
|
total_steps: int
|
||||||
|
started_at: float
|
||||||
|
ended_at: float | None = None
|
||||||
|
result: dict | None = None
|
||||||
|
error: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class CancelResponse(BaseModel):
|
||||||
|
campaign_id: str
|
||||||
|
cancelled: bool
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Inference
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class SdrConfig(BaseModel):
|
||||||
|
device: str
|
||||||
|
center_freq: float
|
||||||
|
sample_rate: float
|
||||||
|
gain: float | str = "auto"
|
||||||
|
|
||||||
|
|
||||||
|
class LoadModelRequest(BaseModel):
|
||||||
|
model_path: str
|
||||||
|
label_map: dict[str, int] # class_name -> class_index
|
||||||
|
|
||||||
|
@field_validator("model_path")
|
||||||
|
@classmethod
|
||||||
|
def validate_model_path(cls, v: str) -> str:
|
||||||
|
p = Path(v)
|
||||||
|
if ".." in p.parts:
|
||||||
|
raise ValueError("model_path must not contain path traversal components")
|
||||||
|
if p.suffix.lower() != ".onnx":
|
||||||
|
raise ValueError("model_path must point to an .onnx file")
|
||||||
|
# Resolve to catch symlink-based traversal; return the resolved absolute path
|
||||||
|
# so callers always work with the real filesystem location.
|
||||||
|
resolved = p.resolve()
|
||||||
|
if resolved.suffix.lower() != ".onnx":
|
||||||
|
raise ValueError("Resolved model_path must point to an .onnx file")
|
||||||
|
return str(resolved)
|
||||||
|
|
||||||
|
|
||||||
|
class LoadModelResponse(BaseModel):
|
||||||
|
loaded: bool
|
||||||
|
model_path: str
|
||||||
|
num_classes: int
|
||||||
|
|
||||||
|
|
||||||
|
class StartInferenceRequest(BaseModel):
|
||||||
|
sdr_config: SdrConfig
|
||||||
|
|
||||||
|
|
||||||
|
class StartInferenceResponse(BaseModel):
|
||||||
|
running: bool
|
||||||
|
|
||||||
|
|
||||||
|
class StopInferenceResponse(BaseModel):
|
||||||
|
stopped: bool
|
||||||
|
|
||||||
|
|
||||||
|
class ConfigureRequest(BaseModel):
|
||||||
|
"""Partial SDR reconfiguration — only supplied fields are updated."""
|
||||||
|
|
||||||
|
center_freq: float | None = None
|
||||||
|
sample_rate: float | None = None
|
||||||
|
gain: float | str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class ConfigureResponse(BaseModel):
|
||||||
|
configured: bool
|
||||||
|
|
||||||
|
|
||||||
|
class InferenceStatusResponse(BaseModel):
|
||||||
|
"""Latest inference result as returned by GET /inference/status.
|
||||||
|
|
||||||
|
When ``idle`` is True the radio is scanning but no signal was detected.
|
||||||
|
``device_id`` is the raw prediction label from the model's label map.
|
||||||
|
The frontend is responsible for mapping device_id to a human name and
|
||||||
|
determining whether the device is authorized.
|
||||||
|
"""
|
||||||
|
|
||||||
|
timestamp: float
|
||||||
|
idle: bool = False
|
||||||
|
device_id: str | None = None # prediction label; None when idle
|
||||||
|
confidence: float = 0.0
|
||||||
|
snr_db: float = 0.0
|
||||||
0
src/ria_toolkit_oss/server/routers/__init__.py
Normal file
0
src/ria_toolkit_oss/server/routers/__init__.py
Normal file
112
src/ria_toolkit_oss/server/routers/conductor.py
Normal file
112
src/ria_toolkit_oss/server/routers/conductor.py
Normal file
|
|
@ -0,0 +1,112 @@
|
||||||
|
"""Conductor routes: campaign deployment, status, and cancellation."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
import uuid
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from fastapi import APIRouter, HTTPException, status
|
||||||
|
|
||||||
|
from ria_toolkit_oss.orchestration.campaign import CampaignConfig
|
||||||
|
from ria_toolkit_oss.orchestration.executor import CampaignExecutor
|
||||||
|
|
||||||
|
from ..models import (
|
||||||
|
CampaignStatusResponse,
|
||||||
|
CancelResponse,
|
||||||
|
DeployRequest,
|
||||||
|
DeployResponse,
|
||||||
|
)
|
||||||
|
from ..state import (
|
||||||
|
CampaignCancelledError,
|
||||||
|
CampaignState,
|
||||||
|
get_campaign,
|
||||||
|
set_campaign,
|
||||||
|
update_campaign,
|
||||||
|
)
|
||||||
|
|
||||||
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
|
def _make_progress_cb(campaign_id: str, cancel_event: threading.Event):
|
||||||
|
def cb(step_index: int, total_steps: int, step_result: Any) -> None:
|
||||||
|
update_campaign(campaign_id, progress=step_index)
|
||||||
|
if cancel_event.is_set():
|
||||||
|
raise CampaignCancelledError(f"Cancelled at step {step_index}/{total_steps}")
|
||||||
|
|
||||||
|
return cb
|
||||||
|
|
||||||
|
|
||||||
|
def _run_campaign_thread(campaign_id: str, cfg: CampaignConfig) -> None:
|
||||||
|
state = get_campaign(campaign_id)
|
||||||
|
try:
|
||||||
|
result = CampaignExecutor(
|
||||||
|
config=cfg,
|
||||||
|
progress_cb=_make_progress_cb(campaign_id, state.cancel_event),
|
||||||
|
).run()
|
||||||
|
update_campaign(
|
||||||
|
campaign_id, status="completed", progress=cfg.total_steps(), result=result.to_dict(), ended_at=time.time()
|
||||||
|
)
|
||||||
|
except CampaignCancelledError:
|
||||||
|
update_campaign(campaign_id, status="cancelled", ended_at=time.time())
|
||||||
|
except Exception as e:
|
||||||
|
update_campaign(campaign_id, status="failed", error=str(e), ended_at=time.time())
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/deploy", response_model=DeployResponse)
|
||||||
|
async def deploy(request: DeployRequest):
|
||||||
|
"""Deploy a campaign config and start execution. Returns a ``campaign_id`` for polling.
|
||||||
|
Cancellation takes effect at step boundaries, not mid-capture.
|
||||||
|
|
||||||
|
External scripts are not permitted in server-deployed campaigns. Configure
|
||||||
|
transmitters without the ``script`` field, or run campaigns via the CLI.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
cfg = CampaignConfig.from_dict(request.config)
|
||||||
|
except (ValueError, KeyError) as e:
|
||||||
|
raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail=str(e))
|
||||||
|
|
||||||
|
if cfg.transmitters and any(t.script for t in cfg.transmitters):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
|
||||||
|
detail="External scripts are not permitted in server-deployed campaigns. "
|
||||||
|
"Remove the 'script' field from all transmitters, or run the campaign via the CLI.",
|
||||||
|
)
|
||||||
|
|
||||||
|
campaign_id = str(uuid.uuid4())
|
||||||
|
cancel_event = threading.Event()
|
||||||
|
thread = threading.Thread(target=_run_campaign_thread, args=(campaign_id, cfg), daemon=True)
|
||||||
|
set_campaign(
|
||||||
|
CampaignState(
|
||||||
|
campaign_id=campaign_id,
|
||||||
|
status="running",
|
||||||
|
config_name=cfg.name,
|
||||||
|
cancel_event=cancel_event,
|
||||||
|
thread=thread,
|
||||||
|
total_steps=cfg.total_steps(),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
thread.start()
|
||||||
|
return DeployResponse(campaign_id=campaign_id)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/status/{campaign_id}", response_model=CampaignStatusResponse)
|
||||||
|
async def get_status(campaign_id: str):
|
||||||
|
"""Get the status and progress of a deployed campaign."""
|
||||||
|
state = get_campaign(campaign_id)
|
||||||
|
if not state:
|
||||||
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=f"Campaign {campaign_id!r} not found")
|
||||||
|
return CampaignStatusResponse(**state.to_dict())
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/cancel/{campaign_id}", response_model=CancelResponse)
|
||||||
|
async def cancel(campaign_id: str):
|
||||||
|
"""Request cancellation. Takes effect at the next step boundary."""
|
||||||
|
state = get_campaign(campaign_id)
|
||||||
|
if not state:
|
||||||
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=f"Campaign {campaign_id!r} not found")
|
||||||
|
if state.status != "running":
|
||||||
|
return CancelResponse(campaign_id=campaign_id, cancelled=False)
|
||||||
|
state.cancel_event.set()
|
||||||
|
return CancelResponse(campaign_id=campaign_id, cancelled=True)
|
||||||
253
src/ria_toolkit_oss/server/routers/inference.py
Normal file
253
src/ria_toolkit_oss/server/routers/inference.py
Normal file
|
|
@ -0,0 +1,253 @@
|
||||||
|
"""Inference routes: model loading, inference loop control, and status polling."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from fastapi import APIRouter, HTTPException, status
|
||||||
|
from scipy.special import softmax
|
||||||
|
|
||||||
|
from ..models import (
|
||||||
|
ConfigureRequest,
|
||||||
|
ConfigureResponse,
|
||||||
|
InferenceStatusResponse,
|
||||||
|
LoadModelRequest,
|
||||||
|
LoadModelResponse,
|
||||||
|
StartInferenceRequest,
|
||||||
|
StartInferenceResponse,
|
||||||
|
StopInferenceResponse,
|
||||||
|
)
|
||||||
|
from ..state import InferenceState, get_inference, set_inference
|
||||||
|
|
||||||
|
router = APIRouter()
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
_INFERENCE_NUM_SAMPLES = 4096
|
||||||
|
|
||||||
|
# Prediction labels that mean "no signal detected" — UI should treat these as idle.
|
||||||
|
_IDLE_LABELS: frozenset[str] = frozenset({"noise", "idle", "no_signal", "unknown_protocol", "background"})
|
||||||
|
|
||||||
|
|
||||||
|
def _load_onnx_session(model_path: str):
|
||||||
|
try:
|
||||||
|
import onnxruntime as ort
|
||||||
|
except ImportError:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||||
|
detail="onnxruntime not installed. Install with: pip install ria-toolkit-oss[server]",
|
||||||
|
)
|
||||||
|
resolved = Path(model_path).resolve()
|
||||||
|
if not resolved.is_file():
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
|
||||||
|
detail=f"Model file not found: {model_path}",
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
return ort.InferenceSession(str(resolved), providers=["CPUExecutionProvider"])
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail=f"Failed to load ONNX model: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
def _preprocess_samples(samples: np.ndarray, expected_shape: tuple) -> np.ndarray:
|
||||||
|
"""Reshape complex IQ samples to float32 matching the model's expected input.
|
||||||
|
|
||||||
|
Supports ``(batch, 2*N)`` interleaved and ``(batch, 2, N)`` two-channel conventions.
|
||||||
|
"""
|
||||||
|
iq = samples.astype(np.complex64)
|
||||||
|
i_ch, q_ch = iq.real, iq.imag
|
||||||
|
|
||||||
|
if len(expected_shape) == 2:
|
||||||
|
n = expected_shape[1] // 2
|
||||||
|
interleaved = np.empty(expected_shape[1], dtype=np.float32)
|
||||||
|
interleaved[0::2] = i_ch[:n]
|
||||||
|
interleaved[1::2] = q_ch[:n]
|
||||||
|
return interleaved.reshape(1, -1)
|
||||||
|
elif len(expected_shape) == 3:
|
||||||
|
n = expected_shape[2]
|
||||||
|
return np.stack([i_ch[:n], q_ch[:n]], axis=0).astype(np.float32).reshape(1, 2, n)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported model input shape: {expected_shape}")
|
||||||
|
|
||||||
|
|
||||||
|
def _stop_current_inference(state: InferenceState, timeout: float = 5.0) -> None:
|
||||||
|
state.stop_event.set()
|
||||||
|
if state.thread and state.thread.is_alive():
|
||||||
|
state.thread.join(timeout=timeout)
|
||||||
|
if state.thread.is_alive():
|
||||||
|
logger.warning("Inference thread did not stop within %.1fs; SDR resources may not be released", timeout)
|
||||||
|
|
||||||
|
|
||||||
|
def _apply_sdr_config(sdr, config: dict) -> None:
|
||||||
|
"""Re-initialise the SDR receiver with updated parameters."""
|
||||||
|
gain = config.get("gain")
|
||||||
|
if gain == "auto":
|
||||||
|
gain = None
|
||||||
|
elif gain is not None:
|
||||||
|
gain = float(gain)
|
||||||
|
kwargs: dict = {}
|
||||||
|
if config.get("center_freq") is not None:
|
||||||
|
kwargs["center_frequency"] = float(config["center_freq"])
|
||||||
|
if config.get("sample_rate") is not None:
|
||||||
|
kwargs["sample_rate"] = float(config["sample_rate"])
|
||||||
|
if gain is not None:
|
||||||
|
kwargs["gain"] = gain
|
||||||
|
if kwargs:
|
||||||
|
sdr.init_rx(**kwargs, channel=0)
|
||||||
|
|
||||||
|
|
||||||
|
def _inference_loop(state: InferenceState, sdr) -> None:
|
||||||
|
from ria_toolkit_oss.orchestration.qa import estimate_snr_db
|
||||||
|
|
||||||
|
state.sdr = sdr
|
||||||
|
state.set_running(True)
|
||||||
|
session = state.session
|
||||||
|
input_name = session.get_inputs()[0].name
|
||||||
|
expected_shape = tuple(
|
||||||
|
d if isinstance(d, int) and d > 0 else _INFERENCE_NUM_SAMPLES for d in session.get_inputs()[0].shape
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
while not state.stop_event.is_set():
|
||||||
|
# Apply any pending SDR reconfiguration before the next capture.
|
||||||
|
pending = state.pop_pending_config()
|
||||||
|
if pending:
|
||||||
|
try:
|
||||||
|
_apply_sdr_config(sdr, pending)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("SDR reconfigure failed: %s", exc)
|
||||||
|
|
||||||
|
recording = sdr.record(num_samples=_INFERENCE_NUM_SAMPLES)
|
||||||
|
samples = recording.data[0] if recording.data.ndim > 1 else recording.data
|
||||||
|
snr_db = estimate_snr_db(samples)
|
||||||
|
|
||||||
|
try:
|
||||||
|
model_input = _preprocess_samples(samples, expected_shape)
|
||||||
|
logits = session.run(None, {input_name: model_input})[0][0].astype(np.float32)
|
||||||
|
probs = softmax(logits)
|
||||||
|
pred_idx = int(np.argmax(probs))
|
||||||
|
prediction = state.index_to_label.get(pred_idx, str(pred_idx))
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("Inference prediction failed: %s", exc)
|
||||||
|
continue
|
||||||
|
|
||||||
|
is_idle = prediction in _IDLE_LABELS
|
||||||
|
|
||||||
|
state.set_latest(
|
||||||
|
{
|
||||||
|
"timestamp": time.time(),
|
||||||
|
"idle": is_idle,
|
||||||
|
"device_id": prediction if not is_idle else None,
|
||||||
|
"confidence": round(float(probs[pred_idx]), 4),
|
||||||
|
"snr_db": round(snr_db, 2),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
state.sdr = None
|
||||||
|
try:
|
||||||
|
sdr.close()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
state.set_running(False)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/load", response_model=LoadModelResponse)
|
||||||
|
async def load_model(request: LoadModelRequest):
|
||||||
|
"""Load an ONNX model. Stops any running inference first.
|
||||||
|
|
||||||
|
``label_map`` maps class names to integer indices (e.g. ``{"iphone13_wifi_001": 0}``).
|
||||||
|
``enrolled_devices`` enriches status responses with human names and authorization flags.
|
||||||
|
"""
|
||||||
|
existing = get_inference()
|
||||||
|
if existing and existing.get_running():
|
||||||
|
_stop_current_inference(existing)
|
||||||
|
|
||||||
|
session = _load_onnx_session(request.model_path)
|
||||||
|
set_inference(
|
||||||
|
InferenceState(
|
||||||
|
model_path=request.model_path,
|
||||||
|
label_map=request.label_map,
|
||||||
|
index_to_label={v: k for k, v in request.label_map.items()},
|
||||||
|
session=session,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return LoadModelResponse(loaded=True, model_path=request.model_path, num_classes=len(request.label_map))
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/start", response_model=StartInferenceResponse)
|
||||||
|
async def start_inference(request: StartInferenceRequest):
|
||||||
|
"""Start continuous inference. Requires a model to be loaded first."""
|
||||||
|
state = get_inference()
|
||||||
|
if not state:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_409_CONFLICT, detail="No model loaded. Call POST /inference/load first."
|
||||||
|
)
|
||||||
|
if state.get_running():
|
||||||
|
raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail="Inference is already running.")
|
||||||
|
|
||||||
|
try:
|
||||||
|
from ria_toolkit_oss.orchestration.executor import _DEVICE_ALIASES
|
||||||
|
from ria_toolkit_oss.sdr import get_sdr_device
|
||||||
|
except ImportError as e:
|
||||||
|
raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail=f"SDR import failed: {e}")
|
||||||
|
|
||||||
|
sdr_cfg = request.sdr_config
|
||||||
|
# Merge any pending configure request on top of the start config.
|
||||||
|
pending = state.pop_pending_config() or {}
|
||||||
|
center_freq = float(pending.get("center_freq") or sdr_cfg.center_freq)
|
||||||
|
sample_rate = float(pending.get("sample_rate") or sdr_cfg.sample_rate)
|
||||||
|
raw_gain = pending.get("gain") if "gain" in pending else sdr_cfg.gain
|
||||||
|
gain = None if raw_gain == "auto" else float(raw_gain)
|
||||||
|
try:
|
||||||
|
sdr = get_sdr_device(_DEVICE_ALIASES.get(sdr_cfg.device.lower(), sdr_cfg.device.lower()))
|
||||||
|
sdr.init_rx(sample_rate=sample_rate, center_frequency=center_freq, gain=gain, channel=0)
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail=f"SDR initialisation failed: {e}")
|
||||||
|
|
||||||
|
state.stop_event.clear()
|
||||||
|
state.thread = threading.Thread(target=_inference_loop, args=(state, sdr), daemon=True)
|
||||||
|
state.thread.start()
|
||||||
|
return StartInferenceResponse(running=True)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/stop", response_model=StopInferenceResponse)
|
||||||
|
async def stop_inference():
|
||||||
|
"""Stop the running inference loop."""
|
||||||
|
state = get_inference()
|
||||||
|
if not state or not state.get_running():
|
||||||
|
return StopInferenceResponse(stopped=False)
|
||||||
|
_stop_current_inference(state)
|
||||||
|
return StopInferenceResponse(stopped=True)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/configure", response_model=ConfigureResponse)
|
||||||
|
async def configure_inference(request: ConfigureRequest):
|
||||||
|
"""Update SDR parameters (center_freq, sample_rate, gain) on the fly.
|
||||||
|
|
||||||
|
If inference is running the change is applied at the next capture boundary.
|
||||||
|
If inference is not running the config is stored and applied when it starts.
|
||||||
|
Only fields present in the request body are updated.
|
||||||
|
"""
|
||||||
|
state = get_inference()
|
||||||
|
if not state:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_409_CONFLICT,
|
||||||
|
detail="No model loaded. Call POST /inference/load first.",
|
||||||
|
)
|
||||||
|
pending = {k: v for k, v in request.model_dump().items() if v is not None}
|
||||||
|
if pending:
|
||||||
|
state.set_pending_config(pending)
|
||||||
|
return ConfigureResponse(configured=bool(pending))
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/status", response_model=InferenceStatusResponse | None)
|
||||||
|
async def inference_status():
|
||||||
|
"""Return the latest inference result, or null if no model is loaded."""
|
||||||
|
state = get_inference()
|
||||||
|
if not state:
|
||||||
|
return None
|
||||||
|
latest = state.get_latest()
|
||||||
|
return InferenceStatusResponse(**latest) if latest else None
|
||||||
121
src/ria_toolkit_oss/server/state.py
Normal file
121
src/ria_toolkit_oss/server/state.py
Normal file
|
|
@ -0,0 +1,121 @@
|
||||||
|
"""In-memory state for running campaigns and inference sessions."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
|
||||||
|
class CampaignCancelledError(Exception):
|
||||||
|
"""Raised by the progress callback when a cancel is requested."""
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class CampaignState:
|
||||||
|
campaign_id: str
|
||||||
|
status: str # "running" | "completed" | "failed" | "cancelled"
|
||||||
|
config_name: str
|
||||||
|
cancel_event: threading.Event
|
||||||
|
thread: threading.Thread
|
||||||
|
total_steps: int = 0
|
||||||
|
progress: int = 0
|
||||||
|
result: Optional[dict] = None
|
||||||
|
error: Optional[str] = None
|
||||||
|
started_at: float = field(default_factory=time.time)
|
||||||
|
ended_at: Optional[float] = None
|
||||||
|
|
||||||
|
def to_dict(self) -> dict:
|
||||||
|
return {
|
||||||
|
"campaign_id": self.campaign_id,
|
||||||
|
"status": self.status,
|
||||||
|
"config_name": self.config_name,
|
||||||
|
"progress": self.progress,
|
||||||
|
"total_steps": self.total_steps,
|
||||||
|
"result": self.result,
|
||||||
|
"error": self.error,
|
||||||
|
"started_at": self.started_at,
|
||||||
|
"ended_at": self.ended_at,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class InferenceState:
|
||||||
|
model_path: str
|
||||||
|
label_map: dict[str, int] # class_name -> class_index
|
||||||
|
index_to_label: dict[int, str] # reverse: class_index -> class_name
|
||||||
|
session: Any # onnxruntime.InferenceSession
|
||||||
|
stop_event: threading.Event = field(default_factory=threading.Event)
|
||||||
|
thread: Optional[threading.Thread] = None
|
||||||
|
sdr: Any = None # live SDR object while inference is running
|
||||||
|
running: bool = False
|
||||||
|
_lock: threading.Lock = field(default_factory=threading.Lock, repr=False)
|
||||||
|
_latest: Optional[dict] = field(default=None, repr=False)
|
||||||
|
_pending_sdr_config: Optional[dict] = field(default=None, repr=False)
|
||||||
|
|
||||||
|
def set_latest(self, result: dict) -> None:
|
||||||
|
with self._lock:
|
||||||
|
self._latest = result
|
||||||
|
|
||||||
|
def get_latest(self) -> Optional[dict]:
|
||||||
|
with self._lock:
|
||||||
|
return self._latest
|
||||||
|
|
||||||
|
def set_pending_config(self, config: dict) -> None:
|
||||||
|
with self._lock:
|
||||||
|
self._pending_sdr_config = config
|
||||||
|
|
||||||
|
def pop_pending_config(self) -> Optional[dict]:
|
||||||
|
with self._lock:
|
||||||
|
cfg = self._pending_sdr_config
|
||||||
|
self._pending_sdr_config = None
|
||||||
|
return cfg
|
||||||
|
|
||||||
|
def set_running(self, value: bool) -> None:
|
||||||
|
with self._lock:
|
||||||
|
self.running = value
|
||||||
|
|
||||||
|
def get_running(self) -> bool:
|
||||||
|
with self._lock:
|
||||||
|
return self.running
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Module-level stores
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
_campaigns: dict[str, CampaignState] = {}
|
||||||
|
_campaigns_lock = threading.Lock()
|
||||||
|
|
||||||
|
_inference: Optional[InferenceState] = None
|
||||||
|
_inference_lock = threading.Lock()
|
||||||
|
|
||||||
|
|
||||||
|
def get_campaign(campaign_id: str) -> Optional[CampaignState]:
|
||||||
|
with _campaigns_lock:
|
||||||
|
return _campaigns.get(campaign_id)
|
||||||
|
|
||||||
|
|
||||||
|
def set_campaign(state: CampaignState) -> None:
|
||||||
|
with _campaigns_lock:
|
||||||
|
_campaigns[state.campaign_id] = state
|
||||||
|
|
||||||
|
|
||||||
|
def update_campaign(campaign_id: str, **kwargs) -> None:
|
||||||
|
with _campaigns_lock:
|
||||||
|
state = _campaigns.get(campaign_id)
|
||||||
|
if state:
|
||||||
|
for k, v in kwargs.items():
|
||||||
|
setattr(state, k, v)
|
||||||
|
|
||||||
|
|
||||||
|
def get_inference() -> Optional[InferenceState]:
|
||||||
|
with _inference_lock:
|
||||||
|
return _inference
|
||||||
|
|
||||||
|
|
||||||
|
def set_inference(state: Optional[InferenceState]) -> None:
|
||||||
|
global _inference
|
||||||
|
with _inference_lock:
|
||||||
|
_inference = state
|
||||||
|
|
@ -2,6 +2,7 @@
|
||||||
.. todo:: Need to add some information here about signal generation and the signal generators in this module.
|
.. todo:: Need to add some information here about signal generation and the signal generators in this module.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import warnings
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
@ -10,7 +11,7 @@ from scipy.signal import butter
|
||||||
from scipy.signal import chirp as sci_chirp
|
from scipy.signal import chirp as sci_chirp
|
||||||
from scipy.signal import hilbert, lfilter
|
from scipy.signal import hilbert, lfilter
|
||||||
|
|
||||||
from ria_toolkit_oss.datatypes.recording import Recording
|
from ria_toolkit_oss.data.recording import Recording
|
||||||
|
|
||||||
|
|
||||||
def sine(
|
def sine(
|
||||||
|
|
@ -227,7 +228,7 @@ def noise(
|
||||||
|
|
||||||
# TODO figure out a better way to make it conform to [-1,1]
|
# TODO figure out a better way to make it conform to [-1,1]
|
||||||
if not np.array_equal(magnitude, magnitude2):
|
if not np.array_equal(magnitude, magnitude2):
|
||||||
print("Warning: clipping in basic_signal_generator.noise")
|
warnings.warn("basic_signal_generator.noise: magnitude clipped to [-1, 1]")
|
||||||
|
|
||||||
phase = np.random.uniform(low=0, high=2 * np.pi, size=length)
|
phase = np.random.uniform(low=0, high=2 * np.pi, size=length)
|
||||||
complex_awgn = magnitude2 * np.exp(1j * phase)
|
complex_awgn = magnitude2 * np.exp(1j * phase)
|
||||||
|
|
@ -268,6 +269,9 @@ def chirp(sample_rate: int, num_samples: int, center_frequency: Optional[float]
|
||||||
.. todo:: Usage examples coming soon!
|
.. todo:: Usage examples coming soon!
|
||||||
"""
|
"""
|
||||||
# Ensure that the generated chirp signal remains within a safe frequency range to avoid aliasing.
|
# Ensure that the generated chirp signal remains within a safe frequency range to avoid aliasing.
|
||||||
|
if num_samples < 2:
|
||||||
|
raise ValueError("num_samples must be >= 2 for chirp generation")
|
||||||
|
|
||||||
chirp_start_frequency = center_frequency - sample_rate / 4
|
chirp_start_frequency = center_frequency - sample_rate / 4
|
||||||
chirp_end_frequency = center_frequency + sample_rate / 4
|
chirp_end_frequency = center_frequency + sample_rate / 4
|
||||||
|
|
||||||
|
|
@ -307,6 +311,9 @@ def lfm_chirp_complex(
|
||||||
down_part = np.flip(up_part)
|
down_part = np.flip(up_part)
|
||||||
baseband_chirp = np.concatenate([up_part, down_part])
|
baseband_chirp = np.concatenate([up_part, down_part])
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown chirp_type '{chirp_type}'. Must be 'up', 'down', or 'up_down'.")
|
||||||
|
|
||||||
# Generate the full signal by tiling the windowed chirp
|
# Generate the full signal by tiling the windowed chirp
|
||||||
num_chirps = round(total_time / chirp_period)
|
num_chirps = round(total_time / chirp_period)
|
||||||
full_signal = np.tile(baseband_chirp, num_chirps)
|
full_signal = np.tile(baseband_chirp, num_chirps)
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,7 @@
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from utils.signal.block_generator.block import Block
|
|
||||||
from utils.signal.block_generator.data_types import DataType
|
from ria_toolkit_oss.signal.block_generator.block import Block
|
||||||
|
from ria_toolkit_oss.signal.block_generator.data_types import DataType
|
||||||
|
|
||||||
|
|
||||||
class FrequencyUpConversion(Block):
|
class FrequencyUpConversion(Block):
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
from ria_toolkit_oss.datatypes.recording import Recording
|
from ria_toolkit_oss.data.recording import Recording
|
||||||
from ria_toolkit_oss.signal.block_generator.generators.signal_generator import (
|
from ria_toolkit_oss.signal.block_generator.generators.signal_generator import (
|
||||||
SignalGenerator,
|
SignalGenerator,
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
from ria_toolkit_oss.datatypes.recording import Recording
|
from ria_toolkit_oss.data.recording import Recording
|
||||||
from ria_toolkit_oss.signal.block_generator.generators.signal_generator import (
|
from ria_toolkit_oss.signal.block_generator.generators.signal_generator import (
|
||||||
SignalGenerator,
|
SignalGenerator,
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
from ria_toolkit_oss.datatypes.recording import Recording
|
from ria_toolkit_oss.data.recording import Recording
|
||||||
from ria_toolkit_oss.signal.block_generator.generators.signal_generator import (
|
from ria_toolkit_oss.signal.block_generator.generators.signal_generator import (
|
||||||
SignalGenerator,
|
SignalGenerator,
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
from ria_toolkit_oss.datatypes import Recording
|
from ria_toolkit_oss.data import Recording
|
||||||
from ria_toolkit_oss.signal import Recordable
|
from ria_toolkit_oss.signal import Recordable
|
||||||
from ria_toolkit_oss.signal.block_generator.block import Block
|
from ria_toolkit_oss.signal.block_generator.block import Block
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -4,7 +4,7 @@ from datetime import datetime
|
||||||
import click
|
import click
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from ria_toolkit_oss.datatypes.recording import Recording
|
from ria_toolkit_oss.data.recording import Recording
|
||||||
from ria_toolkit_oss.signal.block_generator.mapping.mapper import Mapper
|
from ria_toolkit_oss.signal.block_generator.mapping.mapper import Mapper
|
||||||
from ria_toolkit_oss.signal.block_generator.multirate.upsampling import Upsampling
|
from ria_toolkit_oss.signal.block_generator.multirate.upsampling import Upsampling
|
||||||
from ria_toolkit_oss.signal.block_generator.pulse_shaping.raised_cosine_filter import (
|
from ria_toolkit_oss.signal.block_generator.pulse_shaping.raised_cosine_filter import (
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
from ria_toolkit_oss.datatypes import Recording
|
from ria_toolkit_oss.data import Recording
|
||||||
from ria_toolkit_oss.signal.block_generator.data_types import DataType
|
from ria_toolkit_oss.signal.block_generator.data_types import DataType
|
||||||
from ria_toolkit_oss.signal.block_generator.recordable_block import RecordableBlock
|
from ria_toolkit_oss.signal.block_generator.recordable_block import RecordableBlock
|
||||||
from ria_toolkit_oss.signal.block_generator.source_block import SourceBlock
|
from ria_toolkit_oss.signal.block_generator.source_block import SourceBlock
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,6 @@
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
|
|
||||||
from ria_toolkit_oss.datatypes import Recording
|
from ria_toolkit_oss.data import Recording
|
||||||
|
|
||||||
|
|
||||||
class Recordable(ABC):
|
class Recordable(ABC):
|
||||||
|
|
|
||||||
|
|
@ -5,12 +5,13 @@ and return a corresponding numpy.ndarray with the impairment model applied;
|
||||||
we call the latter the impaired data.
|
we call the latter the impaired data.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import warnings
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from numpy.typing import ArrayLike
|
from numpy.typing import ArrayLike
|
||||||
|
|
||||||
from ria_toolkit_oss.datatypes.recording import Recording
|
from ria_toolkit_oss.data.recording import Recording
|
||||||
from ria_toolkit_oss.utils.array_conversion import convert_to_2xn
|
from ria_toolkit_oss.utils.array_conversion import convert_to_2xn
|
||||||
|
|
||||||
# TODO: For round 2 of index generation, should j be at min 2 spots away from where it was to prevent adjacent patches.
|
# TODO: For round 2 of index generation, should j be at min 2 spots away from where it was to prevent adjacent patches.
|
||||||
|
|
@ -28,7 +29,7 @@ def generate_awgn(signal: ArrayLike | Recording, snr: Optional[float] = 1) -> np
|
||||||
|
|
||||||
:param signal: Input IQ data as a complex CxN array or `Recording`, where C is the number of channels and N
|
:param signal: Input IQ data as a complex CxN array or `Recording`, where C is the number of channels and N
|
||||||
is the length of the IQ examples.
|
is the length of the IQ examples.
|
||||||
:type signal: array_like or ria_toolkit_oss.datatypes.Recording
|
:type signal: array_like or ria_toolkit_oss.data.Recording
|
||||||
:param snr: The signal-to-noise ratio in dB. Default is 1.
|
:param snr: The signal-to-noise ratio in dB. Default is 1.
|
||||||
:type snr: float, optional
|
:type snr: float, optional
|
||||||
|
|
||||||
|
|
@ -36,7 +37,7 @@ def generate_awgn(signal: ArrayLike | Recording, snr: Optional[float] = 1) -> np
|
||||||
|
|
||||||
:return: A numpy array representing the generated noise which matches the SNR of `signal`. If `signal` is a
|
:return: A numpy array representing the generated noise which matches the SNR of `signal`. If `signal` is a
|
||||||
Recording, returns a Recording object with its `data` attribute containing the generated noise array.
|
Recording, returns a Recording object with its `data` attribute containing the generated noise array.
|
||||||
:rtype: np.ndarray or ria_toolkit_oss.datatypes.Recording
|
:rtype: np.ndarray or ria_toolkit_oss.data.Recording
|
||||||
|
|
||||||
>>> rec = Recording(data=[[2 + 5j, 1 + 8j]])
|
>>> rec = Recording(data=[[2 + 5j, 1 + 8j]])
|
||||||
>>> new_rec = generate_awgn(rec)
|
>>> new_rec = generate_awgn(rec)
|
||||||
|
|
@ -58,13 +59,14 @@ def generate_awgn(signal: ArrayLike | Recording, snr: Optional[float] = 1) -> np
|
||||||
|
|
||||||
# Calculate the RMS power of the signal to solve for the RMS power of the noise
|
# Calculate the RMS power of the signal to solve for the RMS power of the noise
|
||||||
signal_rms_power = np.sqrt(np.mean(np.abs(data) ** 2))
|
signal_rms_power = np.sqrt(np.mean(np.abs(data) ** 2))
|
||||||
noise_rms_power = signal_rms_power / snr_linear
|
noise_rms_power = signal_rms_power / np.sqrt(snr_linear)
|
||||||
|
|
||||||
# Generate the AWGN noise which has the same shape as data
|
# Generate complex AWGN: independent Gaussian I and Q components.
|
||||||
variance = noise_rms_power**2
|
# Each component has std = noise_rms_power / sqrt(2) so total power = noise_rms_power^2.
|
||||||
magnitude = np.random.normal(loc=0, scale=np.sqrt(variance), size=(c, n))
|
component_std = noise_rms_power / np.sqrt(2)
|
||||||
phase = np.random.uniform(low=0, high=2 * np.pi, size=(c, n))
|
complex_awgn = np.random.normal(scale=component_std, size=(c, n)) + 1j * np.random.normal(
|
||||||
complex_awgn = magnitude * np.exp(1j * phase)
|
scale=component_std, size=(c, n)
|
||||||
|
)
|
||||||
|
|
||||||
if isinstance(signal, Recording):
|
if isinstance(signal, Recording):
|
||||||
return Recording(data=complex_awgn, metadata=signal.metadata)
|
return Recording(data=complex_awgn, metadata=signal.metadata)
|
||||||
|
|
@ -78,14 +80,14 @@ def time_reversal(signal: ArrayLike | Recording) -> np.ndarray | Recording:
|
||||||
|
|
||||||
:param signal: Input IQ data as a complex CxN array or `Recording`, where C is the number of channels and N
|
:param signal: Input IQ data as a complex CxN array or `Recording`, where C is the number of channels and N
|
||||||
is the length of the IQ examples.
|
is the length of the IQ examples.
|
||||||
:type signal: array_like or ria_toolkit_oss.datatypes.Recording
|
:type signal: array_like or ria_toolkit_oss.data.Recording
|
||||||
|
|
||||||
:raises ValueError: If `signal` is not CxN complex.
|
:raises ValueError: If `signal` is not CxN complex.
|
||||||
|
|
||||||
:return: A numpy array containing the reversed I and Q data samples if `signal` is an array.
|
:return: A numpy array containing the reversed I and Q data samples if `signal` is an array.
|
||||||
If `signal` is a `Recording`, returns a `Recording` object with its `data` attribute containing the
|
If `signal` is a `Recording`, returns a `Recording` object with its `data` attribute containing the
|
||||||
reversed array.
|
reversed array.
|
||||||
:rtype: np.ndarray or ria_toolkit_oss.datatypes.Recording
|
:rtype: np.ndarray or ria_toolkit_oss.data.Recording
|
||||||
|
|
||||||
>>> rec = Recording(data=[[1+2j, 3+4j, 5+6j]])
|
>>> rec = Recording(data=[[1+2j, 3+4j, 5+6j]])
|
||||||
>>> new_rec = time_reversal(rec)
|
>>> new_rec = time_reversal(rec)
|
||||||
|
|
@ -121,14 +123,14 @@ def spectral_inversion(signal: ArrayLike | Recording) -> np.ndarray | Recording:
|
||||||
|
|
||||||
:param signal: Input IQ data as a complex CxN array or `Recording`, where C is the number of channels and N
|
:param signal: Input IQ data as a complex CxN array or `Recording`, where C is the number of channels and N
|
||||||
is the length of the IQ examples.
|
is the length of the IQ examples.
|
||||||
:type signal: array_like or ria_toolkit_oss.datatypes.Recording
|
:type signal: array_like or ria_toolkit_oss.data.Recording
|
||||||
|
|
||||||
:raises ValueError: If `signal` is not CxN complex.
|
:raises ValueError: If `signal` is not CxN complex.
|
||||||
|
|
||||||
:return: A numpy array containing the original I and negated Q data samples if `signal` is an array.
|
:return: A numpy array containing the original I and negated Q data samples if `signal` is an array.
|
||||||
If `signal` is a `Recording`, returns a `Recording` object with its `data` attribute containing the
|
If `signal` is a `Recording`, returns a `Recording` object with its `data` attribute containing the
|
||||||
inverted array.
|
inverted array.
|
||||||
:rtype: np.ndarray or ria_toolkit_oss.datatypes.Recording
|
:rtype: np.ndarray or ria_toolkit_oss.data.Recording
|
||||||
|
|
||||||
>>> rec = Recording(data=[[0+45j, 2-10j]])
|
>>> rec = Recording(data=[[0+45j, 2-10j]])
|
||||||
>>> new_rec = spectral_inversion(rec)
|
>>> new_rec = spectral_inversion(rec)
|
||||||
|
|
@ -163,14 +165,14 @@ def channel_swap(signal: ArrayLike | Recording) -> np.ndarray | Recording:
|
||||||
|
|
||||||
:param signal: Input IQ data as a complex CxN array or `Recording`, where C is the number of channels and N
|
:param signal: Input IQ data as a complex CxN array or `Recording`, where C is the number of channels and N
|
||||||
is the length of the IQ examples.
|
is the length of the IQ examples.
|
||||||
:type signal: array_like or ria_toolkit_oss.datatypes.Recording
|
:type signal: array_like or ria_toolkit_oss.data.Recording
|
||||||
|
|
||||||
:raises ValueError: If `signal` is not CxN complex.
|
:raises ValueError: If `signal` is not CxN complex.
|
||||||
|
|
||||||
:return: A numpy array containing the swapped I and Q data samples if `signal` is an array.
|
:return: A numpy array containing the swapped I and Q data samples if `signal` is an array.
|
||||||
If `signal` is a `Recording`, returns a `Recording` object with its `data` attribute containing the
|
If `signal` is a `Recording`, returns a `Recording` object with its `data` attribute containing the
|
||||||
swapped array.
|
swapped array.
|
||||||
:rtype: np.ndarray or ria_toolkit_oss.datatypes.Recording
|
:rtype: np.ndarray or ria_toolkit_oss.data.Recording
|
||||||
|
|
||||||
>>> rec = Recording(data=[[10+20j, 7+35j]])
|
>>> rec = Recording(data=[[10+20j, 7+35j]])
|
||||||
>>> new_rec = channel_swap(rec)
|
>>> new_rec = channel_swap(rec)
|
||||||
|
|
@ -205,14 +207,14 @@ def amplitude_reversal(signal: ArrayLike | Recording) -> np.ndarray | Recording:
|
||||||
|
|
||||||
:param signal: Input IQ data as a complex CxN array or `Recording`, where C is the number of channels and N
|
:param signal: Input IQ data as a complex CxN array or `Recording`, where C is the number of channels and N
|
||||||
is the length of the IQ examples.
|
is the length of the IQ examples.
|
||||||
:type signal: array_like or ria_toolkit_oss.datatypes.Recording
|
:type signal: array_like or ria_toolkit_oss.data.Recording
|
||||||
|
|
||||||
:raises ValueError: If `signal` is not CxN complex.
|
:raises ValueError: If `signal` is not CxN complex.
|
||||||
|
|
||||||
:return: A numpy array containing the negated I and Q data samples if `signal` is an array.
|
:return: A numpy array containing the negated I and Q data samples if `signal` is an array.
|
||||||
If `signal` is a `Recording`, returns a `Recording` object with its `data` attribute containing the
|
If `signal` is a `Recording`, returns a `Recording` object with its `data` attribute containing the
|
||||||
negated array.
|
negated array.
|
||||||
:rtype: np.ndarray or ria_toolkit_oss.datatypes.Recording
|
:rtype: np.ndarray or ria_toolkit_oss.data.Recording
|
||||||
|
|
||||||
>>> rec = Recording(data=[[4-3j, -5-2j, -9+1j]])
|
>>> rec = Recording(data=[[4-3j, -5-2j, -9+1j]])
|
||||||
>>> new_rec = amplitude_reversal(rec)
|
>>> new_rec = amplitude_reversal(rec)
|
||||||
|
|
@ -251,7 +253,7 @@ def drop_samples( # noqa: C901 # TODO: Simplify function
|
||||||
|
|
||||||
:param signal: Input IQ data as a complex CxN array or `Recording`, where C is the number of channels and N
|
:param signal: Input IQ data as a complex CxN array or `Recording`, where C is the number of channels and N
|
||||||
is the length of the IQ examples.
|
is the length of the IQ examples.
|
||||||
:type signal: array_like or ria_toolkit_oss.datatypes.Recording
|
:type signal: array_like or ria_toolkit_oss.data.Recording
|
||||||
:param max_section_size: Maximum allowable size of the section to be dropped and replaced. Default is 2.
|
:param max_section_size: Maximum allowable size of the section to be dropped and replaced. Default is 2.
|
||||||
:type max_section_size: int, optional
|
:type max_section_size: int, optional
|
||||||
:param fill_type: Fill option used to replace dropped section of data (back-fill, front-fill, mean, zeros).
|
:param fill_type: Fill option used to replace dropped section of data (back-fill, front-fill, mean, zeros).
|
||||||
|
|
@ -273,7 +275,7 @@ def drop_samples( # noqa: C901 # TODO: Simplify function
|
||||||
:return: A numpy array containing the I and Q data samples with replaced subsections if
|
:return: A numpy array containing the I and Q data samples with replaced subsections if
|
||||||
`signal` is an array. If `signal` is a `Recording`, returns a `Recording` object with its `data`
|
`signal` is an array. If `signal` is a `Recording`, returns a `Recording` object with its `data`
|
||||||
attribute containing the array with dropped samples.
|
attribute containing the array with dropped samples.
|
||||||
:rtype: np.ndarray or ria_toolkit_oss.datatypes.Recording
|
:rtype: np.ndarray or ria_toolkit_oss.data.Recording
|
||||||
|
|
||||||
>>> rec = Recording(data=[[2+5j, 1+8j, 6+4j, 3+7j, 4+9j]])
|
>>> rec = Recording(data=[[2+5j, 1+8j, 6+4j, 3+7j, 4+9j]])
|
||||||
>>> new_rec = drop_samples(rec)
|
>>> new_rec = drop_samples(rec)
|
||||||
|
|
@ -344,7 +346,7 @@ def quantize_tape(
|
||||||
|
|
||||||
:param signal: Input IQ data as a complex CxN array or `Recording`, where C is the number of channels and N
|
:param signal: Input IQ data as a complex CxN array or `Recording`, where C is the number of channels and N
|
||||||
is the length of the IQ examples.
|
is the length of the IQ examples.
|
||||||
:type signal: array_like or ria_toolkit_oss.datatypes.Recording
|
:type signal: array_like or ria_toolkit_oss.data.Recording
|
||||||
:param bin_number: The number of bins the signal should be divided into. Default is 4.
|
:param bin_number: The number of bins the signal should be divided into. Default is 4.
|
||||||
:type bin_number: int, optional
|
:type bin_number: int, optional
|
||||||
:param rounding_type: The type of rounding applied during processing. Default is "floor".
|
:param rounding_type: The type of rounding applied during processing. Default is "floor".
|
||||||
|
|
@ -360,7 +362,7 @@ def quantize_tape(
|
||||||
:return: A numpy array containing the quantized I and Q data samples if `signal` is an array.
|
:return: A numpy array containing the quantized I and Q data samples if `signal` is an array.
|
||||||
If `signal` is a `Recording`, returns a `Recording` object with its `data` attribute containing
|
If `signal` is a `Recording`, returns a `Recording` object with its `data` attribute containing
|
||||||
the quantized array.
|
the quantized array.
|
||||||
:rtype: np.ndarray or ria_toolkit_oss.datatypes.Recording
|
:rtype: np.ndarray or ria_toolkit_oss.data.Recording
|
||||||
|
|
||||||
>>> rec = Recording(data=[[1+1j, 4+4j, 1+2j, 1+4j]])
|
>>> rec = Recording(data=[[1+1j, 4+4j, 1+2j, 1+4j]])
|
||||||
>>> new_rec = quantize_tape(rec)
|
>>> new_rec = quantize_tape(rec)
|
||||||
|
|
@ -378,7 +380,8 @@ def quantize_tape(
|
||||||
raise ValueError("signal must be CxN complex.")
|
raise ValueError("signal must be CxN complex.")
|
||||||
|
|
||||||
if rounding_type not in {"ceiling", "floor"}:
|
if rounding_type not in {"ceiling", "floor"}:
|
||||||
raise UserWarning('rounding_type must be either "floor" or "ceiling", floor has been selected by default')
|
warnings.warn('rounding_type must be either "floor" or "ceiling", floor has been selected by default')
|
||||||
|
rounding_type = "floor"
|
||||||
|
|
||||||
if c == 1:
|
if c == 1:
|
||||||
iq_data = convert_to_2xn(data)
|
iq_data = convert_to_2xn(data)
|
||||||
|
|
@ -418,7 +421,7 @@ def quantize_parts(
|
||||||
|
|
||||||
:param signal: Input IQ data as a complex CxN array or `Recording`, where C is the number of channels and N
|
:param signal: Input IQ data as a complex CxN array or `Recording`, where C is the number of channels and N
|
||||||
is the length of the IQ examples.
|
is the length of the IQ examples.
|
||||||
:type signal: array_like or ria_toolkit_oss.datatypes.Recording
|
:type signal: array_like or ria_toolkit_oss.data.Recording
|
||||||
:param max_section_size: Maximum allowable size of the section to be quantized. Default is 2.
|
:param max_section_size: Maximum allowable size of the section to be quantized. Default is 2.
|
||||||
:type max_section_size: int, optional
|
:type max_section_size: int, optional
|
||||||
:param bin_number: The number of bins the signal should be divided into. Default is 4.
|
:param bin_number: The number of bins the signal should be divided into. Default is 4.
|
||||||
|
|
@ -436,7 +439,7 @@ def quantize_parts(
|
||||||
:return: A numpy array containing the I and Q data samples with quantized subsections if `signal`
|
:return: A numpy array containing the I and Q data samples with quantized subsections if `signal`
|
||||||
is an array. If `signal` is a `Recording`, returns a `Recording` object with its `data` attribute
|
is an array. If `signal` is a `Recording`, returns a `Recording` object with its `data` attribute
|
||||||
containing the partially quantized array.
|
containing the partially quantized array.
|
||||||
:rtype: np.ndarray or ria_toolkit_oss.datatypes.Recording
|
:rtype: np.ndarray or ria_toolkit_oss.data.Recording
|
||||||
|
|
||||||
>>> rec = Recording(data=[[2+5j, 1+8j, 6+4j, 3+7j, 4+9j]])
|
>>> rec = Recording(data=[[2+5j, 1+8j, 6+4j, 3+7j, 4+9j]])
|
||||||
>>> new_rec = quantize_parts(rec)
|
>>> new_rec = quantize_parts(rec)
|
||||||
|
|
@ -455,7 +458,8 @@ def quantize_parts(
|
||||||
raise ValueError("signal must be CxN complex.")
|
raise ValueError("signal must be CxN complex.")
|
||||||
|
|
||||||
if rounding_type not in {"ceiling", "floor"}:
|
if rounding_type not in {"ceiling", "floor"}:
|
||||||
raise UserWarning('rounding_type must be either "floor" or "ceiling", floor has been selected by default')
|
warnings.warn('rounding_type must be either "floor" or "ceiling", floor has been selected by default')
|
||||||
|
rounding_type = "floor"
|
||||||
|
|
||||||
if c == 1:
|
if c == 1:
|
||||||
iq_data = convert_to_2xn(data)
|
iq_data = convert_to_2xn(data)
|
||||||
|
|
@ -506,7 +510,7 @@ def magnitude_rescale(
|
||||||
|
|
||||||
:param signal: Input IQ data as a complex CxN array or `Recording`, where C is the number of channels and N
|
:param signal: Input IQ data as a complex CxN array or `Recording`, where C is the number of channels and N
|
||||||
is the length of the IQ examples.
|
is the length of the IQ examples.
|
||||||
:type signal: array_like or ria_toolkit_oss.datatypes.Recording
|
:type signal: array_like or ria_toolkit_oss.data.Recording
|
||||||
:param starting_bounds: The bounds (inclusive) as indices in which the starting position of the rescaling occurs.
|
:param starting_bounds: The bounds (inclusive) as indices in which the starting position of the rescaling occurs.
|
||||||
Default is None, but if user does not assign any bounds, the bounds become (random index, N-1).
|
Default is None, but if user does not assign any bounds, the bounds become (random index, N-1).
|
||||||
:type starting_bounds: tuple, optional
|
:type starting_bounds: tuple, optional
|
||||||
|
|
@ -518,7 +522,7 @@ def magnitude_rescale(
|
||||||
:return: A numpy array containing the I and Q data samples with the rescaled magnitude after the random
|
:return: A numpy array containing the I and Q data samples with the rescaled magnitude after the random
|
||||||
starting point if `signal` is an array. If `signal` is a `Recording`, returns a `Recording`
|
starting point if `signal` is an array. If `signal` is a `Recording`, returns a `Recording`
|
||||||
object with its `data` attribute containing the rescaled array.
|
object with its `data` attribute containing the rescaled array.
|
||||||
:rtype: np.ndarray or ria_toolkit_oss.datatypes.Recording
|
:rtype: np.ndarray or ria_toolkit_oss.data.Recording
|
||||||
|
|
||||||
>>> rec = Recording(data=[[2+5j, 1+8j, 6+4j, 3+7j, 4+9j]])
|
>>> rec = Recording(data=[[2+5j, 1+8j, 6+4j, 3+7j, 4+9j]])
|
||||||
>>> new_rec = magniute_rescale(rec)
|
>>> new_rec = magniute_rescale(rec)
|
||||||
|
|
@ -567,7 +571,7 @@ def cut_out( # noqa: C901 # TODO: Simplify function
|
||||||
|
|
||||||
:param signal: Input IQ data as a complex CxN array or `Recording`, where C is the number of channels and N
|
:param signal: Input IQ data as a complex CxN array or `Recording`, where C is the number of channels and N
|
||||||
is the length of the IQ examples.
|
is the length of the IQ examples.
|
||||||
:type signal: array_like or ria_toolkit_oss.datatypes.Recording
|
:type signal: array_like or ria_toolkit_oss.data.Recording
|
||||||
:param max_section_size: Maximum allowable size of the section to be quantized. Default is 3.
|
:param max_section_size: Maximum allowable size of the section to be quantized. Default is 3.
|
||||||
:type max_section_size: int, optional
|
:type max_section_size: int, optional
|
||||||
:param fill_type: Fill option used to replace cutout section of data (zeros, ones, low-snr, avg-snr-1, avg-snr-2).
|
:param fill_type: Fill option used to replace cutout section of data (zeros, ones, low-snr, avg-snr-1, avg-snr-2).
|
||||||
|
|
@ -592,7 +596,7 @@ def cut_out( # noqa: C901 # TODO: Simplify function
|
||||||
:return: A numpy array containing the I and Q data samples with random sections cut out and replaced according to
|
:return: A numpy array containing the I and Q data samples with random sections cut out and replaced according to
|
||||||
`fill_type` if `signal` is an array. If `signal` is a `Recording`, returns a `Recording` object
|
`fill_type` if `signal` is an array. If `signal` is a `Recording`, returns a `Recording` object
|
||||||
with its `data` attribute containing the cut out and replaced array.
|
with its `data` attribute containing the cut out and replaced array.
|
||||||
:rtype: np.ndarray or ria_toolkit_oss.datatypes.Recording
|
:rtype: np.ndarray or ria_toolkit_oss.data.Recording
|
||||||
|
|
||||||
>>> rec = Recording(data=[[2+5j, 1+8j, 6+4j, 3+7j, 4+9j]])
|
>>> rec = Recording(data=[[2+5j, 1+8j, 6+4j, 3+7j, 4+9j]])
|
||||||
>>> new_rec = cut_out(rec)
|
>>> new_rec = cut_out(rec)
|
||||||
|
|
@ -610,8 +614,11 @@ def cut_out( # noqa: C901 # TODO: Simplify function
|
||||||
raise ValueError("signal must be CxN complex.")
|
raise ValueError("signal must be CxN complex.")
|
||||||
|
|
||||||
if fill_type not in {"zeros", "ones", "low-snr", "avg-snr", "high-snr"}:
|
if fill_type not in {"zeros", "ones", "low-snr", "avg-snr", "high-snr"}:
|
||||||
raise UserWarning("""fill_type must be "zeros", "ones", "low-snr", "avg-snr", or "high-snr",
|
warnings.warn(
|
||||||
"ones" has been selected by default""")
|
'fill_type must be "zeros", "ones", "low-snr", "avg-snr", or "high-snr", '
|
||||||
|
'"ones" has been selected by default'
|
||||||
|
)
|
||||||
|
fill_type = "ones"
|
||||||
|
|
||||||
if max_section_size < 1 or max_section_size >= n:
|
if max_section_size < 1 or max_section_size >= n:
|
||||||
raise ValueError("max_section_size must be at least 1 and must be less than the length of signal.")
|
raise ValueError("max_section_size must be at least 1 and must be less than the length of signal.")
|
||||||
|
|
@ -659,7 +666,7 @@ def patch_shuffle(signal: ArrayLike | Recording, max_patch_size: Optional[int] =
|
||||||
|
|
||||||
:param signal: Input IQ data as a complex CxN array or `Recording`, where C is the number of channels and N
|
:param signal: Input IQ data as a complex CxN array or `Recording`, where C is the number of channels and N
|
||||||
is the length of the IQ examples.
|
is the length of the IQ examples.
|
||||||
:type signal: array_like or ria_toolkit_oss.datatypes.Recording
|
:type signal: array_like or ria_toolkit_oss.data.Recording
|
||||||
:param max_patch_size: Maximum allowable patch size of the data that can be shuffled. Default is 3.
|
:param max_patch_size: Maximum allowable patch size of the data that can be shuffled. Default is 3.
|
||||||
:type max_patch_size: int, optional
|
:type max_patch_size: int, optional
|
||||||
|
|
||||||
|
|
@ -669,7 +676,7 @@ def patch_shuffle(signal: ArrayLike | Recording, max_patch_size: Optional[int] =
|
||||||
:return: A numpy array containing the I and Q data samples with randomly shuffled regions if `signal` is
|
:return: A numpy array containing the I and Q data samples with randomly shuffled regions if `signal` is
|
||||||
an array. If `signal` is a `Recording`, returns a `Recording` object with its `data` attribute containing
|
an array. If `signal` is a `Recording`, returns a `Recording` object with its `data` attribute containing
|
||||||
the shuffled array.
|
the shuffled array.
|
||||||
:rtype: np.ndarray or ria_toolkit_oss.datatypes.Recording
|
:rtype: np.ndarray or ria_toolkit_oss.data.Recording
|
||||||
|
|
||||||
>>> rec = Recording(data=[[2+5j, 1+8j, 6+4j, 3+7j, 4+9j]])
|
>>> rec = Recording(data=[[2+5j, 1+8j, 6+4j, 3+7j, 4+9j]])
|
||||||
>>> new_rec = patch_shuffle(rec)
|
>>> new_rec = patch_shuffle(rec)
|
||||||
|
|
|
||||||
|
|
@ -9,13 +9,14 @@ not the same as the signal at the end of the medium. What is sent is not what is
|
||||||
Three causes of impairment are attenuation, distortion, and noise.
|
Three causes of impairment are attenuation, distortion, and noise.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import warnings
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from numpy.typing import ArrayLike
|
from numpy.typing import ArrayLike
|
||||||
from scipy.signal import resample_poly
|
from scipy.signal import resample_poly
|
||||||
|
|
||||||
from ria_toolkit_oss.datatypes import Recording
|
from ria_toolkit_oss.data import Recording
|
||||||
from ria_toolkit_oss.transforms import iq_augmentations
|
from ria_toolkit_oss.transforms import iq_augmentations
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -30,7 +31,7 @@ def add_awgn_to_signal(signal: ArrayLike | Recording, snr: Optional[float] = 1)
|
||||||
|
|
||||||
:param signal: Input IQ data as a complex ``C x N`` array or `Recording`, where ``C`` is the number of channels
|
:param signal: Input IQ data as a complex ``C x N`` array or `Recording`, where ``C`` is the number of channels
|
||||||
and ``N`` is the length of the IQ examples.
|
and ``N`` is the length of the IQ examples.
|
||||||
:type signal: array_like or ria_toolkit_oss.datatypes.Recording
|
:type signal: array_like or ria_toolkit_oss.data.Recording
|
||||||
:param snr: The signal-to-noise ratio in dB. Default is 1.
|
:param snr: The signal-to-noise ratio in dB. Default is 1.
|
||||||
:type snr: float, optional
|
:type snr: float, optional
|
||||||
|
|
||||||
|
|
@ -38,7 +39,7 @@ def add_awgn_to_signal(signal: ArrayLike | Recording, snr: Optional[float] = 1)
|
||||||
|
|
||||||
:return: A numpy array which is the sum of the noise (which matches the SNR) and the original signal. If `signal`
|
:return: A numpy array which is the sum of the noise (which matches the SNR) and the original signal. If `signal`
|
||||||
is a `Recording`, returns a `Recording object` with its `data` attribute containing the noisy signal array.
|
is a `Recording`, returns a `Recording object` with its `data` attribute containing the noisy signal array.
|
||||||
:rtype: np.ndarray or ria_toolkit_oss.datatypes.Recording
|
:rtype: np.ndarray or ria_toolkit_oss.data.Recording
|
||||||
|
|
||||||
>>> rec = Recording(data=[[1+1j, 2+2j]])
|
>>> rec = Recording(data=[[1+1j, 2+2j]])
|
||||||
>>> new_rec = add_awgn_to_signal(rec)
|
>>> new_rec = add_awgn_to_signal(rec)
|
||||||
|
|
@ -55,8 +56,6 @@ def add_awgn_to_signal(signal: ArrayLike | Recording, snr: Optional[float] = 1)
|
||||||
raise ValueError("signal must be CxN complex.")
|
raise ValueError("signal must be CxN complex.")
|
||||||
|
|
||||||
noise = iq_augmentations.generate_awgn(signal=data, snr=snr)
|
noise = iq_augmentations.generate_awgn(signal=data, snr=snr)
|
||||||
print(f"noise is {noise}")
|
|
||||||
|
|
||||||
noisy_signal = data + noise
|
noisy_signal = data + noise
|
||||||
|
|
||||||
if isinstance(signal, Recording):
|
if isinstance(signal, Recording):
|
||||||
|
|
@ -72,7 +71,7 @@ def time_shift(signal: ArrayLike | Recording, shift: Optional[int] = 1) -> np.nd
|
||||||
|
|
||||||
:param signal: Input IQ data as a complex CxN array or `Recording`, where C is the number of channels and N
|
:param signal: Input IQ data as a complex CxN array or `Recording`, where C is the number of channels and N
|
||||||
is the length of the IQ examples.
|
is the length of the IQ examples.
|
||||||
:type signal: array_like or ria_toolkit_oss.datatypes.Recording
|
:type signal: array_like or ria_toolkit_oss.data.Recording
|
||||||
:param shift: The number of indices to shift by. Default is 1.
|
:param shift: The number of indices to shift by. Default is 1.
|
||||||
:type shift: int, optional
|
:type shift: int, optional
|
||||||
|
|
||||||
|
|
@ -81,7 +80,7 @@ def time_shift(signal: ArrayLike | Recording, shift: Optional[int] = 1) -> np.nd
|
||||||
|
|
||||||
:return: A numpy array which represents the time-shifted signal. If `signal` is a `Recording`,
|
:return: A numpy array which represents the time-shifted signal. If `signal` is a `Recording`,
|
||||||
returns a `Recording object` with its `data` attribute containing the time-shifted array.
|
returns a `Recording object` with its `data` attribute containing the time-shifted array.
|
||||||
:rtype: np.ndarray or ria_toolkit_oss.datatypes.Recording
|
:rtype: np.ndarray or ria_toolkit_oss.data.Recording
|
||||||
|
|
||||||
>>> rec = Recording(data=[[1+1j, 2+2j, 3+3j, 4+4j, 5+5j]])
|
>>> rec = Recording(data=[[1+1j, 2+2j, 3+3j, 4+4j, 5+5j]])
|
||||||
>>> new_rec = time_shift(rec, -2)
|
>>> new_rec = time_shift(rec, -2)
|
||||||
|
|
@ -101,16 +100,18 @@ def time_shift(signal: ArrayLike | Recording, shift: Optional[int] = 1) -> np.nd
|
||||||
raise ValueError("signal must be CxN complex.")
|
raise ValueError("signal must be CxN complex.")
|
||||||
|
|
||||||
if shift > n:
|
if shift > n:
|
||||||
raise UserWarning("shift is greater than signal length")
|
warnings.warn("shift is greater than signal length")
|
||||||
|
|
||||||
shifted_data = np.zeros_like(data)
|
shifted_data = np.zeros_like(data)
|
||||||
|
|
||||||
if c == 1:
|
if c == 1:
|
||||||
# New iq array shifted left or right depending on sign of shift
|
# New iq array shifted left or right depending on sign of shift
|
||||||
# This should work even if shift > iqdata.shape[1]
|
# This should work even if shift > iqdata.shape[1]
|
||||||
if shift >= 0:
|
if shift > 0:
|
||||||
# Shift to right
|
# Shift to right
|
||||||
shifted_data[:, shift:] = data[:, :-shift]
|
shifted_data[:, shift:] = data[:, :-shift]
|
||||||
|
elif shift == 0:
|
||||||
|
shifted_data[:] = data
|
||||||
|
|
||||||
else:
|
else:
|
||||||
# Shift to the left
|
# Shift to the left
|
||||||
|
|
@ -133,7 +134,7 @@ def frequency_shift(signal: ArrayLike | Recording, shift: Optional[float] = 0.5)
|
||||||
|
|
||||||
:param signal: Input IQ data as a complex CxN array or `Recording`, where C is the number of channels and N
|
:param signal: Input IQ data as a complex CxN array or `Recording`, where C is the number of channels and N
|
||||||
is the length of the IQ examples.
|
is the length of the IQ examples.
|
||||||
:type signal: array_like or ria_toolkit_oss.datatypes.Recording
|
:type signal: array_like or ria_toolkit_oss.data.Recording
|
||||||
:param shift: The frequency shift relative to the sample rate. Must be in the range ``[-0.5, 0.5]``.
|
:param shift: The frequency shift relative to the sample rate. Must be in the range ``[-0.5, 0.5]``.
|
||||||
Default is 0.5.
|
Default is 0.5.
|
||||||
:type shift: float, optional
|
:type shift: float, optional
|
||||||
|
|
@ -143,7 +144,7 @@ def frequency_shift(signal: ArrayLike | Recording, shift: Optional[float] = 0.5)
|
||||||
|
|
||||||
:return: A numpy array which represents the frequency-shifted signal. If `signal` is a `Recording`,
|
:return: A numpy array which represents the frequency-shifted signal. If `signal` is a `Recording`,
|
||||||
returns a `Recording object` with its `data` attribute containing the frequency-shifted array.
|
returns a `Recording object` with its `data` attribute containing the frequency-shifted array.
|
||||||
:rtype: np.ndarray or ria_toolkit_oss.datatypes.Recording
|
:rtype: np.ndarray or ria_toolkit_oss.data.Recording
|
||||||
|
|
||||||
>>> rec = Recording(data=[[1+1j, 2+2j, 3+3j, 4+4j]])
|
>>> rec = Recording(data=[[1+1j, 2+2j, 3+3j, 4+4j]])
|
||||||
>>> new_rec = frequency_shift(rec, -0.4)
|
>>> new_rec = frequency_shift(rec, -0.4)
|
||||||
|
|
@ -188,7 +189,7 @@ def phase_shift(signal: ArrayLike | Recording, phase: Optional[float] = np.pi) -
|
||||||
|
|
||||||
:param signal: Input IQ data as a complex CxN array or `Recording`, where C is the number of channels and N
|
:param signal: Input IQ data as a complex CxN array or `Recording`, where C is the number of channels and N
|
||||||
is the length of the IQ examples.
|
is the length of the IQ examples.
|
||||||
:type signal: array_like or ria_toolkit_oss.datatypes.Recording
|
:type signal: array_like or ria_toolkit_oss.data.Recording
|
||||||
:param phase: The phase angle by which to rotate the IQ samples, in radians. Must be in the range ``[-π, π]``.
|
:param phase: The phase angle by which to rotate the IQ samples, in radians. Must be in the range ``[-π, π]``.
|
||||||
Default is π.
|
Default is π.
|
||||||
:type phase: float, optional
|
:type phase: float, optional
|
||||||
|
|
@ -198,12 +199,12 @@ def phase_shift(signal: ArrayLike | Recording, phase: Optional[float] = np.pi) -
|
||||||
|
|
||||||
:return: A numpy array which represents the phase-shifted signal. If `signal` is a `Recording`,
|
:return: A numpy array which represents the phase-shifted signal. If `signal` is a `Recording`,
|
||||||
returns a `Recording object` with its `data` attribute containing the phase-shifted array.
|
returns a `Recording object` with its `data` attribute containing the phase-shifted array.
|
||||||
:rtype: np.ndarray or ria_toolkit_oss.datatypes.Recording
|
:rtype: np.ndarray or ria_toolkit_oss.data.Recording
|
||||||
|
|
||||||
>>> rec = Recording(data=[[1+1j, 2+2j, 3+3j, 4+4j]])
|
>>> rec = Recording(data=[[1+1j, 2+2j, 3+3j, 4+4j]])
|
||||||
>>> new_rec = phase_shift(rec, np.pi/2)
|
>>> new_rec = phase_shift(rec, np.pi/2)
|
||||||
>>> new_rec.data
|
>>> new_rec.data
|
||||||
array([[-1+1j, -2+2j -3+3j -4+4j]])
|
array([[-1+1j, -2+2j, -3+3j, -4+4j]])
|
||||||
"""
|
"""
|
||||||
# TODO: Additional info needs to be added to docstring description
|
# TODO: Additional info needs to be added to docstring description
|
||||||
|
|
||||||
|
|
@ -245,7 +246,7 @@ def iq_imbalance(
|
||||||
|
|
||||||
:param signal: Input IQ data as a complex CxN array or `Recording`, where C is the number of channels and N
|
:param signal: Input IQ data as a complex CxN array or `Recording`, where C is the number of channels and N
|
||||||
is the length of the IQ examples.
|
is the length of the IQ examples.
|
||||||
:type signal: array_like or ria_toolkit_oss.datatypes.Recording
|
:type signal: array_like or ria_toolkit_oss.data.Recording
|
||||||
:param amplitude_imbalance: The IQ amplitude imbalance to apply, in dB. Default is 1.5.
|
:param amplitude_imbalance: The IQ amplitude imbalance to apply, in dB. Default is 1.5.
|
||||||
:type amplitude_imbalance: float, optional
|
:type amplitude_imbalance: float, optional
|
||||||
:param phase_imbalance: The IQ phase imbalance to apply, in radians. Default is π.
|
:param phase_imbalance: The IQ phase imbalance to apply, in radians. Default is π.
|
||||||
|
|
@ -259,7 +260,7 @@ def iq_imbalance(
|
||||||
|
|
||||||
:return: A numpy array which is the original signal with an applied IQ imbalance. If `signal` is a `Recording`,
|
:return: A numpy array which is the original signal with an applied IQ imbalance. If `signal` is a `Recording`,
|
||||||
returns a `Recording object` with its `data` attribute containing the IQ imbalanced signal array.
|
returns a `Recording object` with its `data` attribute containing the IQ imbalanced signal array.
|
||||||
:rtype: np.ndarray or ria_toolkit_oss.datatypes.Recording
|
:rtype: np.ndarray or ria_toolkit_oss.data.Recording
|
||||||
|
|
||||||
>>> rec = Recording(data=[[2+18j, -34+2j, 3+9j]])
|
>>> rec = Recording(data=[[2+18j, -34+2j, 3+9j]])
|
||||||
>>> new_rec = iq_imbalance(rec, 1, np.pi, 2)
|
>>> new_rec = iq_imbalance(rec, 1, np.pi, 2)
|
||||||
|
|
@ -314,7 +315,7 @@ def resample(signal: ArrayLike | Recording, up: Optional[int] = 4, down: Optiona
|
||||||
|
|
||||||
:param signal: Input IQ data as a complex CxN array or `Recording`, where C is the number of channels and N
|
:param signal: Input IQ data as a complex CxN array or `Recording`, where C is the number of channels and N
|
||||||
is the length of the IQ examples.
|
is the length of the IQ examples.
|
||||||
:type signal: array_like or ria_toolkit_oss.datatypes.Recording
|
:type signal: array_like or ria_toolkit_oss.data.Recording
|
||||||
:param up: The upsampling factor. Default is 4.
|
:param up: The upsampling factor. Default is 4.
|
||||||
:type up: int, optional
|
:type up: int, optional
|
||||||
:param down: The downsampling factor. Default is 2.
|
:param down: The downsampling factor. Default is 2.
|
||||||
|
|
@ -324,7 +325,7 @@ def resample(signal: ArrayLike | Recording, up: Optional[int] = 4, down: Optiona
|
||||||
|
|
||||||
:return: A numpy array which represents the resampled signal If `signal` is a `Recording`,
|
:return: A numpy array which represents the resampled signal If `signal` is a `Recording`,
|
||||||
returns a `Recording object` with its `data` attribute containing the resampled array.
|
returns a `Recording object` with its `data` attribute containing the resampled array.
|
||||||
:rtype: np.ndarray or ria_toolkit_oss.datatypes.Recording
|
:rtype: np.ndarray or ria_toolkit_oss.data.Recording
|
||||||
|
|
||||||
>>> rec = Recording(data=[[1+1j, 2+2j]])
|
>>> rec = Recording(data=[[1+1j, 2+2j]])
|
||||||
>>> new_rec = resample(rec, 2, 1)
|
>>> new_rec = resample(rec, 2, 1)
|
||||||
|
|
@ -354,8 +355,9 @@ def resample(signal: ArrayLike | Recording, up: Optional[int] = 4, down: Optiona
|
||||||
resampled_iqdata = resampled_iqdata[:, :n]
|
resampled_iqdata = resampled_iqdata[:, :n]
|
||||||
|
|
||||||
else:
|
else:
|
||||||
empty_array = np.zeros(resampled_iqdata.shape, dtype=resampled_iqdata.dtype)
|
empty_array = np.zeros((1, n), dtype=resampled_iqdata.dtype)
|
||||||
empty_array[:, : resampled_iqdata.shape[1]] = resampled_iqdata
|
empty_array[:, : resampled_iqdata.shape[1]] = resampled_iqdata
|
||||||
|
resampled_iqdata = empty_array
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
|
||||||
Binary file not shown.
Binary file not shown.
|
|
@ -4,14 +4,14 @@ import scipy.signal as signal
|
||||||
from plotly.graph_objs import Figure
|
from plotly.graph_objs import Figure
|
||||||
from scipy.fft import fft, fftshift
|
from scipy.fft import fft, fftshift
|
||||||
|
|
||||||
from ria_toolkit_oss.datatypes import Recording
|
from ria_toolkit_oss.data import Recording
|
||||||
|
|
||||||
|
|
||||||
def spectrogram(rec: Recording, thumbnail: bool = False) -> Figure:
|
def spectrogram(rec: Recording, thumbnail: bool = False) -> Figure:
|
||||||
"""Create a spectrogram for the recording.
|
"""Create a spectrogram for the recording.
|
||||||
|
|
||||||
:param rec: Signal to plot.
|
:param rec: Signal to plot.
|
||||||
:type rec: utils.data.Recording
|
:type rec: ria_toolkit_oss.data.Recording
|
||||||
:param thumbnail: Whether to return a small thumbnail version or full plot.
|
:param thumbnail: Whether to return a small thumbnail version or full plot.
|
||||||
:type thumbnail: bool
|
:type thumbnail: bool
|
||||||
|
|
||||||
|
|
@ -95,7 +95,7 @@ def iq_time_series(rec: Recording) -> Figure:
|
||||||
"""Create a time series plot of the real and imaginary parts of signal.
|
"""Create a time series plot of the real and imaginary parts of signal.
|
||||||
|
|
||||||
:param rec: Signal to plot.
|
:param rec: Signal to plot.
|
||||||
:type rec: utils.data.Recording
|
:type rec: ria_toolkit_oss.data.Recording
|
||||||
|
|
||||||
:return: Time series plot as a Plotly figure.
|
:return: Time series plot as a Plotly figure.
|
||||||
"""
|
"""
|
||||||
|
|
@ -125,7 +125,7 @@ def frequency_spectrum(rec: Recording) -> Figure:
|
||||||
"""Create a frequency spectrum plot from the recording.
|
"""Create a frequency spectrum plot from the recording.
|
||||||
|
|
||||||
:param rec: Input signal to plot.
|
:param rec: Input signal to plot.
|
||||||
:type rec: utils.data.Recording
|
:type rec: ria_toolkit_oss.data.Recording
|
||||||
|
|
||||||
:return: Frequency spectrum as a Plotly figure.
|
:return: Frequency spectrum as a Plotly figure.
|
||||||
"""
|
"""
|
||||||
|
|
@ -160,7 +160,7 @@ def constellation(rec: Recording) -> Figure:
|
||||||
"""Create a constellation plot from the recording.
|
"""Create a constellation plot from the recording.
|
||||||
|
|
||||||
:param rec: Input signal to plot.
|
:param rec: Input signal to plot.
|
||||||
:type rec: utils.data.Recording
|
:type rec: ria_toolkit_oss.data.Recording
|
||||||
|
|
||||||
:return: Constellation as a Plotly figure.
|
:return: Constellation as a Plotly figure.
|
||||||
"""
|
"""
|
||||||
|
|
|
||||||
|
|
@ -6,12 +6,13 @@ from typing import Optional
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from matplotlib import gridspec
|
from matplotlib import gridspec
|
||||||
|
from matplotlib.patches import Patch
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from scipy.fft import fft, fftshift
|
from scipy.fft import fft, fftshift
|
||||||
from scipy.signal import spectrogram
|
from scipy.signal import spectrogram
|
||||||
from scipy.signal.windows import hann
|
from scipy.signal.windows import hann
|
||||||
|
|
||||||
from ria_toolkit_oss.datatypes.recording import Recording
|
from ria_toolkit_oss.data.recording import Recording
|
||||||
from ria_toolkit_oss.view.tools import (
|
from ria_toolkit_oss.view.tools import (
|
||||||
COLORS,
|
COLORS,
|
||||||
decimate,
|
decimate,
|
||||||
|
|
@ -39,6 +40,76 @@ def set_spines(ax, spines):
|
||||||
ax.spines["left"].set_visible(False)
|
ax.spines["left"].set_visible(False)
|
||||||
|
|
||||||
|
|
||||||
|
def view_annotations(
|
||||||
|
recording: Recording,
|
||||||
|
channel: Optional[int] = 0,
|
||||||
|
output_path: Optional[str] = "images/annotations.png",
|
||||||
|
title: Optional[str] = "Annotated Spectrogram",
|
||||||
|
dpi: Optional[int] = 300,
|
||||||
|
title_fontsize: Optional[int] = 15,
|
||||||
|
dark: Optional[bool] = True,
|
||||||
|
) -> None:
|
||||||
|
# 1. Setup Plotting Environment
|
||||||
|
plt.close("all")
|
||||||
|
if dark:
|
||||||
|
plt.style.use("dark_background")
|
||||||
|
else:
|
||||||
|
plt.style.use("default")
|
||||||
|
|
||||||
|
fig, ax = plt.subplots(figsize=(12, 8))
|
||||||
|
|
||||||
|
complex_signal = recording.data[channel]
|
||||||
|
sample_rate, center_frequency, _ = extract_metadata_fields(recording.metadata)
|
||||||
|
annotations = recording.annotations
|
||||||
|
|
||||||
|
# 2. Setup Color Mapping
|
||||||
|
palette = ["#2196F3", "#9C27B0", "#64B5F6", "#7B1FA2", "#5C6BC0", "#CE93D8", "#1565C0", "#7C4DFF"]
|
||||||
|
unique_labels = sorted(list(set(ann.label for ann in annotations if ann.label)))
|
||||||
|
label_to_color = {label: palette[i % len(palette)] for i, label in enumerate(unique_labels)}
|
||||||
|
|
||||||
|
# 3. Generate Spectrogram
|
||||||
|
Pxx, freqs, times, im = ax.specgram(
|
||||||
|
complex_signal, NFFT=256, Fs=sample_rate, Fc=center_frequency, noverlap=128, cmap="twilight"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 4. Draw Annotations (highest threshold % first so lower % renders on top)
|
||||||
|
def _threshold_sort_key(ann):
|
||||||
|
try:
|
||||||
|
return int(ann.label.rstrip("%"))
|
||||||
|
except (ValueError, AttributeError):
|
||||||
|
return 0
|
||||||
|
|
||||||
|
for annotation in sorted(annotations, key=_threshold_sort_key, reverse=True):
|
||||||
|
t_start = annotation.sample_start / sample_rate
|
||||||
|
t_width = annotation.sample_count / sample_rate
|
||||||
|
f_start = annotation.freq_lower_edge
|
||||||
|
f_height = annotation.freq_upper_edge - annotation.freq_lower_edge
|
||||||
|
|
||||||
|
ann_color = label_to_color.get(annotation.label, "gray")
|
||||||
|
|
||||||
|
rect = plt.Rectangle(
|
||||||
|
(t_start, f_start), t_width, f_height, linewidth=1.5, edgecolor=ann_color, facecolor="none", alpha=0.8
|
||||||
|
)
|
||||||
|
ax.add_patch(rect)
|
||||||
|
|
||||||
|
if unique_labels:
|
||||||
|
legend_elements = [
|
||||||
|
Patch(facecolor=label_to_color[label], alpha=0.3, edgecolor=label_to_color[label], label=label)
|
||||||
|
for label in unique_labels
|
||||||
|
]
|
||||||
|
ax.legend(handles=legend_elements, loc="upper right", framealpha=0.2)
|
||||||
|
|
||||||
|
ax.set_title(title, fontsize=title_fontsize, pad=20)
|
||||||
|
ax.set_xlabel("Time (s)", fontsize=12)
|
||||||
|
ax.set_ylabel("Frequency (MHz)", fontsize=12)
|
||||||
|
ax.grid(alpha=0.1)
|
||||||
|
|
||||||
|
output_path, _ = set_path(output_path=output_path)
|
||||||
|
plt.savefig(output_path, dpi=dpi, bbox_inches="tight")
|
||||||
|
plt.close(fig)
|
||||||
|
print(f"Professional annotation plot saved to {output_path}")
|
||||||
|
|
||||||
|
|
||||||
def view_channels(
|
def view_channels(
|
||||||
recording: Recording,
|
recording: Recording,
|
||||||
output_path: Optional[str] = "images/signal.png",
|
output_path: Optional[str] = "images/signal.png",
|
||||||
|
|
@ -209,9 +280,7 @@ def view_sig(
|
||||||
)
|
)
|
||||||
|
|
||||||
set_spines(spec_ax, spines)
|
set_spines(spec_ax, spines)
|
||||||
spec_ax.set_title("Spectrogram", fontsize=subtitle_fontsize)
|
spec_ax.set_title("Spectrogram", loc="center", fontsize=subtitle_fontsize)
|
||||||
spec_ax.set_ylabel("Frequency (Hz)")
|
|
||||||
spec_ax.set_xlabel("Time (s)")
|
|
||||||
|
|
||||||
if iq:
|
if iq:
|
||||||
iq_ax = plt.subplot(gs[plot_y_indx : plot_y_indx + 2, :])
|
iq_ax = plt.subplot(gs[plot_y_indx : plot_y_indx + 2, :])
|
||||||
|
|
@ -295,7 +364,11 @@ def view_sig(
|
||||||
set_spines(meta_ax, spines)
|
set_spines(meta_ax, spines)
|
||||||
|
|
||||||
if logo and os.path.isfile(logo_path):
|
if logo and os.path.isfile(logo_path):
|
||||||
logo_ax = plt.subplot(gs[plot_y_indx + 2 :, 2])
|
# logo_ax = plt.subplot(gs[plot_y_indx:, 2])
|
||||||
|
logo_pos = [0.75, 0.05, 0.2, 0.08]
|
||||||
|
logo_ax = fig.add_axes(logo_pos, anchor="SE", zorder=10)
|
||||||
|
plot_x_indx = plot_x_indx + 1
|
||||||
|
|
||||||
logo_ax.axis("off")
|
logo_ax.axis("off")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|
@ -314,7 +387,6 @@ def view_sig(
|
||||||
hspace=2.5, # Vertical space between subplots
|
hspace=2.5, # Vertical space between subplots
|
||||||
)
|
)
|
||||||
|
|
||||||
# save path handling
|
|
||||||
output_path, _ = set_path(output_path=output_path)
|
output_path, _ = set_path(output_path=output_path)
|
||||||
plt.savefig(output_path, dpi=dpi)
|
plt.savefig(output_path, dpi=dpi)
|
||||||
print(f"Saved signal plot to {output_path}")
|
print(f"Saved signal plot to {output_path}")
|
||||||
|
|
|
||||||
|
|
@ -3,6 +3,7 @@
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import gc
|
import gc
|
||||||
|
import json
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import matplotlib
|
import matplotlib
|
||||||
|
|
@ -11,7 +12,7 @@ import numpy as np
|
||||||
from scipy.fft import fft, fftshift
|
from scipy.fft import fft, fftshift
|
||||||
from scipy.signal.windows import hann
|
from scipy.signal.windows import hann
|
||||||
|
|
||||||
from ria_toolkit_oss.datatypes.recording import Recording
|
from ria_toolkit_oss.data.recording import Recording
|
||||||
from ria_toolkit_oss.view.tools import (
|
from ria_toolkit_oss.view.tools import (
|
||||||
COLORS,
|
COLORS,
|
||||||
decimate,
|
decimate,
|
||||||
|
|
@ -20,6 +21,52 @@ from ria_toolkit_oss.view.tools import (
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _add_annotations(annotations, compact_mode, show_labels, sample_rate_hz, center_freq_hz, ax2):
|
||||||
|
if annotations and not compact_mode:
|
||||||
|
for annotation in annotations:
|
||||||
|
start_idx = annotation.get("core:sample_start", 0)
|
||||||
|
length = annotation.get("core:sample_count", 0)
|
||||||
|
start_time = start_idx / sample_rate_hz
|
||||||
|
end_time = (start_idx + length) / sample_rate_hz
|
||||||
|
freq_low = annotation.get("core:freq_lower_edge", center_freq_hz - sample_rate_hz / 4)
|
||||||
|
freq_high = annotation.get("core:freq_upper_edge", center_freq_hz + sample_rate_hz / 4)
|
||||||
|
comment = annotation.get("core:comment", "{}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
comment_data = json.loads(comment) if isinstance(comment, str) else comment
|
||||||
|
ann_type = comment_data.get("type", "unknown")
|
||||||
|
if ann_type == "intersection":
|
||||||
|
color = COLORS["success"]
|
||||||
|
elif ann_type == "parallel":
|
||||||
|
color = COLORS["primary"]
|
||||||
|
elif ann_type == "standalone":
|
||||||
|
color = COLORS["warning"]
|
||||||
|
else:
|
||||||
|
color = COLORS["error"]
|
||||||
|
except Exception:
|
||||||
|
color = COLORS["error"]
|
||||||
|
|
||||||
|
rect = plt.Rectangle(
|
||||||
|
(start_time, freq_low),
|
||||||
|
end_time - start_time,
|
||||||
|
freq_high - freq_low,
|
||||||
|
color=color,
|
||||||
|
alpha=0.4,
|
||||||
|
linewidth=2,
|
||||||
|
)
|
||||||
|
ax2.add_patch(rect)
|
||||||
|
if show_labels:
|
||||||
|
label = annotation.get("core:label", "Signal")
|
||||||
|
ax2.text(
|
||||||
|
start_time,
|
||||||
|
freq_high,
|
||||||
|
label,
|
||||||
|
color=COLORS["light"],
|
||||||
|
fontsize=10,
|
||||||
|
bbox=dict(boxstyle="round,pad=0.2", facecolor=color, alpha=0.7),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _get_nfft_size(signal, fast_mode):
|
def _get_nfft_size(signal, fast_mode):
|
||||||
if len(signal) < 1000:
|
if len(signal) < 1000:
|
||||||
nfft = 128
|
nfft = 128
|
||||||
|
|
@ -138,6 +185,7 @@ def detect_constellation_symbols(signal: np.ndarray, method: str = "differential
|
||||||
|
|
||||||
def view_simple_sig(
|
def view_simple_sig(
|
||||||
recording: Recording,
|
recording: Recording,
|
||||||
|
annotations: Optional[list] = None,
|
||||||
output_path: Optional[str] = "images/signal.png",
|
output_path: Optional[str] = "images/signal.png",
|
||||||
saveplot: Optional[bool] = True,
|
saveplot: Optional[bool] = True,
|
||||||
fast_mode: Optional[bool] = False,
|
fast_mode: Optional[bool] = False,
|
||||||
|
|
@ -261,6 +309,15 @@ def view_simple_sig(
|
||||||
|
|
||||||
ax2.set_title("Spectrogram", loc="left", pad=10)
|
ax2.set_title("Spectrogram", loc="left", pad=10)
|
||||||
|
|
||||||
|
_add_annotations(
|
||||||
|
annotations=annotations,
|
||||||
|
compact_mode=compact_mode,
|
||||||
|
show_labels=show_labels,
|
||||||
|
sample_rate_hz=sample_rate_hz,
|
||||||
|
center_freq_hz=center_freq_hz,
|
||||||
|
ax2=ax2,
|
||||||
|
)
|
||||||
|
|
||||||
if ax_constellation is not None:
|
if ax_constellation is not None:
|
||||||
constellation_samples = _get_plot_samples(signal=signal, fast_mode=fast_mode, slow_max=50_000, fast_max=20_000)
|
constellation_samples = _get_plot_samples(signal=signal, fast_mode=fast_mode, slow_max=50_000, fast_max=20_000)
|
||||||
method = "differential" if fast_mode else "combined"
|
method = "differential" if fast_mode else "combined"
|
||||||
|
|
@ -310,7 +367,7 @@ def view_simple_sig(
|
||||||
else:
|
else:
|
||||||
plt.tight_layout()
|
plt.tight_layout()
|
||||||
if show_title:
|
if show_title:
|
||||||
plt.subplots_adjust(top=0.90)
|
plt.subplots_adjust(top=0.92)
|
||||||
|
|
||||||
if saveplot:
|
if saveplot:
|
||||||
output_path, extension = set_path(output_path=output_path)
|
output_path, extension = set_path(output_path=output_path)
|
||||||
|
|
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user