ria-toolkit-oss/src/ria_toolkit_oss/view/view_signal_simple.py
F fordg1 5cfced8855
Some checks failed
Build Sphinx Docs Set / Build Docs (pull_request) Failing after 1s
Build Project / Build Project (3.10) (pull_request) Failing after 1s
Build Project / Build Project (3.11) (pull_request) Failing after 1s
Build Project / Build Project (3.12) (pull_request) Failing after 1s
Test with tox / Test with tox (3.10) (pull_request) Failing after 1s
Test with tox / Test with tox (3.11) (pull_request) Failing after 1s
Test with tox / Test with tox (3.12) (pull_request) Failing after 1s
Fix merge conflicts and port all imports from utils to ria_toolkit_oss
Resolves unresolved merge conflict markers left in committed files across
the annotations, view, data, and CLI packages. Updates all remaining
imports from the old utils.* namespace to ria_toolkit_oss.datatypes,
ria_toolkit_oss.io, and ria_toolkit_oss.view equivalents.
2026-03-31 15:16:32 -04:00

388 lines
14 KiB
Python

"""Shared plotting primitives for signal visualization."""
from __future__ import annotations
import gc
import json
from typing import Optional
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
from scipy.fft import fft, fftshift
from scipy.signal.windows import hann
from ria_toolkit_oss.datatypes.recording import Recording
from ria_toolkit_oss.view.tools import COLORS, decimate, extract_metadata_fields, set_path
def _add_annotations(annotations, compact_mode, show_labels, sample_rate_hz, center_freq_hz, ax2):
if annotations and not compact_mode:
for annotation in annotations:
start_idx = annotation.get("core:sample_start", 0)
length = annotation.get("core:sample_count", 0)
start_time = start_idx / sample_rate_hz
end_time = (start_idx + length) / sample_rate_hz
freq_low = annotation.get("core:freq_lower_edge", center_freq_hz - sample_rate_hz / 4)
freq_high = annotation.get("core:freq_upper_edge", center_freq_hz + sample_rate_hz / 4)
comment = annotation.get("core:comment", "{}")
try:
comment_data = json.loads(comment) if isinstance(comment, str) else comment
ann_type = comment_data.get("type", "unknown")
if ann_type == "intersection":
color = COLORS["success"]
elif ann_type == "parallel":
color = COLORS["primary"]
elif ann_type == "standalone":
color = COLORS["warning"]
else:
color = COLORS["error"]
except Exception:
color = COLORS["error"]
rect = plt.Rectangle(
(start_time, freq_low),
end_time - start_time,
freq_high - freq_low,
color=color,
alpha=0.4,
linewidth=2,
)
ax2.add_patch(rect)
if show_labels:
label = annotation.get("core:label", "Signal")
ax2.text(
start_time,
freq_high,
label,
color=COLORS["light"],
fontsize=10,
bbox=dict(boxstyle="round,pad=0.2", facecolor=color, alpha=0.7),
)
def _get_nfft_size(signal, fast_mode):
if len(signal) < 1000:
nfft = 128
elif len(signal) < 10_000:
nfft = 256
elif len(signal) < 100_000:
nfft = 512
elif len(signal) < 1_000_000:
nfft = 1024
else:
nfft = 2048
if fast_mode:
nfft = min(nfft, 512)
overlap = nfft // 8 if fast_mode else nfft // 4
return nfft, overlap
def _get_plot_samples(signal, fast_mode, slow_max, fast_max):
max_samples = fast_max if fast_mode else slow_max
if len(signal) > max_samples:
start_idx = len(signal) // 2 - max_samples // 2
return signal[start_idx : start_idx + max_samples]
else:
return signal
def _set_dpi(fast_mode, labels_mode, extension):
if fast_mode:
dpi = 75
elif labels_mode:
dpi = 200
else:
dpi = 150
return dpi if extension == "png" else None
def setup_style(*, labels_mode: bool = False, compact_mode: bool = False) -> None:
"""Configure matplotlib with the signal-testbed styling."""
plt.style.use("dark_background")
if compact_mode:
base_font = 8
title_font = 10
label_font = 8
elif labels_mode:
base_font = 12
title_font = 16
label_font = 14
else:
base_font = 10
title_font = 12
label_font = 10
matplotlib.rcParams.update(
{
"figure.facecolor": "#0f172a",
"axes.facecolor": "#1e293b",
"axes.edgecolor": COLORS["muted"],
"axes.labelcolor": COLORS["light"],
"text.color": COLORS["light"],
"xtick.color": COLORS["muted"],
"ytick.color": COLORS["muted"],
"grid.color": COLORS["muted"],
"grid.alpha": 0.3,
"font.size": base_font,
"axes.titlesize": title_font,
"axes.labelsize": label_font,
"figure.titlesize": title_font + 2,
"legend.frameon": False,
"legend.facecolor": "none",
"xtick.labelsize": base_font,
"ytick.labelsize": base_font,
}
)
def detect_constellation_symbols(signal: np.ndarray, method: str = "differential") -> np.ndarray:
"""Heuristic symbol detector used for constellation highlighting."""
if len(signal) < 100:
return np.ones(len(signal), dtype=bool)
if method == "differential":
di = np.diff(signal.imag)
dq = np.diff(signal.real)
derivative_magnitude = np.sqrt(di**2 + dq**2)
derivative_magnitude = np.append(derivative_magnitude, 0)
threshold = np.percentile(derivative_magnitude, 15)
return derivative_magnitude < threshold
if method == "amplitude":
amplitude = np.abs(signal)
amplitude_change = np.abs(np.diff(amplitude))
amplitude_change = np.append(amplitude_change, 0)
threshold = np.percentile(amplitude_change, 20)
return amplitude_change < threshold
if method == "phase":
phase = np.angle(signal)
phase_diff = np.diff(np.unwrap(phase))
phase_diff = np.append(phase_diff, 0)
threshold = np.percentile(np.abs(phase_diff), 20)
return np.abs(phase_diff) < threshold
if method == "combined":
diff_stable = detect_constellation_symbols(signal, "differential")
amp_stable = detect_constellation_symbols(signal, "amplitude")
phase_stable = detect_constellation_symbols(signal, "phase")
stability_count = diff_stable.astype(int) + amp_stable.astype(int) + phase_stable.astype(int)
return stability_count >= 2
raise ValueError(f"Unknown method: {method}")
def view_simple_sig(
recording: Recording,
annotations: Optional[list] = None,
output_path: Optional[str] = "images/signal.png",
saveplot: Optional[bool] = True,
fast_mode: Optional[bool] = False,
compact_mode: Optional[bool] = False,
horizontal_mode: Optional[bool] = False,
constellation_mode: Optional[bool] = False,
labels_mode: Optional[bool] = False,
slice: Optional[tuple] = None,
title: Optional[str] = "Signal",
):
"""
Create a simple plot of various signal visualizations as a png or svg image.
:param recording: The recording object to plot.
:type recording: Recording
:param output_path: The output image path. Defaults to "images/signal.png"
:type output_path: str, optional
:param saveplot: Whether or not to save the plot. Defaults to True.
:type saveplot: bool, optional
:param fast_mode: Use fast mode for faster render. Defaults to False.
:type fast_mode: bool, optional
:param compact_mode: Use compact mode for compact plot. Defaults to False.
:type compact_mode: bool, optional
:param horizontal_mode: Display plots horizontally. Defaults to False.
:type horizontal_mode: bool, optional
:param constellation_mode: Display constellation plot and PSD if not using compact mode. Defaults to False.
:type constellation_mode: bool, optional
:param labels_mode: Display more thorough labels. Defaults to False.
:type labels_mode: bool, optional
:param slice: Slice of signal to display. Defaults to None.
:type slice: tuple[int, int], optional
:param title: Title of plot. Defaults to "Signal".
:type title: str, optional
"""
signal = recording.data[0]
sample_rate_hz, center_freq_hz, sdr = extract_metadata_fields(recording.metadata)
setup_style(labels_mode=labels_mode, compact_mode=compact_mode)
if slice:
start_idx, end_idx = slice
signal = signal[start_idx:end_idx]
print(f"Using slice: samples {start_idx} to {end_idx} ({len(signal):,} samples)")
max_display_pixels = 100_000 if fast_mode else 250_000
display_signal = decimate(signal, max_display_pixels) if len(signal) > max_display_pixels else signal
spec_signal = signal
if compact_mode:
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(12, 6), gridspec_kw={"height_ratios": [1, 5]})
show_title = False
show_labels = False
ax_constellation = ax_psd = None
elif horizontal_mode:
if constellation_mode:
fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(18, 6))
ax_constellation = ax3
else:
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 6))
ax_constellation = None
show_title = True
show_labels = labels_mode
ax_psd = None
else:
if constellation_mode:
fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(16, 12))
ax_constellation, ax_psd = ax3, ax4
else:
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(14, 10))
ax_constellation = ax_psd = None
show_title = True
show_labels = labels_mode
if show_title:
fig.suptitle(title, fontsize=16, color=COLORS["light"], y=0.96)
fig.patch.set_facecolor("#0f172a")
total_duration_s = len(signal) / sample_rate_hz if sample_rate_hz else 0.0
t_s = np.linspace(0, total_duration_s, len(display_signal)) if len(display_signal) else np.array([])
ax1.plot(t_s, display_signal.real, color=COLORS["purple"], linewidth=0.8, alpha=0.8, label="I")
ax1.plot(t_s, display_signal.imag, color=COLORS["magenta"], linewidth=0.8, alpha=0.8, label="Q")
ax1.set_xlim(0, total_duration_s)
ax1.grid(True, alpha=0.3)
nfft, overlap = _get_nfft_size(signal=signal, fast_mode=fast_mode)
_, freqs, _, _ = ax2.specgram(
spec_signal,
NFFT=nfft,
Fc=center_freq_hz,
Fs=sample_rate_hz,
noverlap=overlap,
cmap="twilight",
)
ax2.set_ylim(center_freq_hz - sample_rate_hz / 2, center_freq_hz + sample_rate_hz / 2)
ax2.set_xlim(0, total_duration_s)
if show_labels:
if horizontal_mode:
ax1.set_xlabel("Time (s)")
else:
ax2.set_xlabel("Time (s)")
ax1.set_ylabel("Amplitude")
ax1.set_title(f"Time Series - {sdr} SDR", loc="left", pad=10)
ax1.legend(loc="upper right")
ax2.set_ylabel("Frequency (Hz)")
ax2.set_title(
f"Spectrogram - {center_freq_hz / 1e6:.1f} MHz ± {sample_rate_hz / 2e6:.1f} MHz", loc="left", pad=10
)
yticks = ax2.get_yticks()
ax2.set_yticklabels([f"{y / 1e6:.1f}" for y in yticks])
elif not compact_mode:
ax1.set_title("Time Series", loc="left", pad=10)
ax1.legend(loc="upper right", fontsize=8)
ax2.set_title("Spectrogram", loc="left", pad=10)
_add_annotations(
annotations=annotations,
compact_mode=compact_mode,
show_labels=show_labels,
sample_rate_hz=sample_rate_hz,
center_freq_hz=center_freq_hz,
ax2=ax2,
)
if ax_constellation is not None:
constellation_samples = _get_plot_samples(signal=signal, fast_mode=fast_mode, slow_max=50_000, fast_max=20_000)
method = "differential" if fast_mode else "combined"
stable_points = detect_constellation_symbols(constellation_samples, method=method)
ax_constellation.scatter(
constellation_samples.real[~stable_points],
constellation_samples.imag[~stable_points],
c=COLORS["muted"],
s=0.5,
alpha=0.2,
)
ax_constellation.scatter(
constellation_samples.real[stable_points],
constellation_samples.imag[stable_points],
c=COLORS["purple"],
s=3,
alpha=0.8,
)
ax_constellation.set_xlabel("In-phase (I)")
ax_constellation.set_ylabel("Quadrature (Q)")
ax_constellation.set_title("Constellation")
ax_constellation.grid(True, alpha=0.3)
ax_constellation.set_aspect("equal")
if ax_psd is not None:
psd_samples = _get_plot_samples(signal=signal, fast_mode=fast_mode, slow_max=65_536, fast_max=16_384)
window = hann(len(psd_samples))
spectrum = np.abs(fftshift(fft(psd_samples * window))) ** 2
freqs = np.linspace(-sample_rate_hz / 2, sample_rate_hz / 2, len(psd_samples))
freqs = freqs + center_freq_hz
spectrum_db = 10 * np.log10(spectrum + 1e-12)
ax_psd.plot(freqs / 1e6, spectrum_db, color=COLORS["accent"], linewidth=1.0)
ax_psd.set_xlabel("Frequency (MHz)")
ax_psd.set_ylabel("Power (dB)")
ax_psd.set_title("Power Spectral Density")
ax_psd.grid(True, alpha=0.3)
if compact_mode:
ax1.set_xticks([])
ax1.set_yticks([])
ax2.set_xticks([])
ax2.set_yticks([])
plt.subplots_adjust(left=0, right=1, top=1, bottom=0, hspace=0)
else:
plt.tight_layout()
if show_title:
plt.subplots_adjust(top=0.92)
if saveplot:
output_path, extension = set_path(output_path=output_path)
dpi_value = _set_dpi(fast_mode=fast_mode, labels_mode=labels_mode, extension=extension)
plt.savefig(output_path, dpi=dpi_value, bbox_inches="tight", facecolor="#0f172a", edgecolor="none")
print(f"Saved signal plot to {output_path}")
return output_path
plt.show()
# Garbage collection and clean up to prevent memory overloading
plt.close("all")
gc.collect()
return None
__all__ = [
"setup_style",
"detect_constellation_symbols",
"view_simple_sig",
]