updates_and_fixes #12
|
|
@ -1,5 +1,6 @@
|
||||||
import math
|
import math
|
||||||
import pickle
|
import pickle
|
||||||
|
import threading
|
||||||
import warnings
|
import warnings
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
@ -27,17 +28,21 @@ class SDR(ABC):
|
||||||
self._tx_initialized = False
|
self._tx_initialized = False
|
||||||
self._enable_rx = False
|
self._enable_rx = False
|
||||||
self._enable_tx = False
|
self._enable_tx = False
|
||||||
|
|
||||||
self._accumulated_buffer = None
|
self._accumulated_buffer = None
|
||||||
self._max_num_buffers = None
|
self._max_num_buffers = None
|
||||||
self._num_buffers_processed = 0
|
self._num_buffers_processed = 0
|
||||||
self._accumulated_buffer = None
|
self._accumulated_buffer = None
|
||||||
self._last_buffer = None
|
self._last_buffer = None
|
||||||
|
self._corrupted_buffer_count = 0
|
||||||
|
|
||||||
self.rx_sample_rate = None
|
self.rx_sample_rate = None
|
||||||
self.rx_center_frequency = None
|
self.rx_center_frequency = None
|
||||||
self.rx_gain = None
|
self.rx_gain = None
|
||||||
self.tx_sample_rate = None
|
self.tx_sample_rate = None
|
||||||
self.tx_center_frequency = None
|
self.tx_center_frequency = None
|
||||||
self.tx_gain = None
|
self.tx_gain = None
|
||||||
|
self._param_lock = threading.RLock() # Reentrant lock
|
||||||
|
|
||||||
def record(self, num_samples: Optional[int] = None, rx_time: Optional[int | float] = None) -> Recording:
|
def record(self, num_samples: Optional[int] = None, rx_time: Optional[int | float] = None) -> Recording:
|
||||||
"""
|
"""
|
||||||
|
|
@ -71,7 +76,6 @@ class SDR(ABC):
|
||||||
|
|
||||||
self._max_num_buffers = num_buffers
|
self._max_num_buffers = num_buffers
|
||||||
self._num_buffers_processed = 0
|
self._num_buffers_processed = 0
|
||||||
self._num_buffers_processed = 0
|
|
||||||
self._last_buffer = None
|
self._last_buffer = None
|
||||||
self._accumulated_buffer = None
|
self._accumulated_buffer = None
|
||||||
print("Starting stream")
|
print("Starting stream")
|
||||||
|
|
@ -94,6 +98,7 @@ class SDR(ABC):
|
||||||
|
|
||||||
# reset to record again
|
# reset to record again
|
||||||
self._accumulated_buffer = None
|
self._accumulated_buffer = None
|
||||||
|
self._num_buffers_processed = 0
|
||||||
return recording
|
return recording
|
||||||
|
|
||||||
def stream_to_zmq(self, zmq_address, n_samples: int, buffer_size: Optional[int] = 10000):
|
def stream_to_zmq(self, zmq_address, n_samples: int, buffer_size: Optional[int] = 10000):
|
||||||
|
|
@ -110,7 +115,7 @@ class SDR(ABC):
|
||||||
:return: The trimmed Recording.
|
:return: The trimmed Recording.
|
||||||
:rtype: Recording
|
:rtype: Recording
|
||||||
"""
|
"""
|
||||||
|
try:
|
||||||
self._previous_buffer = None
|
self._previous_buffer = None
|
||||||
self._max_num_buffers = np.inf if n_samples == np.inf else math.ceil(n_samples / buffer_size)
|
self._max_num_buffers = np.inf if n_samples == np.inf else math.ceil(n_samples / buffer_size)
|
||||||
self._num_buffers_processed = 0
|
self._num_buffers_processed = 0
|
||||||
|
|
@ -122,9 +127,11 @@ class SDR(ABC):
|
||||||
self._stream_rx(
|
self._stream_rx(
|
||||||
self._zmq_bytestream_callback,
|
self._zmq_bytestream_callback,
|
||||||
)
|
)
|
||||||
|
finally:
|
||||||
self.context.destroy()
|
if hasattr(self, "socket"):
|
||||||
self.socket.close()
|
self.socket.close()
|
||||||
|
if hasattr(self, "context"):
|
||||||
|
self.context.destroy()
|
||||||
|
|
||||||
def _accumulate_buffers_callback(self, buffer, metadata=None):
|
def _accumulate_buffers_callback(self, buffer, metadata=None):
|
||||||
"""
|
"""
|
||||||
|
|
@ -134,62 +141,72 @@ class SDR(ABC):
|
||||||
# save the buffer until max reached
|
# save the buffer until max reached
|
||||||
# return a recording
|
# return a recording
|
||||||
|
|
||||||
|
# Validate buffer
|
||||||
|
if not self._validate_buffer(buffer):
|
||||||
|
print("Warning: Corrupted buffer detected, skipping")
|
||||||
|
self._corrupted_buffer_count += 1
|
||||||
|
return # Skip this buffer
|
||||||
|
|
||||||
|
if isinstance(buffer, np.ndarray):
|
||||||
|
if buffer.ndim == 1:
|
||||||
|
buffer = buffer[np.newaxis, :] # make shape (1, N)
|
||||||
|
else:
|
||||||
buffer = np.array(buffer) # make it 1d
|
buffer = np.array(buffer) # make it 1d
|
||||||
if len(buffer.shape) == 1:
|
if len(buffer.shape) == 1:
|
||||||
buffer = np.array([buffer])
|
buffer = np.array([buffer])
|
||||||
|
|
||||||
# it runs these checks each time, is that an efficiency issue?
|
# First call: pre-allocate if we know the final size
|
||||||
|
if self._accumulated_buffer is None:
|
||||||
|
# Check that _max_num_buffers is set
|
||||||
if self._max_num_buffers is None:
|
if self._max_num_buffers is None:
|
||||||
# default then
|
|
||||||
# this should probably print, but that would happen every buffer...
|
|
||||||
raise ValueError("Number of buffers for block capture not set.")
|
raise ValueError("Number of buffers for block capture not set.")
|
||||||
|
if self._num_samples_to_record is None:
|
||||||
# add the given buffer to the pre-allocated buffer
|
raise ValueError("Number of samples not set before RX start.")
|
||||||
|
|
||||||
if metadata is not None:
|
if metadata is not None:
|
||||||
self.received_metadata = metadata
|
self.received_metadata = metadata
|
||||||
|
|
||||||
# TODO optimize, pre-allocate
|
# Preallocate once (avoid np.zeros; use np.empty for speed)
|
||||||
if self._accumulated_buffer is not None:
|
num_channels = buffer.shape[0]
|
||||||
self._accumulated_buffer = np.concatenate((self._accumulated_buffer, buffer), axis=1)
|
self._accumulated_buffer = np.empty((num_channels, self._num_samples_to_record), dtype=buffer.dtype)
|
||||||
else:
|
self._write_position = 0
|
||||||
# the first time
|
print(f"Pre-allocated buffer for {self._num_samples_to_record:,} samples.")
|
||||||
self._accumulated_buffer = buffer.copy()
|
|
||||||
|
|
||||||
self._num_buffers_processed = self._num_buffers_processed + 1
|
# Write new buffer into pre-allocated array
|
||||||
|
n = buffer.shape[1]
|
||||||
|
start = self._write_position
|
||||||
|
end = min(start + n, self._num_samples_to_record)
|
||||||
|
samples_to_write = end - start
|
||||||
|
|
||||||
|
if samples_to_write > 0:
|
||||||
|
self._accumulated_buffer[:, start:end] = buffer[:, : end - start]
|
||||||
|
self._write_position = end
|
||||||
|
|
||||||
|
# Check if we're done
|
||||||
|
self._num_buffers_processed += 1
|
||||||
if self._num_buffers_processed >= self._max_num_buffers:
|
if self._num_buffers_processed >= self._max_num_buffers:
|
||||||
self.stop()
|
self.stop()
|
||||||
|
|
||||||
if self._last_buffer is not None:
|
def _validate_buffer(self, buffer):
|
||||||
if (buffer == self._last_buffer).all():
|
"""Check for obviously corrupt data."""
|
||||||
print("\033[93mWarning: Buffer Overflow Detected\033[0m")
|
# Check for all zeros
|
||||||
self._last_buffer = buffer.copy()
|
if np.all(buffer == 0):
|
||||||
else:
|
return False
|
||||||
self._last_buffer = buffer.copy()
|
# Check for all same value
|
||||||
|
if np.all(buffer == buffer[0]):
|
||||||
# print("Number of buffers received: " + str(self._num_buffers_processed))
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
def _zmq_bytestream_callback(self, buffer, metadata=None):
|
def _zmq_bytestream_callback(self, buffer, metadata=None):
|
||||||
# push to ZMQ port
|
# push to ZMQ port
|
||||||
data = np.array(buffer).tobytes() # convert to bytes for transport
|
data = np.array(buffer).tobytes() # convert to bytes for transport
|
||||||
self.socket.send(data)
|
self.socket.send(data)
|
||||||
|
|
||||||
# print(f"Sent {self._num_buffers_processed} ZMQ buffers to {self.zmq_address}")
|
|
||||||
|
|
||||||
self._num_buffers_processed = self._num_buffers_processed + 1
|
self._num_buffers_processed = self._num_buffers_processed + 1
|
||||||
if self._max_num_buffers is not None:
|
if self._max_num_buffers is not None:
|
||||||
if self._num_buffers_processed >= self._max_num_buffers:
|
if self._num_buffers_processed >= self._max_num_buffers:
|
||||||
self.pause_rx()
|
self.pause_rx()
|
||||||
|
|
||||||
if self._previous_buffer is not None:
|
|
||||||
if (buffer == self._previous_buffer).all():
|
|
||||||
print("\033[93mWarning: Buffer Overflow Detected\033[0m")
|
|
||||||
# TODO: I suggest we think about moving this part to the top of this function
|
|
||||||
# and skip the rest of the function in case of overflow.
|
|
||||||
# like, it's not necessary to stream repeated IQ data anyways!
|
|
||||||
self._previous_buffer = buffer.copy()
|
|
||||||
|
|
||||||
def pickle_buffer_to_zmq(self, zmq_address, buffer_size, num_buffers):
|
def pickle_buffer_to_zmq(self, zmq_address, buffer_size, num_buffers):
|
||||||
"""
|
"""
|
||||||
Stream samples to a zmq address, packaged in binary buffers using numpy.pickle.
|
Stream samples to a zmq address, packaged in binary buffers using numpy.pickle.
|
||||||
|
|
@ -229,7 +246,7 @@ class SDR(ABC):
|
||||||
self.stop()
|
self.stop()
|
||||||
|
|
||||||
if self._last_buffer is not None:
|
if self._last_buffer is not None:
|
||||||
if (buffer == self._last_buffer).all():
|
if np.array_equal(buffer, self._last_buffer):
|
||||||
print("\033[93mWarning: Buffer Overflow Detected\033[0m")
|
print("\033[93mWarning: Buffer Overflow Detected\033[0m")
|
||||||
self._last_buffer = buffer.copy()
|
self._last_buffer = buffer.copy()
|
||||||
else:
|
else:
|
||||||
|
|
@ -373,6 +390,58 @@ class SDR(ABC):
|
||||||
"""
|
"""
|
||||||
return self.tx_gain
|
return self.tx_gain
|
||||||
|
|
||||||
|
def set_rx_sample_rate(self):
|
||||||
|
"""
|
||||||
|
Set the sample rate of the receiver.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def set_rx_center_frequency(self):
|
||||||
|
"""
|
||||||
|
Set the center frequency of the receiver.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def set_rx_gain(self):
|
||||||
|
"""
|
||||||
|
Set the gain setting of the receiver.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def set_tx_sample_rate(self):
|
||||||
|
"""
|
||||||
|
Set the sample rate of the transmitter.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def set_tx_center_frequency(self):
|
||||||
|
"""
|
||||||
|
Set the center frequency of the transmitter.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def set_tx_gain(self):
|
||||||
|
"""
|
||||||
|
Set the gain setting of the transmitter.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def supports_dynamic_updates(self) -> dict:
|
||||||
|
"""
|
||||||
|
Report which parameters can be updated during streaming.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: {'center_frequency': bool, 'sample_rate': bool, 'gain': bool}
|
||||||
|
"""
|
||||||
|
return {"center_frequency": False, "sample_rate": False, "gain": False}
|
||||||
|
|
||||||
|
def __del__(self):
|
||||||
|
"""Cleanup on garbage collection."""
|
||||||
|
try:
|
||||||
|
self.close()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def close(self):
|
def close(self):
|
||||||
pass
|
pass
|
||||||
|
|
@ -442,3 +511,21 @@ def _verify_sample_format(samples):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
return np.max(np.abs(samples)) <= 1
|
return np.max(np.abs(samples)) <= 1
|
||||||
|
|
||||||
|
|
||||||
|
class SDRError(Exception):
|
||||||
|
"""Base exception for SDR errors."""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class SDRParameterError(SDRError):
|
||||||
|
"""Invalid parameter (sample rate, freq, gain)."""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class SDROverflowError(SDRError):
|
||||||
|
"""Buffer overflow detected."""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user