M
madrigal
8a66860d33
All checks were successful
Build Sphinx Docs Set / Build Docs (pull_request) Successful in 15m51s
Build Project / Build Project (3.10) (pull_request) Successful in 16m14s
Build Project / Build Project (3.11) (pull_request) Successful in 17m9s
Build Project / Build Project (3.12) (pull_request) Successful in 2m29s
Test with tox / Test with tox (3.12) (pull_request) Successful in 21m28s
Test with tox / Test with tox (3.10) (pull_request) Successful in 22m50s
Test with tox / Test with tox (3.11) (pull_request) Successful in 23m18s
612 lines
21 KiB
Python
612 lines
21 KiB
Python
import math
|
||
import pickle
|
||
import threading
|
||
import warnings
|
||
from abc import ABC, abstractmethod
|
||
from typing import Optional
|
||
|
||
import numpy as np
|
||
import zmq
|
||
|
||
from ria_toolkit_oss.data.recording import Recording
|
||
|
||
|
||
class SDR(ABC):
|
||
"""
|
||
This class defines a common interface (a template) for all SDR devices.
|
||
Each specific SDR implementation should subclass SDR and provide concrete implementations
|
||
for the abstract methods.
|
||
|
||
To add support for a new radio, subclass this interface and implement all abstract methods.
|
||
If you experience difficulties, please `contact us <mailto:info@qoherent.ai>`_, we are happy to
|
||
provide additional direction and/or help with the implementation details.
|
||
"""
|
||
|
||
def __init__(self):
|
||
|
||
self._rx_initialized = False
|
||
self._tx_initialized = False
|
||
self._enable_rx = False
|
||
self._enable_tx = False
|
||
|
||
self._accumulated_buffer = None
|
||
self._max_num_buffers = None
|
||
self._num_buffers_processed = 0
|
||
self._last_buffer = None
|
||
self._corrupted_buffer_count = 0
|
||
|
||
self.rx_sample_rate = None
|
||
self.rx_center_frequency = None
|
||
self.rx_gain = None
|
||
self.tx_sample_rate = None
|
||
self.tx_center_frequency = None
|
||
self.tx_gain = None
|
||
self._param_lock = threading.RLock() # Reentrant lock
|
||
|
||
# Pending config consumed by rx() on first call and by _apply_sdr_config
|
||
# in the agent inference loop. Subclasses that need different defaults
|
||
# (e.g. MockSDR) can overwrite these in their own __init__.
|
||
self.center_freq: float = 2.4e9
|
||
self.sample_rate: float = 10e6
|
||
self.gain: float = 40.0
|
||
|
||
def record(self, num_samples: Optional[int] = None, rx_time: Optional[int | float] = None) -> Recording:
|
||
"""
|
||
Create a radio recording of a given length. Either ``num_samples`` or ``rx_time`` must be provided.
|
||
|
||
Note that ``init_rx()`` must be called before ``record()``.
|
||
|
||
:param num_samples: The number of samples to record.
|
||
:type num_samples: int, optional
|
||
:param rx_time: The time to record.
|
||
:type rx_time: int or float, optional
|
||
|
||
:return: The Recording object
|
||
:rtype: Recording
|
||
"""
|
||
|
||
if not self._rx_initialized:
|
||
raise RuntimeError("RX was not initialized. init_rx() must be called before _stream_rx() or record()")
|
||
|
||
if num_samples is not None and rx_time is not None:
|
||
raise ValueError("Only input one of num_samples or rx_time")
|
||
elif num_samples is not None:
|
||
self._num_samples_to_record = num_samples
|
||
elif rx_time is not None:
|
||
self._num_samples_to_record = int(rx_time * self.rx_sample_rate)
|
||
else:
|
||
raise ValueError("Must provide input of one of num_samples or rx_time")
|
||
|
||
self.buffer_size = self.rx_buffer_size
|
||
num_buffers = self._num_samples_to_record // self.buffer_size + 1
|
||
|
||
self._max_num_buffers = num_buffers
|
||
self._num_buffers_processed = 0
|
||
self._last_buffer = None
|
||
self._accumulated_buffer = None
|
||
print("Starting stream")
|
||
|
||
self._stream_rx(
|
||
callback=self._accumulate_buffers_callback,
|
||
)
|
||
|
||
print("Finished stream")
|
||
metadata = {
|
||
"source": self.__class__.__name__,
|
||
"sample_rate": self.rx_sample_rate,
|
||
"center_frequency": self.rx_center_frequency,
|
||
"gain": self.rx_gain,
|
||
}
|
||
|
||
print("Creating recording")
|
||
# build recording, truncate to self._num_samples_to_record
|
||
recording = Recording(data=self._accumulated_buffer[:, : self._num_samples_to_record], metadata=metadata)
|
||
|
||
# reset to record again
|
||
self._accumulated_buffer = None
|
||
self._num_buffers_processed = 0
|
||
return recording
|
||
|
||
def rx(self, num_samples: int) -> "np.ndarray":
|
||
"""Return *num_samples* complex IQ samples as a 1-D complex64 array.
|
||
|
||
This is the interface used by the agent inference loop. On first call,
|
||
``init_rx()`` is invoked automatically using the values stored in
|
||
``center_freq``, ``sample_rate``, and ``gain`` (set beforehand by
|
||
``_apply_sdr_config``). Subsequent calls stream directly.
|
||
|
||
Subclasses may override this for hardware-native capture APIs (e.g.
|
||
``MockSDR`` uses AWGN generation; ``PlutoSDR`` could use
|
||
``self.radio.rx()``).
|
||
"""
|
||
if not self._rx_initialized:
|
||
gain = self.gain if isinstance(self.gain, (int, float)) else 40.0
|
||
self.init_rx(
|
||
sample_rate=self.sample_rate,
|
||
center_frequency=self.center_freq,
|
||
gain=gain,
|
||
channel=0,
|
||
)
|
||
recording = self.record(num_samples=num_samples)
|
||
# Recording.data is either a list of 1-D arrays (one per channel) or a
|
||
# 2-D ndarray (channels × samples). Either way, index 0 is channel 0.
|
||
data = recording.data
|
||
return data[0] if hasattr(data, "__getitem__") else data
|
||
|
||
def stream_to_zmq(self, zmq_address, n_samples: int, buffer_size: Optional[int] = 10000):
|
||
"""
|
||
Stream iq samples as interleaved bytes via zmq.
|
||
|
||
:param zmq_address: The zmq address.
|
||
:type zmq_address:
|
||
:param n_samples: The number of samples to stream.
|
||
:type n_samples: int
|
||
:param buffer_size: The buffer size during streaming. Defaults to 10000.
|
||
:type buffer_size: int, optional
|
||
|
||
:return: The trimmed Recording.
|
||
:rtype: Recording
|
||
"""
|
||
try:
|
||
self._previous_buffer = None
|
||
self._max_num_buffers = np.inf if n_samples == np.inf else math.ceil(n_samples / buffer_size)
|
||
self._num_buffers_processed = 0
|
||
self.zmq_address = _generate_full_zmq_address(str(zmq_address))
|
||
self.context = zmq.Context()
|
||
self.socket = self.context.socket(zmq.PUB)
|
||
self.socket.bind(self.zmq_address)
|
||
|
||
self._stream_rx(
|
||
self._zmq_bytestream_callback,
|
||
)
|
||
finally:
|
||
if hasattr(self, "socket"):
|
||
self.socket.close()
|
||
if hasattr(self, "context"):
|
||
self.context.destroy()
|
||
|
||
def _accumulate_buffers_callback(self, buffer, metadata=None):
|
||
"""
|
||
Receives a buffer and saves it to self.accumulated_buffer.
|
||
"""
|
||
# expected buffer is complex samples range -1 to 1
|
||
# save the buffer until max reached
|
||
# return a recording
|
||
|
||
# Validate buffer
|
||
if not self._validate_buffer(buffer):
|
||
print("Warning: Corrupted buffer detected, skipping")
|
||
self._corrupted_buffer_count += 1
|
||
return # Skip this buffer
|
||
|
||
if isinstance(buffer, np.ndarray):
|
||
if buffer.ndim == 1:
|
||
buffer = buffer[np.newaxis, :] # make shape (1, N)
|
||
else:
|
||
buffer = np.array(buffer) # make it 1d
|
||
if len(buffer.shape) == 1:
|
||
buffer = np.array([buffer])
|
||
|
||
# First call: pre-allocate if we know the final size
|
||
if self._accumulated_buffer is None:
|
||
# Check that _max_num_buffers is set
|
||
if self._max_num_buffers is None:
|
||
raise ValueError("Number of buffers for block capture not set.")
|
||
if self._num_samples_to_record is None:
|
||
raise ValueError("Number of samples not set before RX start.")
|
||
|
||
if metadata is not None:
|
||
self.received_metadata = metadata
|
||
|
||
# Preallocate once (avoid np.zeros; use np.empty for speed)
|
||
num_channels = buffer.shape[0]
|
||
self._accumulated_buffer = np.empty((num_channels, self._num_samples_to_record), dtype=buffer.dtype)
|
||
self._write_position = 0
|
||
print(f"Pre-allocated buffer for {self._num_samples_to_record:,} samples.")
|
||
|
||
# Write new buffer into pre-allocated array
|
||
n = buffer.shape[1]
|
||
start = self._write_position
|
||
end = min(start + n, self._num_samples_to_record)
|
||
samples_to_write = end - start
|
||
|
||
if samples_to_write > 0:
|
||
self._accumulated_buffer[:, start:end] = buffer[:, : end - start]
|
||
self._write_position = end
|
||
|
||
# Check if we're done
|
||
self._num_buffers_processed += 1
|
||
if self._num_buffers_processed >= self._max_num_buffers:
|
||
self.stop()
|
||
|
||
def _validate_buffer(self, buffer):
|
||
"""Check for obviously corrupt data."""
|
||
# Check for all zeros
|
||
if np.all(buffer == 0):
|
||
return False
|
||
# Check for all same value
|
||
if np.all(buffer == buffer[0]):
|
||
return False
|
||
return True
|
||
|
||
def _zmq_bytestream_callback(self, buffer, metadata=None):
|
||
# push to ZMQ port
|
||
data = np.array(buffer).tobytes() # convert to bytes for transport
|
||
self.socket.send(data)
|
||
|
||
self._num_buffers_processed = self._num_buffers_processed + 1
|
||
if self._max_num_buffers is not None:
|
||
if self._num_buffers_processed >= self._max_num_buffers:
|
||
self.pause_rx()
|
||
|
||
def pickle_buffer_to_zmq(self, zmq_address, buffer_size, num_buffers):
|
||
"""
|
||
Stream samples to a zmq address, packaged in binary buffers using numpy.pickle.
|
||
Useful for inference applications with a known input size.
|
||
May reduce transfer rates, but individual buffers will not have discontinuities.
|
||
|
||
:param zmq_address: The tcp address to stream to.
|
||
:type zmq_address: str
|
||
:param buffer_size: The number of iq samples in a buffer.
|
||
:type buffer_size: int
|
||
:param num_buffers: The number of buffers to stream before stopping.
|
||
:type num_buffers: int
|
||
"""
|
||
self._max_num_buffers = num_buffers
|
||
self.buffer_size = buffer_size
|
||
self._num_buffers_processed = 0
|
||
self.zmq_address = _generate_full_zmq_address(str(zmq_address))
|
||
self.context = zmq.Context()
|
||
self.socket = self.context.socket(zmq.PUB)
|
||
self.socket.bind(self.zmq_address)
|
||
self.set_rx_buffer_size(buffer_size)
|
||
|
||
self._stream_rx(self._zmq_pickle_buffer_callback)
|
||
|
||
def _zmq_pickle_buffer_callback(self, buffer, metadata=None):
|
||
# push to ZMQ port
|
||
# data = np.array(buffer).tobytes() # convert to bytes for transport
|
||
# self.socket.send(data)
|
||
|
||
self.socket.send(pickle.dumps(buffer))
|
||
|
||
# print(f"Sent {self._num_buffers_processed} ZMQ buffers to {self.zmq_address}")
|
||
|
||
self._num_buffers_processed = self._num_buffers_processed + 1
|
||
if self._max_num_buffers is not None:
|
||
if self._num_buffers_processed >= self._max_num_buffers:
|
||
self.stop()
|
||
|
||
if self._last_buffer is not None:
|
||
if np.array_equal(buffer, self._last_buffer):
|
||
print("\033[93mWarning: Buffer Overflow Detected\033[0m")
|
||
self._last_buffer = buffer.copy()
|
||
else:
|
||
self._last_buffer = buffer.copy()
|
||
|
||
def tx_recording(
|
||
self,
|
||
recording: Recording | np.ndarray,
|
||
num_samples: Optional[int] = None,
|
||
tx_time: Optional[int | float] = None,
|
||
):
|
||
"""
|
||
Transmit the given iq samples from the provided recording.
|
||
init_tx() must be called before this function.
|
||
|
||
:param recording: The recording to transmit.
|
||
:type recording: Recording or np.ndarray
|
||
:param num_samples: The number of samples to transmit, will repeat or
|
||
truncate the recording to this length. Defaults to None.
|
||
:type num_samples: int, optional
|
||
:param tx_time: The time to transmit, will repeat or truncate the
|
||
recording to this length. Defaults to None.
|
||
:type tx_time: int or float, optional
|
||
"""
|
||
|
||
if not self._tx_initialized:
|
||
raise RuntimeError(
|
||
"TX was not initialized. init_tx() must be called before _stream_tx() or transmit_recording()"
|
||
)
|
||
|
||
if num_samples is not None and tx_time is not None:
|
||
raise ValueError("Only input one of num_samples or tx_time")
|
||
elif num_samples is not None:
|
||
self._num_samples_to_transmit = num_samples
|
||
elif tx_time is not None:
|
||
self._num_samples_to_transmit = int(tx_time * self.tx_sample_rate)
|
||
else:
|
||
self._num_samples_to_transmit = len(recording)
|
||
|
||
if isinstance(recording, np.ndarray):
|
||
self._samples_to_transmit = recording
|
||
elif isinstance(recording, Recording):
|
||
if len(recording.data) > 1:
|
||
warnings.warn("Recording object is multichannel, only channel 0 data was used for transmission")
|
||
|
||
self._samples_to_transmit = recording.data[0]
|
||
|
||
self._num_samples_transmitted = 0
|
||
|
||
self._stream_tx(self._loop_recording_callback)
|
||
|
||
def _loop_recording_callback(self, num_samples):
|
||
|
||
samples_left = self._num_samples_to_transmit - self._num_samples_transmitted
|
||
# find where to start based on num_samples_transmitted
|
||
start_index = self._num_samples_transmitted % len(self._samples_to_transmit)
|
||
|
||
# generates an array of indices that wrap around as many times as necessary.
|
||
indices = np.arange(start_index, start_index + num_samples) % len(self._samples_to_transmit)
|
||
samples = self._samples_to_transmit[indices]
|
||
|
||
# zero pad at the end so we are still giving the requested buffer size
|
||
# while also giving the exact number of non zero samples
|
||
if len(samples) > samples_left:
|
||
samples[int(samples_left) :] = 0
|
||
self.pause_tx()
|
||
|
||
self._num_samples_transmitted = self._num_samples_transmitted + num_samples
|
||
|
||
return samples
|
||
|
||
def supports_bias_tee(self) -> bool:
|
||
"""Return True when the radio supports bias-tee control."""
|
||
return False
|
||
|
||
def set_bias_tee(self, enable: bool):
|
||
"""Enable or disable bias-tee power when supported by the radio."""
|
||
raise NotImplementedError(f"{self.__class__.__name__} does not support bias-tee control")
|
||
|
||
def pause_rx(self):
|
||
self._enable_rx = False
|
||
|
||
def pause_tx(self):
|
||
self._enable_tx = False
|
||
|
||
def stop(self):
|
||
self.pause_rx()
|
||
self.pause_tx()
|
||
|
||
def get_rx_sample_rate(self):
|
||
"""
|
||
Retrieve the current sample rate of the receiver.
|
||
|
||
Returns:
|
||
float: The receiver's sample rate in samples per second (Hz).
|
||
"""
|
||
return self.rx_sample_rate
|
||
|
||
def get_rx_center_frequency(self):
|
||
"""
|
||
Retrieve the current center frequency of the receiver.
|
||
|
||
Returns:
|
||
float: The receiver's center frequency in Hertz (Hz).
|
||
"""
|
||
return self.rx_center_frequency
|
||
|
||
def get_rx_gain(self):
|
||
"""
|
||
Retrieve the current gain setting of the receiver.
|
||
|
||
Returns:
|
||
float: The receiver's gain in decibels (dB).
|
||
"""
|
||
return self.rx_gain
|
||
|
||
def get_tx_sample_rate(self):
|
||
"""
|
||
Retrieve the current sample rate of the transmitter.
|
||
|
||
Returns:
|
||
float: The transmitter's sample rate in samples per second (Hz).
|
||
"""
|
||
return self.tx_sample_rate
|
||
|
||
def get_tx_center_frequency(self):
|
||
"""
|
||
Retrieve the current center frequency of the transmitter.
|
||
|
||
Returns:
|
||
float: The transmitter's center frequency in Hertz (Hz).
|
||
"""
|
||
return self.tx_center_frequency
|
||
|
||
def get_tx_gain(self):
|
||
"""
|
||
Retrieve the current gain setting of the transmitter.
|
||
|
||
Returns:
|
||
float: The transmitter's gain in decibels (dB).
|
||
"""
|
||
return self.tx_gain
|
||
|
||
def set_rx_sample_rate(self):
|
||
"""
|
||
Set the sample rate of the receiver.
|
||
"""
|
||
raise NotImplementedError
|
||
|
||
def set_rx_center_frequency(self):
|
||
"""
|
||
Set the center frequency of the receiver.
|
||
"""
|
||
raise NotImplementedError
|
||
|
||
def set_rx_gain(self):
|
||
"""
|
||
Set the gain setting of the receiver.
|
||
"""
|
||
raise NotImplementedError
|
||
|
||
def set_tx_sample_rate(self):
|
||
"""
|
||
Set the sample rate of the transmitter.
|
||
"""
|
||
raise NotImplementedError
|
||
|
||
def set_tx_center_frequency(self):
|
||
"""
|
||
Set the center frequency of the transmitter.
|
||
"""
|
||
raise NotImplementedError
|
||
|
||
def set_tx_gain(self):
|
||
"""
|
||
Set the gain setting of the transmitter.
|
||
"""
|
||
raise NotImplementedError
|
||
|
||
def supports_dynamic_updates(self) -> dict:
|
||
"""
|
||
Report which parameters can be updated during streaming.
|
||
|
||
Returns:
|
||
dict: {'center_frequency': bool, 'sample_rate': bool, 'gain': bool}
|
||
"""
|
||
return {"center_frequency": False, "sample_rate": False, "gain": False}
|
||
|
||
def __del__(self):
|
||
"""Cleanup on garbage collection."""
|
||
try:
|
||
self.close()
|
||
except Exception:
|
||
pass
|
||
|
||
@abstractmethod
|
||
def close(self):
|
||
pass
|
||
|
||
@abstractmethod
|
||
def init_rx(self, sample_rate, center_frequency, gain, channel, gain_mode):
|
||
pass
|
||
|
||
@abstractmethod
|
||
def init_tx(self, sample_rate, center_frequency, gain, channel, gain_mode):
|
||
pass
|
||
|
||
@abstractmethod
|
||
def _stream_rx(self, callback):
|
||
pass
|
||
|
||
@abstractmethod
|
||
def _stream_tx(self, callback):
|
||
pass
|
||
|
||
@abstractmethod
|
||
def set_clock_source(self, source):
|
||
"""
|
||
Sets the clock source to external or internal.
|
||
|
||
:param source: The clock source
|
||
:type source: str
|
||
"""
|
||
pass
|
||
|
||
|
||
def _generate_full_zmq_address(input_address):
|
||
"""
|
||
Helper function for zmq streaming.
|
||
If given a port number like 5556,
|
||
return tcp localhost address at that port.
|
||
Otherwise, return the address untouched.
|
||
"""
|
||
|
||
if ("://" not in str(input_address)) and _is_valid_port(input_address):
|
||
# If no transport protocol specified, assume TCP
|
||
return "tcp://*:" + str(input_address)
|
||
else:
|
||
# Otherwise, return the input unchanged
|
||
return input_address
|
||
|
||
|
||
def _is_valid_port(port):
|
||
"""
|
||
Helper function for zmq address.
|
||
"""
|
||
try:
|
||
port_num = int(port)
|
||
return 0 <= port_num <= 65535
|
||
except ValueError:
|
||
return False
|
||
|
||
|
||
def _verify_sample_format(samples):
|
||
"""
|
||
Verify that the sample data is in the range -1 to 1.
|
||
|
||
:param buffer: An array of samples.
|
||
|
||
:Return: True if the buffer is in the correct format, false if not.
|
||
:rtype: bool
|
||
"""
|
||
|
||
return np.max(np.abs(samples)) <= 1
|
||
|
||
|
||
class SDRError(Exception):
|
||
"""Base exception for SDR errors."""
|
||
|
||
pass
|
||
|
||
|
||
class SDRParameterError(SDRError):
|
||
"""Invalid parameter (sample rate, freq, gain)."""
|
||
|
||
pass
|
||
|
||
|
||
class SDROverflowError(SDRError):
|
||
"""Buffer overflow detected."""
|
||
|
||
pass
|
||
|
||
|
||
class SdrDisconnectedError(SDRError):
|
||
"""Raised when the SDR device disappears mid-operation (USB unplug, network drop)."""
|
||
|
||
pass
|
||
|
||
|
||
# Substrings that strongly indicate a device has disappeared rather than a
|
||
# transient / recoverable error. Checked case-insensitively against str(exc).
|
||
_DISCONNECT_MARKERS = (
|
||
"no such device",
|
||
"device not found",
|
||
"not found",
|
||
"broken pipe",
|
||
"disconnected",
|
||
"no device",
|
||
"device unplugged",
|
||
"usb",
|
||
"i/o error",
|
||
"input/output error",
|
||
"errno 19", # ENODEV
|
||
"errno 5", # EIO
|
||
)
|
||
|
||
|
||
def translate_disconnect(exc: BaseException) -> BaseException:
|
||
"""Return ``SdrDisconnectedError`` if *exc* looks like a USB/device drop, else *exc*.
|
||
|
||
Drivers wrap their native-API calls with::
|
||
|
||
try:
|
||
return self.radio.rx()
|
||
except Exception as exc:
|
||
raise translate_disconnect(exc) from exc
|
||
|
||
The caller (e.g. the streamer) can then catch ``SdrDisconnectedError``
|
||
specifically and report it to the hub rather than crashing the loop.
|
||
"""
|
||
if isinstance(exc, SdrDisconnectedError):
|
||
return exc
|
||
msg = str(exc).lower()
|
||
if any(marker in msg for marker in _DISCONNECT_MARKERS):
|
||
return SdrDisconnectedError(str(exc))
|
||
# OSError subclass with ENODEV / EIO errno is also a disconnect signal.
|
||
if isinstance(exc, OSError) and getattr(exc, "errno", None) in (5, 19):
|
||
return SdrDisconnectedError(str(exc))
|
||
return exc
|