"""Shared plotting primitives for signal visualization.""" from __future__ import annotations import gc import json from typing import Optional import matplotlib import matplotlib.pyplot as plt import numpy as np from scipy.fft import fft, fftshift from scipy.signal.windows import hann from ria_toolkit_oss.datatypes.recording import Recording from ria_toolkit_oss.view.tools import ( COLORS, decimate, extract_metadata_fields, set_path, ) def _add_annotations(annotations, compact_mode, show_labels, sample_rate_hz, center_freq_hz, ax2): if annotations and not compact_mode: for annotation in annotations: start_idx = annotation.get("core:sample_start", 0) length = annotation.get("core:sample_count", 0) start_time = start_idx / sample_rate_hz end_time = (start_idx + length) / sample_rate_hz freq_low = annotation.get("core:freq_lower_edge", center_freq_hz - sample_rate_hz / 4) freq_high = annotation.get("core:freq_upper_edge", center_freq_hz + sample_rate_hz / 4) comment = annotation.get("core:comment", "{}") try: comment_data = json.loads(comment) if isinstance(comment, str) else comment ann_type = comment_data.get("type", "unknown") if ann_type == "intersection": color = COLORS["success"] elif ann_type == "parallel": color = COLORS["primary"] elif ann_type == "standalone": color = COLORS["warning"] else: color = COLORS["error"] except Exception: color = COLORS["error"] rect = plt.Rectangle( (start_time, freq_low), end_time - start_time, freq_high - freq_low, color=color, alpha=0.4, linewidth=2, ) ax2.add_patch(rect) if show_labels: label = annotation.get("core:label", "Signal") ax2.text( start_time, freq_high, label, color=COLORS["light"], fontsize=10, bbox=dict(boxstyle="round,pad=0.2", facecolor=color, alpha=0.7), ) def _get_nfft_size(signal, fast_mode): if len(signal) < 1000: nfft = 128 elif len(signal) < 10_000: nfft = 256 elif len(signal) < 100_000: nfft = 512 elif len(signal) < 1_000_000: nfft = 1024 else: nfft = 2048 if fast_mode: nfft = min(nfft, 512) overlap = nfft // 8 if fast_mode else nfft // 4 return nfft, overlap def _get_plot_samples(signal, fast_mode, slow_max, fast_max): max_samples = fast_max if fast_mode else slow_max if len(signal) > max_samples: start_idx = len(signal) // 2 - max_samples // 2 return signal[start_idx : start_idx + max_samples] else: return signal def _set_dpi(fast_mode, labels_mode, extension): if fast_mode: dpi = 75 elif labels_mode: dpi = 200 else: dpi = 150 return dpi if extension == "png" else None def setup_style(*, labels_mode: bool = False, compact_mode: bool = False) -> None: """Configure matplotlib with the signal-testbed styling.""" plt.style.use("dark_background") if compact_mode: base_font = 8 title_font = 10 label_font = 8 elif labels_mode: base_font = 12 title_font = 16 label_font = 14 else: base_font = 10 title_font = 12 label_font = 10 matplotlib.rcParams.update( { "figure.facecolor": "#0f172a", "axes.facecolor": "#1e293b", "axes.edgecolor": COLORS["muted"], "axes.labelcolor": COLORS["light"], "text.color": COLORS["light"], "xtick.color": COLORS["muted"], "ytick.color": COLORS["muted"], "grid.color": COLORS["muted"], "grid.alpha": 0.3, "font.size": base_font, "axes.titlesize": title_font, "axes.labelsize": label_font, "figure.titlesize": title_font + 2, "legend.frameon": False, "legend.facecolor": "none", "xtick.labelsize": base_font, "ytick.labelsize": base_font, } ) def detect_constellation_symbols(signal: np.ndarray, method: str = "differential") -> np.ndarray: """Heuristic symbol detector used for constellation highlighting.""" if len(signal) < 100: return np.ones(len(signal), dtype=bool) if method == "differential": di = np.diff(signal.imag) dq = np.diff(signal.real) derivative_magnitude = np.sqrt(di**2 + dq**2) derivative_magnitude = np.append(derivative_magnitude, 0) threshold = np.percentile(derivative_magnitude, 15) return derivative_magnitude < threshold if method == "amplitude": amplitude = np.abs(signal) amplitude_change = np.abs(np.diff(amplitude)) amplitude_change = np.append(amplitude_change, 0) threshold = np.percentile(amplitude_change, 20) return amplitude_change < threshold if method == "phase": phase = np.angle(signal) phase_diff = np.diff(np.unwrap(phase)) phase_diff = np.append(phase_diff, 0) threshold = np.percentile(np.abs(phase_diff), 20) return np.abs(phase_diff) < threshold if method == "combined": diff_stable = detect_constellation_symbols(signal, "differential") amp_stable = detect_constellation_symbols(signal, "amplitude") phase_stable = detect_constellation_symbols(signal, "phase") stability_count = diff_stable.astype(int) + amp_stable.astype(int) + phase_stable.astype(int) return stability_count >= 2 raise ValueError(f"Unknown method: {method}") def view_simple_sig( recording: Recording, annotations: Optional[list] = None, output_path: Optional[str] = "images/signal.png", saveplot: Optional[bool] = True, fast_mode: Optional[bool] = False, compact_mode: Optional[bool] = False, horizontal_mode: Optional[bool] = False, constellation_mode: Optional[bool] = False, labels_mode: Optional[bool] = False, slice: Optional[tuple] = None, title: Optional[str] = "Signal", ): """ Create a simple plot of various signal visualizations as a png or svg image. :param recording: The recording object to plot. :type recording: Recording :param output_path: The output image path. Defaults to "images/signal.png" :type output_path: str, optional :param saveplot: Whether or not to save the plot. Defaults to True. :type saveplot: bool, optional :param fast_mode: Use fast mode for faster render. Defaults to False. :type fast_mode: bool, optional :param compact_mode: Use compact mode for compact plot. Defaults to False. :type compact_mode: bool, optional :param horizontal_mode: Display plots horizontally. Defaults to False. :type horizontal_mode: bool, optional :param constellation_mode: Display constellation plot and PSD if not using compact mode. Defaults to False. :type constellation_mode: bool, optional :param labels_mode: Display more thorough labels. Defaults to False. :type labels_mode: bool, optional :param slice: Slice of signal to display. Defaults to None. :type slice: tuple[int, int], optional :param title: Title of plot. Defaults to "Signal". :type title: str, optional """ signal = recording.data[0] sample_rate_hz, center_freq_hz, sdr = extract_metadata_fields(recording.metadata) setup_style(labels_mode=labels_mode, compact_mode=compact_mode) if slice: start_idx, end_idx = slice signal = signal[start_idx:end_idx] print(f"Using slice: samples {start_idx} to {end_idx} ({len(signal):,} samples)") max_display_pixels = 100_000 if fast_mode else 250_000 display_signal = decimate(signal, max_display_pixels) if len(signal) > max_display_pixels else signal spec_signal = signal if compact_mode: fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(12, 6), gridspec_kw={"height_ratios": [1, 5]}) show_title = False show_labels = False ax_constellation = ax_psd = None elif horizontal_mode: if constellation_mode: fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(18, 6)) ax_constellation = ax3 else: fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 6)) ax_constellation = None show_title = True show_labels = labels_mode ax_psd = None else: if constellation_mode: fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(16, 12)) ax_constellation, ax_psd = ax3, ax4 else: fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(14, 10)) ax_constellation = ax_psd = None show_title = True show_labels = labels_mode if show_title: fig.suptitle(title, fontsize=16, color=COLORS["light"], y=0.96) fig.patch.set_facecolor("#0f172a") total_duration_s = len(signal) / sample_rate_hz if sample_rate_hz else 0.0 t_s = np.linspace(0, total_duration_s, len(display_signal)) if len(display_signal) else np.array([]) ax1.plot(t_s, display_signal.real, color=COLORS["purple"], linewidth=0.8, alpha=0.8, label="I") ax1.plot(t_s, display_signal.imag, color=COLORS["magenta"], linewidth=0.8, alpha=0.8, label="Q") ax1.set_xlim(0, total_duration_s) ax1.grid(True, alpha=0.3) nfft, overlap = _get_nfft_size(signal=signal, fast_mode=fast_mode) _, freqs, _, _ = ax2.specgram( spec_signal, NFFT=nfft, Fc=center_freq_hz, Fs=sample_rate_hz, noverlap=overlap, cmap="twilight", ) ax2.set_ylim(center_freq_hz - sample_rate_hz / 2, center_freq_hz + sample_rate_hz / 2) ax2.set_xlim(0, total_duration_s) if show_labels: if horizontal_mode: ax1.set_xlabel("Time (s)") else: ax2.set_xlabel("Time (s)") ax1.set_ylabel("Amplitude") ax1.set_title(f"Time Series - {sdr} SDR", loc="left", pad=10) ax1.legend(loc="upper right") ax2.set_ylabel("Frequency (Hz)") ax2.set_title( f"Spectrogram - {center_freq_hz / 1e6:.1f} MHz ± {sample_rate_hz / 2e6:.1f} MHz", loc="left", pad=10 ) yticks = ax2.get_yticks() ax2.set_yticklabels([f"{y / 1e6:.1f}" for y in yticks]) elif not compact_mode: ax1.set_title("Time Series", loc="left", pad=10) ax1.legend(loc="upper right", fontsize=8) ax2.set_title("Spectrogram", loc="left", pad=10) _add_annotations( annotations=annotations, compact_mode=compact_mode, show_labels=show_labels, sample_rate_hz=sample_rate_hz, center_freq_hz=center_freq_hz, ax2=ax2, ) if ax_constellation is not None: constellation_samples = _get_plot_samples(signal=signal, fast_mode=fast_mode, slow_max=50_000, fast_max=20_000) method = "differential" if fast_mode else "combined" stable_points = detect_constellation_symbols(constellation_samples, method=method) ax_constellation.scatter( constellation_samples.real[~stable_points], constellation_samples.imag[~stable_points], c=COLORS["muted"], s=0.5, alpha=0.2, ) ax_constellation.scatter( constellation_samples.real[stable_points], constellation_samples.imag[stable_points], c=COLORS["purple"], s=3, alpha=0.8, ) ax_constellation.set_xlabel("In-phase (I)") ax_constellation.set_ylabel("Quadrature (Q)") ax_constellation.set_title("Constellation") ax_constellation.grid(True, alpha=0.3) ax_constellation.set_aspect("equal") if ax_psd is not None: psd_samples = _get_plot_samples(signal=signal, fast_mode=fast_mode, slow_max=65_536, fast_max=16_384) window = hann(len(psd_samples)) spectrum = np.abs(fftshift(fft(psd_samples * window))) ** 2 freqs = np.linspace(-sample_rate_hz / 2, sample_rate_hz / 2, len(psd_samples)) freqs = freqs + center_freq_hz spectrum_db = 10 * np.log10(spectrum + 1e-12) ax_psd.plot(freqs / 1e6, spectrum_db, color=COLORS["accent"], linewidth=1.0) ax_psd.set_xlabel("Frequency (MHz)") ax_psd.set_ylabel("Power (dB)") ax_psd.set_title("Power Spectral Density") ax_psd.grid(True, alpha=0.3) if compact_mode: ax1.set_xticks([]) ax1.set_yticks([]) ax2.set_xticks([]) ax2.set_yticks([]) plt.subplots_adjust(left=0, right=1, top=1, bottom=0, hspace=0) else: plt.tight_layout() if show_title: plt.subplots_adjust(top=0.92) if saveplot: output_path, extension = set_path(output_path=output_path) dpi_value = _set_dpi(fast_mode=fast_mode, labels_mode=labels_mode, extension=extension) plt.savefig(output_path, dpi=dpi_value, bbox_inches="tight", facecolor="#0f172a", edgecolor="none") print(f"Saved signal plot to {output_path}") return output_path plt.show() # Garbage collection and clean up to prevent memory overloading plt.close("all") gc.collect() return None __all__ = [ "setup_style", "detect_constellation_symbols", "view_simple_sig", ]