updates_and_fixes #12

Merged
madrigal merged 9 commits from updates_and_fixes into main 2025-11-18 15:01:25 -05:00
Showing only changes of commit c673967a90 - Show all commits

View File

@ -1,5 +1,6 @@
import math import math
import pickle import pickle
import threading
import warnings import warnings
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Optional from typing import Optional
@ -27,17 +28,21 @@ class SDR(ABC):
self._tx_initialized = False self._tx_initialized = False
self._enable_rx = False self._enable_rx = False
self._enable_tx = False self._enable_tx = False
self._accumulated_buffer = None self._accumulated_buffer = None
self._max_num_buffers = None self._max_num_buffers = None
self._num_buffers_processed = 0 self._num_buffers_processed = 0
self._accumulated_buffer = None self._accumulated_buffer = None
self._last_buffer = None self._last_buffer = None
self._corrupted_buffer_count = 0
self.rx_sample_rate = None self.rx_sample_rate = None
self.rx_center_frequency = None self.rx_center_frequency = None
self.rx_gain = None self.rx_gain = None
self.tx_sample_rate = None self.tx_sample_rate = None
self.tx_center_frequency = None self.tx_center_frequency = None
self.tx_gain = None self.tx_gain = None
self._param_lock = threading.RLock() # Reentrant lock
def record(self, num_samples: Optional[int] = None, rx_time: Optional[int | float] = None) -> Recording: def record(self, num_samples: Optional[int] = None, rx_time: Optional[int | float] = None) -> Recording:
""" """
@ -71,7 +76,6 @@ class SDR(ABC):
self._max_num_buffers = num_buffers self._max_num_buffers = num_buffers
self._num_buffers_processed = 0 self._num_buffers_processed = 0
self._num_buffers_processed = 0
self._last_buffer = None self._last_buffer = None
self._accumulated_buffer = None self._accumulated_buffer = None
print("Starting stream") print("Starting stream")
@ -94,6 +98,7 @@ class SDR(ABC):
# reset to record again # reset to record again
self._accumulated_buffer = None self._accumulated_buffer = None
self._num_buffers_processed = 0
return recording return recording
def stream_to_zmq(self, zmq_address, n_samples: int, buffer_size: Optional[int] = 10000): def stream_to_zmq(self, zmq_address, n_samples: int, buffer_size: Optional[int] = 10000):
@ -110,21 +115,23 @@ class SDR(ABC):
:return: The trimmed Recording. :return: The trimmed Recording.
:rtype: Recording :rtype: Recording
""" """
try:
self._previous_buffer = None
self._max_num_buffers = np.inf if n_samples == np.inf else math.ceil(n_samples / buffer_size)
self._num_buffers_processed = 0
self.zmq_address = _generate_full_zmq_address(str(zmq_address))
self.context = zmq.Context()
self.socket = self.context.socket(zmq.PUB)
self.socket.bind(self.zmq_address)
self._previous_buffer = None self._stream_rx(
self._max_num_buffers = np.inf if n_samples == np.inf else math.ceil(n_samples / buffer_size) self._zmq_bytestream_callback,
self._num_buffers_processed = 0 )
self.zmq_address = _generate_full_zmq_address(str(zmq_address)) finally:
self.context = zmq.Context() if hasattr(self, "socket"):
self.socket = self.context.socket(zmq.PUB) self.socket.close()
self.socket.bind(self.zmq_address) if hasattr(self, "context"):
self.context.destroy()
self._stream_rx(
self._zmq_bytestream_callback,
)
self.context.destroy()
self.socket.close()
def _accumulate_buffers_callback(self, buffer, metadata=None): def _accumulate_buffers_callback(self, buffer, metadata=None):
""" """
@ -134,62 +141,72 @@ class SDR(ABC):
# save the buffer until max reached # save the buffer until max reached
# return a recording # return a recording
buffer = np.array(buffer) # make it 1d # Validate buffer
if len(buffer.shape) == 1: if not self._validate_buffer(buffer):
buffer = np.array([buffer]) print("Warning: Corrupted buffer detected, skipping")
self._corrupted_buffer_count += 1
return # Skip this buffer
# it runs these checks each time, is that an efficiency issue? if isinstance(buffer, np.ndarray):
if buffer.ndim == 1:
if self._max_num_buffers is None: buffer = buffer[np.newaxis, :] # make shape (1, N)
# default then
# this should probably print, but that would happen every buffer...
raise ValueError("Number of buffers for block capture not set.")
# add the given buffer to the pre-allocated buffer
if metadata is not None:
self.received_metadata = metadata
# TODO optimize, pre-allocate
if self._accumulated_buffer is not None:
self._accumulated_buffer = np.concatenate((self._accumulated_buffer, buffer), axis=1)
else: else:
# the first time buffer = np.array(buffer) # make it 1d
self._accumulated_buffer = buffer.copy() if len(buffer.shape) == 1:
buffer = np.array([buffer])
self._num_buffers_processed = self._num_buffers_processed + 1 # First call: pre-allocate if we know the final size
if self._accumulated_buffer is None:
# Check that _max_num_buffers is set
if self._max_num_buffers is None:
raise ValueError("Number of buffers for block capture not set.")
if self._num_samples_to_record is None:
raise ValueError("Number of samples not set before RX start.")
if metadata is not None:
self.received_metadata = metadata
# Preallocate once (avoid np.zeros; use np.empty for speed)
num_channels = buffer.shape[0]
self._accumulated_buffer = np.empty((num_channels, self._num_samples_to_record), dtype=buffer.dtype)
self._write_position = 0
print(f"Pre-allocated buffer for {self._num_samples_to_record:,} samples.")
# Write new buffer into pre-allocated array
n = buffer.shape[1]
start = self._write_position
end = min(start + n, self._num_samples_to_record)
samples_to_write = end - start
if samples_to_write > 0:
self._accumulated_buffer[:, start:end] = buffer[:, : end - start]
self._write_position = end
# Check if we're done
self._num_buffers_processed += 1
if self._num_buffers_processed >= self._max_num_buffers: if self._num_buffers_processed >= self._max_num_buffers:
self.stop() self.stop()
if self._last_buffer is not None: def _validate_buffer(self, buffer):
if (buffer == self._last_buffer).all(): """Check for obviously corrupt data."""
print("\033[93mWarning: Buffer Overflow Detected\033[0m") # Check for all zeros
self._last_buffer = buffer.copy() if np.all(buffer == 0):
else: return False
self._last_buffer = buffer.copy() # Check for all same value
if np.all(buffer == buffer[0]):
# print("Number of buffers received: " + str(self._num_buffers_processed)) return False
return True
def _zmq_bytestream_callback(self, buffer, metadata=None): def _zmq_bytestream_callback(self, buffer, metadata=None):
# push to ZMQ port # push to ZMQ port
data = np.array(buffer).tobytes() # convert to bytes for transport data = np.array(buffer).tobytes() # convert to bytes for transport
self.socket.send(data) self.socket.send(data)
# print(f"Sent {self._num_buffers_processed} ZMQ buffers to {self.zmq_address}")
self._num_buffers_processed = self._num_buffers_processed + 1 self._num_buffers_processed = self._num_buffers_processed + 1
if self._max_num_buffers is not None: if self._max_num_buffers is not None:
if self._num_buffers_processed >= self._max_num_buffers: if self._num_buffers_processed >= self._max_num_buffers:
self.pause_rx() self.pause_rx()
if self._previous_buffer is not None:
if (buffer == self._previous_buffer).all():
print("\033[93mWarning: Buffer Overflow Detected\033[0m")
# TODO: I suggest we think about moving this part to the top of this function
# and skip the rest of the function in case of overflow.
# like, it's not necessary to stream repeated IQ data anyways!
self._previous_buffer = buffer.copy()
def pickle_buffer_to_zmq(self, zmq_address, buffer_size, num_buffers): def pickle_buffer_to_zmq(self, zmq_address, buffer_size, num_buffers):
""" """
Stream samples to a zmq address, packaged in binary buffers using numpy.pickle. Stream samples to a zmq address, packaged in binary buffers using numpy.pickle.
@ -229,7 +246,7 @@ class SDR(ABC):
self.stop() self.stop()
if self._last_buffer is not None: if self._last_buffer is not None:
if (buffer == self._last_buffer).all(): if np.array_equal(buffer, self._last_buffer):
print("\033[93mWarning: Buffer Overflow Detected\033[0m") print("\033[93mWarning: Buffer Overflow Detected\033[0m")
self._last_buffer = buffer.copy() self._last_buffer = buffer.copy()
else: else:
@ -373,6 +390,58 @@ class SDR(ABC):
""" """
return self.tx_gain return self.tx_gain
def set_rx_sample_rate(self):
"""
Set the sample rate of the receiver.
"""
raise NotImplementedError
def set_rx_center_frequency(self):
"""
Set the center frequency of the receiver.
"""
raise NotImplementedError
def set_rx_gain(self):
"""
Set the gain setting of the receiver.
"""
raise NotImplementedError
def set_tx_sample_rate(self):
"""
Set the sample rate of the transmitter.
"""
raise NotImplementedError
def set_tx_center_frequency(self):
"""
Set the center frequency of the transmitter.
"""
raise NotImplementedError
def set_tx_gain(self):
"""
Set the gain setting of the transmitter.
"""
raise NotImplementedError
def supports_dynamic_updates(self) -> dict:
"""
Report which parameters can be updated during streaming.
Returns:
dict: {'center_frequency': bool, 'sample_rate': bool, 'gain': bool}
"""
return {"center_frequency": False, "sample_rate": False, "gain": False}
def __del__(self):
"""Cleanup on garbage collection."""
try:
self.close()
except Exception:
pass
@abstractmethod @abstractmethod
def close(self): def close(self):
pass pass
@ -442,3 +511,21 @@ def _verify_sample_format(samples):
""" """
return np.max(np.abs(samples)) <= 1 return np.max(np.abs(samples)) <= 1
class SDRError(Exception):
"""Base exception for SDR errors."""
pass
class SDRParameterError(SDRError):
"""Invalid parameter (sample rate, freq, gain)."""
pass
class SDROverflowError(SDRError):
"""Buffer overflow detected."""
pass