384 lines
14 KiB
Python
384 lines
14 KiB
Python
import math
|
|
import pickle
|
|
import warnings
|
|
from abc import ABC, abstractmethod
|
|
from typing import Optional
|
|
|
|
import numpy as np
|
|
import zmq
|
|
|
|
from ria_toolkit_oss.datatypes.recording import Recording
|
|
|
|
|
|
class SDR(ABC):
|
|
"""
|
|
This class defines a common interface (a template) for all SDR devices.
|
|
Each specific SDR implementation should subclass SDR and provide concrete implementations
|
|
for the abstract methods.
|
|
|
|
To add support for a new radio, subclass this interface and implement all abstract methods.
|
|
If you experience difficulties, please `contact us <mailto:info@qoherent.ai>`_, we are happy to
|
|
provide additional direction and/or help with the implementation details.
|
|
"""
|
|
|
|
def __init__(self):
|
|
|
|
self._rx_initialized = False
|
|
self._tx_initialized = False
|
|
self._enable_rx = False
|
|
self._enable_tx = False
|
|
self._accumulated_buffer = None
|
|
self._max_num_buffers = None
|
|
self._num_buffers_processed = 0
|
|
self._accumulated_buffer = None
|
|
self._last_buffer = None
|
|
|
|
def record(self, num_samples: Optional[int] = None, rx_time: Optional[int | float] = None) -> Recording:
|
|
"""
|
|
Create a radio recording of a given length. Either ``num_samples`` or ``rx_time`` must be provided.
|
|
|
|
Note that ``init_rx()`` must be called before ``record()``.
|
|
|
|
:param num_samples: The number of samples to record.
|
|
:type num_samples: int, optional
|
|
:param rx_time: The time to record.
|
|
:type rx_time: int or float, optional
|
|
|
|
:return: The Recording object
|
|
:rtype: Recording
|
|
"""
|
|
|
|
if not self._rx_initialized:
|
|
raise RuntimeError("RX was not initialized. init_rx() must be called before _stream_rx() or record()")
|
|
|
|
if num_samples is not None and rx_time is not None:
|
|
raise ValueError("Only input one of num_samples or rx_time")
|
|
elif num_samples is not None:
|
|
self._num_samples_to_record = num_samples
|
|
elif rx_time is not None:
|
|
self._num_samples_to_record = int(rx_time * self.rx_sample_rate)
|
|
else:
|
|
raise ValueError("Must provide input of one of num_samples or rx_time")
|
|
|
|
self.buffer_size = self.rx_buffer_size
|
|
num_buffers = self._num_samples_to_record // self.buffer_size + 1
|
|
|
|
self._max_num_buffers = num_buffers
|
|
self._num_buffers_processed = 0
|
|
self._num_buffers_processed = 0
|
|
self._last_buffer = None
|
|
self._accumulated_buffer = None
|
|
print("Starting stream")
|
|
|
|
self._stream_rx(
|
|
callback=self._accumulate_buffers_callback,
|
|
)
|
|
|
|
print("Finished stream")
|
|
metadata = {
|
|
"source": self.__class__.__name__,
|
|
"sample_rate": self.rx_sample_rate,
|
|
"center_frequency": self.rx_center_frequency,
|
|
"gain": self.rx_gain,
|
|
}
|
|
|
|
print("Creating recording")
|
|
# build recording, truncate to self._num_samples_to_record
|
|
recording = Recording(data=self._accumulated_buffer[:, : self._num_samples_to_record], metadata=metadata)
|
|
|
|
# reset to record again
|
|
self._accumulated_buffer = None
|
|
return recording
|
|
|
|
def stream_to_zmq(self, zmq_address, n_samples: int, buffer_size: Optional[int] = 10000):
|
|
"""
|
|
Stream iq samples as interleaved bytes via zmq.
|
|
|
|
:param zmq_address: The zmq address.
|
|
:type zmq_address:
|
|
:param n_samples: The number of samples to stream.
|
|
:type n_samples: int
|
|
:param buffer_size: The buffer size during streaming. Defaults to 10000.
|
|
:type buffer_size: int, optional
|
|
|
|
:return: The trimmed Recording.
|
|
:rtype: Recording
|
|
"""
|
|
|
|
self._previous_buffer = None
|
|
self._max_num_buffers = np.inf if n_samples == np.inf else math.ceil(n_samples / buffer_size)
|
|
self._num_buffers_processed = 0
|
|
self.zmq_address = _generate_full_zmq_address(str(zmq_address))
|
|
self.context = zmq.Context()
|
|
self.socket = self.context.socket(zmq.PUB)
|
|
self.socket.bind(self.zmq_address)
|
|
|
|
self._stream_rx(
|
|
self._zmq_bytestream_callback,
|
|
)
|
|
|
|
self.context.destroy()
|
|
self.socket.close()
|
|
|
|
def _accumulate_buffers_callback(self, buffer, metadata=None):
|
|
"""
|
|
Receives a buffer and saves it to self.accumulated_buffer.
|
|
"""
|
|
# expected buffer is complex samples range -1 to 1
|
|
# save the buffer until max reached
|
|
# return a recording
|
|
|
|
buffer = np.array(buffer) # make it 1d
|
|
if len(buffer.shape) == 1:
|
|
buffer = np.array([buffer])
|
|
|
|
# it runs these checks each time, is that an efficiency issue?
|
|
|
|
if self._max_num_buffers is None:
|
|
# default then
|
|
# this should probably print, but that would happen every buffer...
|
|
raise ValueError("Number of buffers for block capture not set.")
|
|
|
|
# add the given buffer to the pre-allocated buffer
|
|
|
|
if metadata is not None:
|
|
self.received_metadata = metadata
|
|
|
|
# TODO optimize, pre-allocate
|
|
if self._accumulated_buffer is not None:
|
|
self._accumulated_buffer = np.concatenate((self._accumulated_buffer, buffer), axis=1)
|
|
else:
|
|
# the first time
|
|
self._accumulated_buffer = buffer.copy()
|
|
|
|
self._num_buffers_processed = self._num_buffers_processed + 1
|
|
if self._num_buffers_processed >= self._max_num_buffers:
|
|
self.stop()
|
|
|
|
if self._last_buffer is not None:
|
|
if (buffer == self._last_buffer).all():
|
|
print("\033[93mWarning: Buffer Overflow Detected\033[0m")
|
|
self._last_buffer = buffer.copy()
|
|
else:
|
|
self._last_buffer = buffer.copy()
|
|
|
|
# print("Number of buffers received: " + str(self._num_buffers_processed))
|
|
|
|
def _zmq_bytestream_callback(self, buffer, metadata=None):
|
|
# push to ZMQ port
|
|
data = np.array(buffer).tobytes() # convert to bytes for transport
|
|
self.socket.send(data)
|
|
|
|
# print(f"Sent {self._num_buffers_processed} ZMQ buffers to {self.zmq_address}")
|
|
|
|
self._num_buffers_processed = self._num_buffers_processed + 1
|
|
if self._max_num_buffers is not None:
|
|
if self._num_buffers_processed >= self._max_num_buffers:
|
|
self.pause_rx()
|
|
|
|
if self._previous_buffer is not None:
|
|
if (buffer == self._previous_buffer).all():
|
|
print("\033[93mWarning: Buffer Overflow Detected\033[0m")
|
|
# TODO: I suggest we think about moving this part to the top of this function
|
|
# and skip the rest of the function in case of overflow.
|
|
# like, it's not necessary to stream repeated IQ data anyways!
|
|
self._previous_buffer = buffer.copy()
|
|
|
|
def pickle_buffer_to_zmq(self, zmq_address, buffer_size, num_buffers):
|
|
"""
|
|
Stream samples to a zmq address, packaged in binary buffers using numpy.pickle.
|
|
Useful for inference applications with a known input size.
|
|
May reduce transfer rates, but individual buffers will not have discontinuities.
|
|
|
|
:param zmq_address: The tcp address to stream to.
|
|
:type zmq_address: str
|
|
:param buffer_size: The number of iq samples in a buffer.
|
|
:type buffer_size: int
|
|
:param num_buffers: The number of buffers to stream before stopping.
|
|
:type num_buffers: int
|
|
"""
|
|
self._max_num_buffers = num_buffers
|
|
self.buffer_size = buffer_size
|
|
self._num_buffers_processed = 0
|
|
self.zmq_address = _generate_full_zmq_address(str(zmq_address))
|
|
self.context = zmq.Context()
|
|
self.socket = self.context.socket(zmq.PUB)
|
|
self.socket.bind(self.zmq_address)
|
|
self.set_rx_buffer_size(buffer_size)
|
|
|
|
self._stream_rx(self._zmq_pickle_buffer_callback)
|
|
|
|
def _zmq_pickle_buffer_callback(self, buffer, metadata=None):
|
|
# push to ZMQ port
|
|
# data = np.array(buffer).tobytes() # convert to bytes for transport
|
|
# self.socket.send(data)
|
|
|
|
self.socket.send(pickle.dumps(buffer))
|
|
|
|
# print(f"Sent {self._num_buffers_processed} ZMQ buffers to {self.zmq_address}")
|
|
|
|
self._num_buffers_processed = self._num_buffers_processed + 1
|
|
if self._max_num_buffers is not None:
|
|
if self._num_buffers_processed >= self._max_num_buffers:
|
|
self.stop()
|
|
|
|
if self._last_buffer is not None:
|
|
if (buffer == self._last_buffer).all():
|
|
print("\033[93mWarning: Buffer Overflow Detected\033[0m")
|
|
self._last_buffer = buffer.copy()
|
|
else:
|
|
self._last_buffer = buffer.copy()
|
|
|
|
def tx_recording(
|
|
self,
|
|
recording: Recording | np.ndarray,
|
|
num_samples: Optional[int] = None,
|
|
tx_time: Optional[int | float] = None,
|
|
):
|
|
"""
|
|
Transmit the given iq samples from the provided recording.
|
|
init_tx() must be called before this function.
|
|
|
|
:param recording: The recording to transmit.
|
|
:type recording: Recording or np.ndarray
|
|
:param num_samples: The number of samples to transmit, will repeat or
|
|
truncate the recording to this length. Defaults to None.
|
|
:type num_samples: int, optional
|
|
:param tx_time: The time to transmit, will repeat or truncate the
|
|
recording to this length. Defaults to None.
|
|
:type tx_time: int or float, optional
|
|
"""
|
|
|
|
if not self._tx_initialized:
|
|
raise RuntimeError(
|
|
"TX was not initialized. init_tx() must be called before _stream_tx() or transmit_recording()"
|
|
)
|
|
|
|
if num_samples is not None and tx_time is not None:
|
|
raise ValueError("Only input one of num_samples or tx_time")
|
|
elif num_samples is not None:
|
|
self._num_samples_to_transmit = num_samples
|
|
elif tx_time is not None:
|
|
self._num_samples_to_transmit = tx_time * self.tx_sample_rate
|
|
else:
|
|
self._num_samples_to_transmit = len(recording)
|
|
|
|
if isinstance(recording, np.ndarray):
|
|
self._samples_to_transmit = recording
|
|
elif isinstance(recording, Recording):
|
|
if len(recording.data) > 1:
|
|
warnings.warn("Recording object is multichannel, only channel 0 data was used for transmission")
|
|
|
|
self._samples_to_transmit = recording.data[0]
|
|
|
|
self._num_samples_transmitted = 0
|
|
|
|
self._stream_tx(self._loop_recording_callback)
|
|
|
|
def _loop_recording_callback(self, num_samples):
|
|
|
|
samples_left = self._num_samples_to_transmit - self._num_samples_transmitted
|
|
# find where to start based on num_samples_transmitted
|
|
start_index = self._num_samples_transmitted % len(self._samples_to_transmit)
|
|
|
|
# generates an array of indices that wrap around as many times as necessary.
|
|
indices = np.arange(start_index, start_index + num_samples) % len(self._samples_to_transmit)
|
|
samples = self._samples_to_transmit[indices]
|
|
|
|
# zero pad at the end so we are still giving the requested buffer size
|
|
# while also giving the exact number of non zero samples
|
|
if len(samples) > samples_left:
|
|
samples[int(samples_left) :] = 0
|
|
self.pause_tx()
|
|
|
|
self._num_samples_transmitted = self._num_samples_transmitted + num_samples
|
|
|
|
return samples
|
|
|
|
def pause_rx(self):
|
|
self._enable_rx = False
|
|
|
|
def pause_tx(self):
|
|
self._enable_tx = False
|
|
|
|
def stop(self):
|
|
self.pause_rx()
|
|
|
|
def supports_bias_tee(self) -> bool:
|
|
"""Return True when the radio supports bias-tee control."""
|
|
return False
|
|
|
|
def set_bias_tee(self, enable: bool):
|
|
"""Enable or disable bias-tee power when supported by the radio."""
|
|
raise NotImplementedError(f"{self.__class__.__name__} does not support bias-tee control")
|
|
|
|
@abstractmethod
|
|
def close(self):
|
|
pass
|
|
|
|
@abstractmethod
|
|
def init_rx(self, sample_rate, center_frequency, gain, channel, gain_mode):
|
|
pass
|
|
|
|
@abstractmethod
|
|
def init_tx(self, sample_rate, center_frequency, gain, channel, gain_mode):
|
|
pass
|
|
|
|
@abstractmethod
|
|
def _stream_rx(self, callback):
|
|
pass
|
|
|
|
@abstractmethod
|
|
def _stream_tx(self, callback):
|
|
pass
|
|
|
|
@abstractmethod
|
|
def set_clock_source(self, source):
|
|
"""
|
|
Sets the clock source to external or internal.
|
|
|
|
:param source: The clock source
|
|
:type source: str
|
|
"""
|
|
pass
|
|
|
|
|
|
def _generate_full_zmq_address(input_address):
|
|
"""
|
|
Helper function for zmq streaming.
|
|
If given a port number like 5556,
|
|
return tcp localhost address at that port.
|
|
Otherwise, return the address untouched.
|
|
"""
|
|
|
|
if ("://" not in str(input_address)) and _is_valid_port(input_address):
|
|
# If no transport protocol specified, assume TCP
|
|
return "tcp://*:" + str(input_address)
|
|
else:
|
|
# Otherwise, return the input unchanged
|
|
return input_address
|
|
|
|
|
|
def _is_valid_port(port):
|
|
"""
|
|
Helper function for zmq address.
|
|
"""
|
|
try:
|
|
port_num = int(port)
|
|
return 0 <= port_num <= 65535
|
|
except ValueError:
|
|
return False
|
|
|
|
|
|
def _verify_sample_format(samples):
|
|
"""
|
|
Verify that the sample data is in the range -1 to 1.
|
|
|
|
:param buffer: An array of samples.
|
|
|
|
:Return: True if the buffer is in the correct format, false if not.
|
|
:rtype: bool
|
|
"""
|
|
|
|
return np.max(np.abs(samples)) <= 1
|