ria-toolkit-oss/src/ria_toolkit_oss/sdr/sdr.py

384 lines
14 KiB
Python

import math
import pickle
import warnings
from abc import ABC, abstractmethod
from typing import Optional
import numpy as np
import zmq
from ria_toolkit_oss.datatypes.recording import Recording
class SDR(ABC):
"""
This class defines a common interface (a template) for all SDR devices.
Each specific SDR implementation should subclass SDR and provide concrete implementations
for the abstract methods.
To add support for a new radio, subclass this interface and implement all abstract methods.
If you experience difficulties, please `contact us <mailto:info@qoherent.ai>`_, we are happy to
provide additional direction and/or help with the implementation details.
"""
def __init__(self):
self._rx_initialized = False
self._tx_initialized = False
self._enable_rx = False
self._enable_tx = False
self._accumulated_buffer = None
self._max_num_buffers = None
self._num_buffers_processed = 0
self._accumulated_buffer = None
self._last_buffer = None
def record(self, num_samples: Optional[int] = None, rx_time: Optional[int | float] = None) -> Recording:
"""
Create a radio recording of a given length. Either ``num_samples`` or ``rx_time`` must be provided.
Note that ``init_rx()`` must be called before ``record()``.
:param num_samples: The number of samples to record.
:type num_samples: int, optional
:param rx_time: The time to record.
:type rx_time: int or float, optional
:return: The Recording object
:rtype: Recording
"""
if not self._rx_initialized:
raise RuntimeError("RX was not initialized. init_rx() must be called before _stream_rx() or record()")
if num_samples is not None and rx_time is not None:
raise ValueError("Only input one of num_samples or rx_time")
elif num_samples is not None:
self._num_samples_to_record = num_samples
elif rx_time is not None:
self._num_samples_to_record = int(rx_time * self.rx_sample_rate)
else:
raise ValueError("Must provide input of one of num_samples or rx_time")
self.buffer_size = self.rx_buffer_size
num_buffers = self._num_samples_to_record // self.buffer_size + 1
self._max_num_buffers = num_buffers
self._num_buffers_processed = 0
self._num_buffers_processed = 0
self._last_buffer = None
self._accumulated_buffer = None
print("Starting stream")
self._stream_rx(
callback=self._accumulate_buffers_callback,
)
print("Finished stream")
metadata = {
"source": self.__class__.__name__,
"sample_rate": self.rx_sample_rate,
"center_frequency": self.rx_center_frequency,
"gain": self.rx_gain,
}
print("Creating recording")
# build recording, truncate to self._num_samples_to_record
recording = Recording(data=self._accumulated_buffer[:, : self._num_samples_to_record], metadata=metadata)
# reset to record again
self._accumulated_buffer = None
return recording
def stream_to_zmq(self, zmq_address, n_samples: int, buffer_size: Optional[int] = 10000):
"""
Stream iq samples as interleaved bytes via zmq.
:param zmq_address: The zmq address.
:type zmq_address:
:param n_samples: The number of samples to stream.
:type n_samples: int
:param buffer_size: The buffer size during streaming. Defaults to 10000.
:type buffer_size: int, optional
:return: The trimmed Recording.
:rtype: Recording
"""
self._previous_buffer = None
self._max_num_buffers = np.inf if n_samples == np.inf else math.ceil(n_samples / buffer_size)
self._num_buffers_processed = 0
self.zmq_address = _generate_full_zmq_address(str(zmq_address))
self.context = zmq.Context()
self.socket = self.context.socket(zmq.PUB)
self.socket.bind(self.zmq_address)
self._stream_rx(
self._zmq_bytestream_callback,
)
self.context.destroy()
self.socket.close()
def _accumulate_buffers_callback(self, buffer, metadata=None):
"""
Receives a buffer and saves it to self.accumulated_buffer.
"""
# expected buffer is complex samples range -1 to 1
# save the buffer until max reached
# return a recording
buffer = np.array(buffer) # make it 1d
if len(buffer.shape) == 1:
buffer = np.array([buffer])
# it runs these checks each time, is that an efficiency issue?
if self._max_num_buffers is None:
# default then
# this should probably print, but that would happen every buffer...
raise ValueError("Number of buffers for block capture not set.")
# add the given buffer to the pre-allocated buffer
if metadata is not None:
self.received_metadata = metadata
# TODO optimize, pre-allocate
if self._accumulated_buffer is not None:
self._accumulated_buffer = np.concatenate((self._accumulated_buffer, buffer), axis=1)
else:
# the first time
self._accumulated_buffer = buffer.copy()
self._num_buffers_processed = self._num_buffers_processed + 1
if self._num_buffers_processed >= self._max_num_buffers:
self.stop()
if self._last_buffer is not None:
if (buffer == self._last_buffer).all():
print("\033[93mWarning: Buffer Overflow Detected\033[0m")
self._last_buffer = buffer.copy()
else:
self._last_buffer = buffer.copy()
# print("Number of buffers received: " + str(self._num_buffers_processed))
def _zmq_bytestream_callback(self, buffer, metadata=None):
# push to ZMQ port
data = np.array(buffer).tobytes() # convert to bytes for transport
self.socket.send(data)
# print(f"Sent {self._num_buffers_processed} ZMQ buffers to {self.zmq_address}")
self._num_buffers_processed = self._num_buffers_processed + 1
if self._max_num_buffers is not None:
if self._num_buffers_processed >= self._max_num_buffers:
self.pause_rx()
if self._previous_buffer is not None:
if (buffer == self._previous_buffer).all():
print("\033[93mWarning: Buffer Overflow Detected\033[0m")
# TODO: I suggest we think about moving this part to the top of this function
# and skip the rest of the function in case of overflow.
# like, it's not necessary to stream repeated IQ data anyways!
self._previous_buffer = buffer.copy()
def pickle_buffer_to_zmq(self, zmq_address, buffer_size, num_buffers):
"""
Stream samples to a zmq address, packaged in binary buffers using numpy.pickle.
Useful for inference applications with a known input size.
May reduce transfer rates, but individual buffers will not have discontinuities.
:param zmq_address: The tcp address to stream to.
:type zmq_address: str
:param buffer_size: The number of iq samples in a buffer.
:type buffer_size: int
:param num_buffers: The number of buffers to stream before stopping.
:type num_buffers: int
"""
self._max_num_buffers = num_buffers
self.buffer_size = buffer_size
self._num_buffers_processed = 0
self.zmq_address = _generate_full_zmq_address(str(zmq_address))
self.context = zmq.Context()
self.socket = self.context.socket(zmq.PUB)
self.socket.bind(self.zmq_address)
self.set_rx_buffer_size(buffer_size)
self._stream_rx(self._zmq_pickle_buffer_callback)
def _zmq_pickle_buffer_callback(self, buffer, metadata=None):
# push to ZMQ port
# data = np.array(buffer).tobytes() # convert to bytes for transport
# self.socket.send(data)
self.socket.send(pickle.dumps(buffer))
# print(f"Sent {self._num_buffers_processed} ZMQ buffers to {self.zmq_address}")
self._num_buffers_processed = self._num_buffers_processed + 1
if self._max_num_buffers is not None:
if self._num_buffers_processed >= self._max_num_buffers:
self.stop()
if self._last_buffer is not None:
if (buffer == self._last_buffer).all():
print("\033[93mWarning: Buffer Overflow Detected\033[0m")
self._last_buffer = buffer.copy()
else:
self._last_buffer = buffer.copy()
def tx_recording(
self,
recording: Recording | np.ndarray,
num_samples: Optional[int] = None,
tx_time: Optional[int | float] = None,
):
"""
Transmit the given iq samples from the provided recording.
init_tx() must be called before this function.
:param recording: The recording to transmit.
:type recording: Recording or np.ndarray
:param num_samples: The number of samples to transmit, will repeat or
truncate the recording to this length. Defaults to None.
:type num_samples: int, optional
:param tx_time: The time to transmit, will repeat or truncate the
recording to this length. Defaults to None.
:type tx_time: int or float, optional
"""
if not self._tx_initialized:
raise RuntimeError(
"TX was not initialized. init_tx() must be called before _stream_tx() or transmit_recording()"
)
if num_samples is not None and tx_time is not None:
raise ValueError("Only input one of num_samples or tx_time")
elif num_samples is not None:
self._num_samples_to_transmit = num_samples
elif tx_time is not None:
self._num_samples_to_transmit = tx_time * self.tx_sample_rate
else:
self._num_samples_to_transmit = len(recording)
if isinstance(recording, np.ndarray):
self._samples_to_transmit = recording
elif isinstance(recording, Recording):
if len(recording.data) > 1:
warnings.warn("Recording object is multichannel, only channel 0 data was used for transmission")
self._samples_to_transmit = recording.data[0]
self._num_samples_transmitted = 0
self._stream_tx(self._loop_recording_callback)
def _loop_recording_callback(self, num_samples):
samples_left = self._num_samples_to_transmit - self._num_samples_transmitted
# find where to start based on num_samples_transmitted
start_index = self._num_samples_transmitted % len(self._samples_to_transmit)
# generates an array of indices that wrap around as many times as necessary.
indices = np.arange(start_index, start_index + num_samples) % len(self._samples_to_transmit)
samples = self._samples_to_transmit[indices]
# zero pad at the end so we are still giving the requested buffer size
# while also giving the exact number of non zero samples
if len(samples) > samples_left:
samples[int(samples_left) :] = 0
self.pause_tx()
self._num_samples_transmitted = self._num_samples_transmitted + num_samples
return samples
def pause_rx(self):
self._enable_rx = False
def pause_tx(self):
self._enable_tx = False
def stop(self):
self.pause_rx()
def supports_bias_tee(self) -> bool:
"""Return True when the radio supports bias-tee control."""
return False
def set_bias_tee(self, enable: bool):
"""Enable or disable bias-tee power when supported by the radio."""
raise NotImplementedError(f"{self.__class__.__name__} does not support bias-tee control")
@abstractmethod
def close(self):
pass
@abstractmethod
def init_rx(self, sample_rate, center_frequency, gain, channel, gain_mode):
pass
@abstractmethod
def init_tx(self, sample_rate, center_frequency, gain, channel, gain_mode):
pass
@abstractmethod
def _stream_rx(self, callback):
pass
@abstractmethod
def _stream_tx(self, callback):
pass
@abstractmethod
def set_clock_source(self, source):
"""
Sets the clock source to external or internal.
:param source: The clock source
:type source: str
"""
pass
def _generate_full_zmq_address(input_address):
"""
Helper function for zmq streaming.
If given a port number like 5556,
return tcp localhost address at that port.
Otherwise, return the address untouched.
"""
if ("://" not in str(input_address)) and _is_valid_port(input_address):
# If no transport protocol specified, assume TCP
return "tcp://*:" + str(input_address)
else:
# Otherwise, return the input unchanged
return input_address
def _is_valid_port(port):
"""
Helper function for zmq address.
"""
try:
port_num = int(port)
return 0 <= port_num <= 65535
except ValueError:
return False
def _verify_sample_format(samples):
"""
Verify that the sample data is in the range -1 to 1.
:param buffer: An array of samples.
:Return: True if the buffer is in the correct format, false if not.
:rtype: bool
"""
return np.max(np.abs(samples)) <= 1