Compare commits

...

11 Commits

Author SHA1 Message Date
0bd1b6e288 Merge branch 'main' into st_edits
All checks were successful
Build Sphinx Docs Set / Build Docs (pull_request) Successful in 13s
Test with tox / Test with tox (3.11) (pull_request) Successful in 34s
Test with tox / Test with tox (3.12) (pull_request) Successful in 31s
Build Project / Build Project (3.10) (pull_request) Successful in 49s
Test with tox / Test with tox (3.10) (pull_request) Successful in 43s
Build Project / Build Project (3.11) (pull_request) Successful in 49s
Build Project / Build Project (3.12) (pull_request) Successful in 49s
2025-10-22 12:35:26 -04:00
8105b829be Merge pull request 'Widget Support Panels' (#5) from new_widgets_support into main
All checks were successful
Build Sphinx Docs Set / Build Docs (push) Successful in 13s
Test with tox / Test with tox (3.11) (push) Successful in 32s
Test with tox / Test with tox (3.12) (push) Successful in 31s
Test with tox / Test with tox (3.10) (push) Successful in 42s
Build Project / Build Project (3.10) (push) Successful in 50s
Build Project / Build Project (3.11) (push) Successful in 49s
Build Project / Build Project (3.12) (push) Successful in 49s
Reviewed-on: #5
Reviewed-by: lswersk <lorne@qoherent.ai>
2025-10-22 12:04:53 -04:00
450fab6df2 Merge branch 'main' into new_widgets_support
All checks were successful
Build Sphinx Docs Set / Build Docs (pull_request) Successful in 12s
Test with tox / Test with tox (3.11) (pull_request) Successful in 33s
Test with tox / Test with tox (3.12) (pull_request) Successful in 30s
Test with tox / Test with tox (3.10) (pull_request) Successful in 42s
Build Project / Build Project (3.10) (pull_request) Successful in 51s
Build Project / Build Project (3.11) (pull_request) Successful in 49s
Build Project / Build Project (3.12) (pull_request) Successful in 49s
2025-10-22 12:03:38 -04:00
ben
a0b46a35e2 format fix 3?
All checks were successful
Build Sphinx Docs Set / Build Docs (pull_request) Successful in 13s
Test with tox / Test with tox (3.11) (pull_request) Successful in 32s
Test with tox / Test with tox (3.12) (pull_request) Successful in 31s
Test with tox / Test with tox (3.10) (pull_request) Successful in 42s
Build Project / Build Project (3.10) (pull_request) Successful in 50s
Build Project / Build Project (3.11) (pull_request) Successful in 49s
Build Project / Build Project (3.12) (pull_request) Successful in 48s
2025-10-22 12:02:05 -04:00
ben
4872eea116 more formatting fixes
Some checks failed
Build Sphinx Docs Set / Build Docs (pull_request) Successful in 13s
Test with tox / Test with tox (3.12) (pull_request) Successful in 30s
Test with tox / Test with tox (3.11) (pull_request) Successful in 34s
Test with tox / Test with tox (3.10) (pull_request) Failing after 43s
Build Project / Build Project (3.10) (pull_request) Successful in 51s
Build Project / Build Project (3.11) (pull_request) Successful in 49s
Build Project / Build Project (3.12) (pull_request) Successful in 49s
2025-10-22 11:41:37 -04:00
ben
c06e58f5d6 Formatting Fixes
Some checks failed
Build Sphinx Docs Set / Build Docs (pull_request) Successful in 17s
Test with tox / Test with tox (3.12) (pull_request) Successful in 42s
Test with tox / Test with tox (3.11) (pull_request) Successful in 49s
Build Project / Build Project (3.10) (pull_request) Successful in 1m2s
Test with tox / Test with tox (3.10) (pull_request) Failing after 58s
Build Project / Build Project (3.11) (pull_request) Successful in 1m2s
Build Project / Build Project (3.12) (pull_request) Successful in 1m2s
2025-10-21 11:48:40 -04:00
ben
c7c7100d46 Formatting fixes 2025-10-20 14:44:51 -04:00
ben
e863040e19 onnx visualizers
Some checks failed
Build Sphinx Docs Set / Build Docs (pull_request) Failing after 1s
Test with tox / Test with tox (3.11) (pull_request) Successful in 34s
Test with tox / Test with tox (3.12) (pull_request) Successful in 33s
Test with tox / Test with tox (3.10) (pull_request) Failing after 42s
Build Project / Build Project (3.11) (pull_request) Successful in 52s
Build Project / Build Project (3.12) (pull_request) Successful in 51s
Build Project / Build Project (3.10) (pull_request) Successful in 56s
2025-10-20 12:16:30 -04:00
ben
2721ed866c Radio-dataset widgets
Some checks failed
Build Sphinx Docs Set / Build Docs (pull_request) Successful in 16s
Test with tox / Test with tox (3.11) (pull_request) Successful in 34s
Test with tox / Test with tox (3.12) (pull_request) Successful in 31s
Test with tox / Test with tox (3.10) (pull_request) Failing after 42s
Build Project / Build Project (3.10) (pull_request) Successful in 53s
Build Project / Build Project (3.11) (pull_request) Successful in 53s
Build Project / Build Project (3.12) (pull_request) Successful in 52s
2025-10-17 09:35:27 -04:00
ben
f430e626a6 Pytorch state dict widget
Some checks failed
Build Sphinx Docs Set / Build Docs (pull_request) Successful in 17s
Test with tox / Test with tox (3.11) (pull_request) Successful in 34s
Test with tox / Test with tox (3.12) (pull_request) Successful in 32s
Test with tox / Test with tox (3.10) (pull_request) Failing after 42s
Build Project / Build Project (3.10) (pull_request) Successful in 52s
Build Project / Build Project (3.11) (pull_request) Successful in 51s
Build Project / Build Project (3.12) (pull_request) Successful in 51s
2025-10-14 14:22:37 -04:00
ben
1fb55607a2 Pytorch widgets 2025-10-09 16:55:23 -04:00
3 changed files with 1188 additions and 0 deletions

View File

@ -0,0 +1,562 @@
"""
ONNX model visualization utilities.
This module provides visualization functions for ONNX models following the same pattern
as other ria-toolkit-oss visualization modules.
"""
from pathlib import Path
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
try:
import onnx
import onnx.helper
import onnx.numpy_helper
ONNX_AVAILABLE = True
except ImportError:
ONNX_AVAILABLE = False
def create_styled_error_figure(title: str, message: str, suggestion: str = None) -> go.Figure:
"""Create a professional error figure with Qoherent dark theme styling."""
fig = go.Figure()
# Create a clean, centered text display using Plotly's text formatting
main_text = f"<b style='color:#f56565;font-size:18px'>⚠️ {title}</b><br><br>"
main_text += f"<span style='color:#e2e8f0;font-size:14px'>{message}</span>"
if suggestion:
main_text += "<br><br><span style='color:#63b3ed;font-size:13px'>💡 <b>Suggestion:</b></span><br>"
main_text += f"<span style='color:#cbd5e0;font-size:12px'>{suggestion}</span>"
# Add the main text annotation
fig.add_annotation(
text=main_text,
xref="paper",
yref="paper",
x=0.5,
y=0.5,
xanchor="center",
yanchor="middle",
showarrow=False,
align="center",
borderwidth=2,
bordercolor="#4a5568",
bgcolor="#2d3748",
font=dict(family="Arial, sans-serif", size=14, color="#e2e8f0"),
)
# Update layout with dark theme
fig.update_layout(
title="",
height=400,
template="plotly_dark",
margin=dict(l=40, r=40, t=40, b=40),
plot_bgcolor="#1a202c",
paper_bgcolor="#1a202c",
font=dict(color="#e2e8f0"),
)
# Remove axes and grid
fig.update_xaxes(visible=False)
fig.update_yaxes(visible=False)
return fig
def graph_structure(file_path: Path) -> go.Figure:
"""
Visualize the ONNX model graph structure showing nodes and connections.
Matches layout ID: graph_structure
"""
if not ONNX_AVAILABLE:
return create_styled_error_figure(
"ONNX Not Available", "ONNX library is required for model analysis.", "Install with: pip install onnx"
)
try:
# Load ONNX model
model = onnx.load(str(file_path))
graph = model.graph
nodes = graph.node
if len(nodes) == 0:
return create_styled_error_figure(
"Empty Model", "This ONNX model contains no operators.", "Please check if the model file is valid."
)
# Create network diagram data
node_info = []
for i, node in enumerate(nodes):
node_info.append(
{
"id": i,
"name": node.name or f"{node.op_type}_{i}",
"op_type": node.op_type,
"inputs": len(node.input),
"outputs": len(node.output),
}
)
# Create visualization
fig = go.Figure()
# Simple linear layout for now
x_positions = list(range(len(node_info)))
y_positions = [0] * len(node_info)
# Add nodes as scatter points
fig.add_trace(
go.Scatter(
x=x_positions,
y=y_positions,
mode="markers+text",
marker=dict(
size=[min(max(info["inputs"] + info["outputs"] + 15, 20), 50) for info in node_info],
color=px.colors.qualitative.Set3[: len(node_info)],
opacity=0.8,
line=dict(width=2, color="white"),
),
text=[f"{info['op_type']}" for info in node_info],
textposition="middle center",
textfont=dict(size=10, color="white"),
hovertemplate="<b>%{text}</b><br>"
+ "Name: %{customdata[0]}<br>"
+ "Inputs: %{customdata[1]}<br>"
+ "Outputs: %{customdata[2]}<br>"
+ "<extra></extra>",
customdata=[[info["name"], info["inputs"], info["outputs"]] for info in node_info],
name="Operators",
)
)
# Add connecting lines
for i in range(len(node_info) - 1):
fig.add_trace(
go.Scatter(
x=[x_positions[i], x_positions[i + 1]],
y=[y_positions[i], y_positions[i + 1]],
mode="lines",
line=dict(color="gray", width=1, dash="dot"),
showlegend=False,
hoverinfo="skip",
)
)
fig.update_layout(
title={
"text": (
"ONNX Graph Structure<br>"
f"<span style='font-size:14px; color:#a0a0a0;'>{len(nodes)} Operators</span>"
),
"x": 0.5,
"xanchor": "center",
"font": {"size": 22},
},
xaxis_title="Execution Order",
yaxis_title="",
showlegend=False,
height=500,
template="plotly_dark",
yaxis=dict(showticklabels=False, showgrid=False),
xaxis=dict(showgrid=False),
margin=dict(l=50, r=50, t=80, b=50),
)
return fig
except Exception as e:
return create_styled_error_figure(
"Graph Analysis Error", "Could not analyze ONNX model structure.", f"Error: {str(e)}"
)
def operator_analysis(file_path: Path) -> go.Figure:
"""
Analyze the distribution and types of operators in the ONNX model.
Matches layout ID: operator_analysis
"""
if not ONNX_AVAILABLE:
return create_styled_error_figure(
"ONNX Not Available", "ONNX library is required for operator analysis.", "Install with: pip install onnx"
)
try:
model = onnx.load(str(file_path))
graph = model.graph
# Count operators
op_counts = {}
for node in graph.node:
op_type = node.op_type
op_counts[op_type] = op_counts.get(op_type, 0) + 1
if not op_counts:
return create_styled_error_figure(
"No Operators",
"This ONNX model contains no operators to analyze.",
"Please verify the model file is valid.",
)
# Sort by frequency
sorted_ops = sorted(op_counts.items(), key=lambda x: x[1], reverse=True)
# Create pie chart and bar chart
fig = make_subplots(
rows=2,
cols=1,
subplot_titles=("Operator Distribution", "Operator Frequency"),
specs=[[{"type": "pie"}], [{"type": "bar"}]],
)
# Pie chart for operator distribution
op_names, op_values = zip(*sorted_ops) if sorted_ops else ([], [])
fig.add_trace(
go.Pie(
labels=list(op_names),
values=list(op_values),
textinfo="label+percent",
textposition="auto",
showlegend=False,
),
row=1,
col=1,
)
# Bar chart for frequency
fig.add_trace(
go.Bar(
x=list(op_names),
y=list(op_values),
marker_color=px.colors.qualitative.Set3[: len(op_names)],
showlegend=False,
),
row=2,
col=1,
)
fig.update_layout(
title={
"text": (
"ONNX Operator Analysis<br>"
f"<span style='font-size:14px; color:#a0a0a0;'>{len(op_counts)} Unique Types</span>"
),
"x": 0.5,
"xanchor": "center",
"font": {"size": 22},
},
height=700,
template="plotly_dark",
)
return fig
except Exception as e:
return create_styled_error_figure(
"Operator Analysis Error", "Could not analyze ONNX operators.", f"Error: {str(e)}"
)
def model_metadata(file_path: Path) -> go.Figure:
"""
Display comprehensive metadata about the ONNX model.
Matches layout ID: model_metadata
"""
if not ONNX_AVAILABLE:
return create_styled_error_figure(
"ONNX Not Available", "ONNX library is required for metadata analysis.", "Install with: pip install onnx"
)
try:
model = onnx.load(str(file_path))
graph = model.graph
# Calculate basic statistics
total_nodes = len(graph.node)
total_inputs = len(graph.input)
total_outputs = len(graph.output)
total_initializers = len(graph.initializer)
# Calculate parameter count
total_params = 0
for initializer in graph.initializer:
try:
tensor = onnx.numpy_helper.to_array(initializer)
total_params += tensor.size
except Exception:
pass # Skip if tensor can't be loaded
# Get model file size
file_size_mb = file_path.stat().st_size / (1024 * 1024)
# Create metadata display
fig = make_subplots(
rows=2,
cols=2,
subplot_titles=("Model Size", "Architecture", "Inputs/Outputs", "Parameters"),
specs=[[{"type": "indicator"}, {"type": "bar"}], [{"type": "table"}, {"type": "indicator"}]],
)
# Model size indicator
fig.add_trace(
go.Indicator(
mode="number+gauge",
value=file_size_mb,
title={"text": "Model Size (MB)"},
number={"suffix": " MB", "valueformat": ".2f"},
gauge={
"axis": {"range": [0, max(100, file_size_mb * 1.5)]},
"bar": {"color": "darkblue"},
"steps": [
{"range": [0, 10], "color": "lightgreen"},
{"range": [10, 50], "color": "yellow"},
{"range": [50, 100], "color": "orange"},
],
},
),
row=1,
col=1,
)
# Architecture components
arch_data = ["Nodes", "Inputs", "Outputs", "Initializers"]
arch_values = [total_nodes, total_inputs, total_outputs, total_initializers]
fig.add_trace(
go.Bar(x=arch_data, y=arch_values, marker_color=["blue", "green", "orange", "red"], showlegend=False),
row=1,
col=2,
)
# I/O Table
io_data = []
# Add input info
for inp in graph.input[:5]: # Limit to first 5
shape = "Unknown"
dtype = "Unknown"
if inp.type and inp.type.tensor_type:
# Get shape
if inp.type.tensor_type.shape:
dims = [str(d.dim_value) if d.dim_value > 0 else "?" for d in inp.type.tensor_type.shape.dim]
shape = f"[{', '.join(dims)}]"
# Get data type
elem_type = inp.type.tensor_type.elem_type
type_map = {
1: "float32",
2: "uint8",
3: "int8",
6: "int32",
7: "int64",
9: "bool",
10: "float16",
11: "double",
}
dtype = type_map.get(elem_type, f"type_{elem_type}")
io_data.append(["Input", inp.name[:20], shape, dtype])
# Add output info
for out in graph.output[:5]: # Limit to first 5
shape = "Unknown"
dtype = "Unknown"
if out.type and out.type.tensor_type:
if out.type.tensor_type.shape:
dims = [str(d.dim_value) if d.dim_value > 0 else "?" for d in out.type.tensor_type.shape.dim]
shape = f"[{', '.join(dims)}]"
elem_type = out.type.tensor_type.elem_type
type_map = {
1: "float32",
2: "uint8",
3: "int8",
6: "int32",
7: "int64",
9: "bool",
10: "float16",
11: "double",
}
dtype = type_map.get(elem_type, f"type_{elem_type}")
io_data.append(["Output", out.name[:20], shape, dtype])
if io_data:
fig.add_trace(
go.Table(
header=dict(values=["Type", "Name", "Shape", "Data Type"], fill_color="lightblue", align="left"),
cells=dict(values=list(zip(*io_data)), fill_color="white", align="left"),
),
row=2,
col=1,
)
# Parameters indicator
fig.add_trace(
go.Indicator(
mode="number",
value=total_params,
title={"text": "Total Parameters"},
number={"suffix": "M", "valueformat": ".2f"},
number_font_size=30,
),
row=2,
col=2,
)
fig.update_layout(
title={
"text": (
"ONNX Model Metadata<br>"
f"<span style='font-size:14px; color:#a0a0a0;'>{total_params/1e6:.2f}M Parameters</span>"
),
"x": 0.5,
"xanchor": "center",
"font": {"size": 22},
},
height=600,
template="plotly_dark",
showlegend=False,
)
return fig
except Exception as e:
return create_styled_error_figure(
"Metadata Analysis Error", "Could not extract ONNX model metadata.", f"Error: {str(e)}"
)
def performance_metrics(file_path: Path) -> go.Figure:
"""
Display performance and computational metrics for the ONNX model.
Matches layout ID: performance_metrics
"""
if not ONNX_AVAILABLE:
return create_styled_error_figure(
"ONNX Not Available",
"ONNX library is required for performance analysis.",
"Install with: pip install onnx",
)
try:
model = onnx.load(str(file_path))
graph = model.graph
# Calculate metrics
model_size_bytes = file_path.stat().st_size
model_size_mb = model_size_bytes / (1024 * 1024)
# Count parameters
total_params = 0
for initializer in graph.initializer:
try:
tensor = onnx.numpy_helper.to_array(initializer)
total_params += tensor.size
except Exception:
pass
# Estimate memory usage (rough approximation)
param_memory_mb = (total_params * 4) / (1024 * 1024) # Assume float32
# Count operations by complexity
compute_ops = ["Conv", "MatMul", "Gemm", "LSTM", "GRU"]
efficient_ops = ["Relu", "Add", "Mul", "BatchNormalization", "Dropout"]
compute_count = sum(1 for node in graph.node if any(op in node.op_type for op in compute_ops))
efficient_count = sum(1 for node in graph.node if any(op in node.op_type for op in efficient_ops))
total_ops = len(graph.node)
other_count = total_ops - compute_count - efficient_count
# Create performance dashboard
fig = make_subplots(
rows=2,
cols=2,
subplot_titles=("Model Efficiency", "Memory Usage", "Operation Types", "Complexity Score"),
specs=[[{"type": "bar"}, {"type": "bar"}], [{"type": "pie"}, {"type": "indicator"}]],
)
# Model efficiency metrics
efficiency_metrics = ["Model Size (MB)", "Parameters (M)", "Total Ops"]
efficiency_values = [model_size_mb, total_params / 1e6, total_ops]
fig.add_trace(
go.Bar(
x=efficiency_metrics, y=efficiency_values, marker_color=["blue", "green", "orange"], showlegend=False
),
row=1,
col=1,
)
# Memory usage
memory_types = ["Parameters", "Est. Inference"]
memory_values = [param_memory_mb, param_memory_mb * 2] # Rough estimate
fig.add_trace(
go.Bar(x=memory_types, y=memory_values, marker_color=["purple", "red"], showlegend=False),
row=1,
col=2,
)
# Operation types pie chart
fig.add_trace(
go.Pie(
labels=["Compute Ops", "Efficient Ops", "Other Ops"],
values=[compute_count, efficient_count, other_count],
marker_colors=["red", "green", "gray"],
),
row=2,
col=1,
)
# Complexity score (simple heuristic)
complexity_score = min(100, (model_size_mb * 10 + total_params / 1e6 * 20 + compute_count))
fig.add_trace(
go.Indicator(
mode="gauge+number",
value=complexity_score,
title={"text": "Complexity Score"},
gauge={
"axis": {"range": [0, 100]},
"bar": {
"color": "darkred" if complexity_score > 70 else "orange" if complexity_score > 40 else "green"
},
"steps": [
{"range": [0, 40], "color": "lightgreen"},
{"range": [40, 70], "color": "yellow"},
{"range": [70, 100], "color": "lightcoral"},
],
},
),
row=2,
col=2,
)
fig.update_layout(
title={
"text": (
"ONNX Performance Metrics<br>"
f"<span style='font-size:14px; color:#a0a0a0;'>"
f"Complexity Score: {complexity_score:.0f}/100</span>"
),
"x": 0.5,
"xanchor": "center",
"font": {"size": 22},
},
height=600,
template="plotly_dark",
showlegend=False,
)
return fig
except Exception as e:
return create_styled_error_figure(
"Performance Analysis Error", "Could not analyze ONNX model performance.", f"Error: {str(e)}"
)

View File

@ -0,0 +1,194 @@
import numpy as np
import plotly.graph_objects as go
from plotly.graph_objects import Figure
def create_styled_error_figure(title: str, message: str, suggestion: str = None) -> go.Figure:
"""Create a professional error figure with Qoherent dark theme styling."""
fig = go.Figure()
# Create a clean, centered text display using Plotly's text formatting
main_text = f"<b style='color:#f56565;font-size:18px'>⚠️ {title}</b><br><br>"
main_text += f"<span style='color:#e2e8f0;font-size:14px'>{message}</span>"
if suggestion:
main_text += "<br><br><span style='color:#63b3ed;font-size:13px'>💡 <b>Suggestion:</b></span><br>"
main_text += f"<span style='color:#cbd5e0;font-size:12px'>{suggestion}</span>"
# Add the main text annotation
fig.add_annotation(
text=main_text,
xref="paper",
yref="paper",
x=0.5,
y=0.5,
xanchor="center",
yanchor="middle",
showarrow=False,
align="center",
borderwidth=2,
bordercolor="#4a5568",
bgcolor="#2d3748",
font=dict(family="Arial, sans-serif", size=14, color="#e2e8f0"),
)
# Update layout with dark theme
fig.update_layout(
title="",
height=400,
template="plotly_dark",
margin=dict(l=40, r=40, t=40, b=40),
plot_bgcolor="#1a202c",
paper_bgcolor="#1a202c",
font=dict(color="#e2e8f0"),
)
# Remove axes and grid
fig.update_xaxes(visible=False)
fig.update_yaxes(visible=False)
return fig
def model_summary_plot(state_dict: dict) -> Figure:
"""Generate a summary plot of the PyTorch model state dict."""
if not state_dict:
return create_styled_error_figure(
"Empty State Dict",
"No parameters found in state dict",
"Ensure the model state dictionary contains weight parameters",
)
# Count parameters by layer type
layer_info = []
for key, tensor in state_dict.items():
if "weight" in key:
try:
layer_name = key.replace(".weight", "")
param_count = (
tensor.numel()
if hasattr(tensor, "numel")
else len(tensor.flatten()) if hasattr(tensor, "flatten") else 0
)
shape = (
list(tensor.shape)
if hasattr(tensor, "shape")
else [len(tensor)] if hasattr(tensor, "__len__") else []
)
layer_info.append({"layer": layer_name, "parameters": param_count, "shape": shape})
except Exception as e:
print(f"Warning: Could not process layer {key}: {e}")
continue
if not layer_info:
return create_styled_error_figure(
"No Weight Layers Found",
"No weight layers found in state dict",
"Ensure the state dictionary contains layers with '.weight' parameters",
)
# Create bar chart of parameter counts
fig = go.Figure(
data=[
go.Bar(
x=[info["layer"] for info in layer_info],
y=[info["parameters"] for info in layer_info],
text=[f"Shape: {info['shape']}" for info in layer_info],
textposition="auto",
)
]
)
fig.update_layout(
title="Model Layer Parameter Counts",
xaxis_title="Layer",
yaxis_title="Number of Parameters",
template="plotly_dark",
)
return fig
def layer_weights_plot(state_dict: dict, layer_name: str = None) -> Figure:
"""Visualize weights for a specific layer."""
if not state_dict:
return create_styled_error_figure(
"Empty State Dict", "No data in state dict", "Ensure the model state dictionary contains data"
)
if layer_name is None:
# Get first weight tensor
weight_keys = [k for k in state_dict.keys() if "weight" in k]
if not weight_keys:
return create_styled_error_figure(
"No Weight Tensors Found",
"No weight tensors found in state dict",
"Ensure the state dictionary contains layers with '.weight' parameters",
)
layer_name = weight_keys[0]
try:
weights = state_dict[layer_name]
# Convert to numpy if it's a torch tensor
if hasattr(weights, "numpy"):
weights_np = weights.detach().numpy() if hasattr(weights, "detach") else weights.numpy()
elif hasattr(weights, "cpu"):
weights_np = weights.cpu().detach().numpy()
else:
weights_np = np.array(weights)
# For 2D weights, create heatmap
if len(weights_np.shape) == 2:
fig = go.Figure(data=go.Heatmap(z=weights_np, colorscale="RdBu", zmid=0))
fig.update_layout(title=f"Weights Heatmap: {layer_name}", template="plotly_dark")
else:
# For other shapes, flatten and show histogram
flat_weights = weights_np.flatten()
fig = go.Figure(data=[go.Histogram(x=flat_weights, nbinsx=50)])
fig.update_layout(title=f"Weight Distribution: {layer_name}", template="plotly_dark")
return fig
except Exception as e:
return create_styled_error_figure(
"Layer Processing Error",
f"Error processing layer {layer_name}: {str(e)}",
"Check that the layer name exists and contains valid tensor data",
)
def weight_distribution_plot(state_dict: dict) -> Figure:
"""Show distribution of weights across all layers."""
if not state_dict:
return create_styled_error_figure(
"Empty State Dict", "No data in state dict", "Ensure the model state dictionary contains data"
)
all_weights = []
layer_names = []
for key, tensor in state_dict.items():
if "weight" in key:
try:
# Convert to numpy if it's a torch tensor
if hasattr(tensor, "numpy"):
weights_np = tensor.detach().numpy() if hasattr(tensor, "detach") else tensor.numpy()
elif hasattr(tensor, "cpu"):
weights_np = tensor.cpu().detach().numpy()
else:
weights_np = np.array(tensor)
flat_weights = weights_np.flatten()
all_weights.extend(flat_weights)
layer_names.extend([key] * len(flat_weights))
except Exception as e:
print(f"Warning: Could not process weights for layer {key}: {e}")
continue
if not all_weights:
return create_styled_error_figure(
"No Weight Data Found",
"No weight data found in state dict",
"Ensure the state dictionary contains layers with '.weight' parameters",
)
fig = go.Figure(data=[go.Histogram(x=all_weights, nbinsx=100, name="All Weights")])
fig.update_layout(
title="Overall Weight Distribution",
xaxis_title="Weight Value",
yaxis_title="Frequency",
template="plotly_dark",
)
return fig

View File

@ -0,0 +1,432 @@
"""
Simple, clean visualization utilities for RadioDataset analysis.
"""
import random
from typing import Optional
import numpy as np
import plotly.express as px
import plotly.graph_objects as go
from plotly.graph_objects import Figure
from plotly.subplots import make_subplots
def create_styled_error_figure(title: str, message: str, suggestion: str = None) -> Figure:
"""Create a professional error figure with Qoherent dark theme styling."""
fig = go.Figure()
# Create a clean, centered text display using Plotly's text formatting
main_text = f"<b style='color:#f56565;font-size:18px'>⚠️ {title}</b><br><br>"
main_text += f"<span style='color:#e2e8f0;font-size:14px'>{message}</span>"
if suggestion:
main_text += "<br><br><span style='color:#63b3ed;font-size:13px'>💡 <b>Suggestion:</b></span><br>"
main_text += f"<span style='color:#cbd5e0;font-size:12px'>{suggestion}</span>"
# Add the main text annotation
fig.add_annotation(
text=main_text,
xref="paper",
yref="paper",
x=0.5,
y=0.5,
xanchor="center",
yanchor="middle",
showarrow=False,
align="center",
borderwidth=2,
bordercolor="#4a5568",
bgcolor="#2d3748",
font=dict(family="Arial, sans-serif", size=14, color="#e2e8f0"),
)
# Update layout with dark theme
fig.update_layout(
title="",
height=400,
template="plotly_dark",
margin=dict(l=40, r=40, t=40, b=40),
plot_bgcolor="#1a202c",
paper_bgcolor="#1a202c",
font=dict(color="#e2e8f0"),
)
# Remove axes and grid
fig.update_xaxes(visible=False)
fig.update_yaxes(visible=False)
return fig
def _check_dataset_compatibility(dataset, plot_type: str) -> tuple[bool, str]:
"""Check if dataset is compatible with a specific plot type.
Returns (is_compatible, error_message)
"""
try:
metadata = dataset.metadata
if len(metadata) == 0:
return False, "Dataset is empty"
if plot_type == "class_distribution":
# Check if we have any categorical columns
categorical_cols = [col for col in metadata.columns if metadata[col].dtype == "object"]
alternatives = ["class", "label", "modulation", "impairment", "use_case", "category", "labels"]
has_class_col = any(alt in metadata.columns for alt in alternatives)
has_categorical = len(categorical_cols) > 0
if not has_class_col and not has_categorical:
return False, "No categorical columns found for class distribution"
elif plot_type == "sample_spectrogram":
# Check if we can generate a valid spectrogram
if len(metadata) < 1:
return False, "No samples available for spectrogram"
# Check if we can access sample data (basic test)
try:
sample_data = dataset[0] if hasattr(dataset, "__getitem__") else None
if sample_data is None or len(sample_data) < 32:
return False, "Insufficient sample data for spectrogram (need at least 32 points)"
except Exception:
# If we can't access data, we'll rely on synthetic data generation
pass
return True, ""
except Exception as e:
return False, f"Dataset compatibility check failed: {str(e)}"
def class_distribution_plot(dataset, class_key: str = "modulation") -> Figure:
"""Generate a bar plot showing the distribution of examples across classes."""
try:
# Check dataset compatibility first
is_compatible, error_msg = _check_dataset_compatibility(dataset, "class_distribution")
if not is_compatible:
return create_styled_error_figure(
"Dataset Not Compatible",
"This dataset doesn't have categorical labels needed for class distribution analysis.",
"Try using the Dataset Overview widget to explore the available data columns.",
)
metadata = dataset.metadata
# Find the class column
if class_key not in metadata.columns:
# Try common alternatives
alternatives = ["class", "label", "modulation", "impairment", "use_case", "category", "labels"]
for alt in alternatives:
if alt in metadata.columns:
class_key = alt
break
else:
# Use first categorical column
for col in metadata.columns:
if metadata[col].dtype == "object" or metadata[col].nunique() < 50:
class_key = col
break
if class_key not in metadata.columns:
return create_styled_error_figure(
"No Class Labels Found",
"This dataset contains numerical data without categorical labels.",
(
"Try using the Dataset Overview widget for data analysis, "
"or check if your dataset has hidden categorical columns."
),
)
# Count examples per class (limit to top 20 for performance)
class_counts = metadata[class_key].value_counts()
if len(class_counts) > 20:
class_counts = class_counts.head(20)
class_counts = class_counts.sort_index()
# Create simple bar plot
fig = px.bar(x=class_counts.index, y=class_counts.values, title=f"Class Distribution: {class_key.title()}")
fig.update_traces(texttemplate="%{y}", textposition="outside")
fig.update_layout(
xaxis_title=class_key.title(),
yaxis_title="Number of Examples",
showlegend=False,
height=400,
template="plotly_dark",
)
return fig
except Exception as e:
return create_styled_error_figure(
"Class Distribution Error",
"An error occurred while generating the class distribution plot.",
f"Technical details: {str(e)}",
)
def dataset_overview_plot(dataset) -> Figure:
"""Generate an overview plot with key dataset statistics."""
try:
metadata = dataset.metadata
total_examples = len(metadata)
# Create subplot with multiple charts
# Determine subplot titles based on data type
categorical_cols = [col for col in metadata.columns if metadata[col].dtype == "object"]
numeric_cols = [col for col in metadata.columns if metadata[col].dtype in ["int64", "float64"]]
dist_title = "Value Distribution" if categorical_cols else "Data Distribution"
fig = make_subplots(
rows=2,
cols=2,
subplot_titles=("Dataset Size", "Data Types", dist_title, "Statistics Summary"),
specs=[
[{"type": "indicator"}, {"type": "bar"}],
[{"type": "histogram" if not categorical_cols else "bar"}, {"type": "table"}],
],
)
# Top left: Dataset size indicator
fig.add_trace(
go.Indicator(
mode="number", value=total_examples, title={"text": "Total Examples"}, number={"font": {"size": 40}}
),
row=1,
col=1,
)
# Top right: Data types distribution
dtype_counts = metadata.dtypes.value_counts()
fig.add_trace(
go.Bar(
x=[str(dt) for dt in dtype_counts.index], y=dtype_counts.values, name="Data Types", showlegend=False
),
row=1,
col=2,
)
# Bottom left: Show distribution of numeric columns or categorical if available
categorical_cols = [col for col in metadata.columns if metadata[col].dtype == "object"]
numeric_cols = [col for col in metadata.columns if metadata[col].dtype in ["int64", "float64"]]
if categorical_cols:
col = categorical_cols[0] # Show first categorical column
value_counts = metadata[col].value_counts().head(10)
fig.add_trace(
go.Bar(x=value_counts.index, y=value_counts.values, name=f"{col} Distribution", showlegend=False),
row=2,
col=1,
)
elif numeric_cols:
# Show histogram of first numeric column
col = numeric_cols[0]
fig.add_trace(
go.Histogram(x=metadata[col], name=f"{col} Distribution", showlegend=False, nbinsx=20), row=2, col=1
)
# Bottom right: Basic statistics table
stats_data = []
display_cols = numeric_cols[:5] if len(numeric_cols) > 0 else metadata.columns[:5]
for col in display_cols:
if metadata[col].dtype in ["int64", "float64"]:
stats_data.append(
[
col[:15] + "..." if len(col) > 15 else col, # Truncate long column names
f"{metadata[col].mean():.3f}",
f"{metadata[col].std():.3f}",
f"{metadata[col].min():.3f}",
f"{metadata[col].max():.3f}",
]
)
else:
unique_count = metadata[col].nunique()
stats_data.append(
[col[:15] + "..." if len(col) > 15 else col, "N/A", "N/A", f"{unique_count} unique", "N/A"]
)
if stats_data:
fig.add_trace(
go.Table(
header=dict(
values=["Column", "Mean", "Std", "Min/Unique", "Max"],
fill_color="rgba(30, 30, 30, 0.8)",
align="center",
font=dict(color="white", size=12),
),
cells=dict(
values=list(zip(*stats_data)),
fill_color="rgba(50, 50, 50, 0.6)",
align="center",
font=dict(color="white", size=11),
),
),
row=2,
col=2,
)
# Create informative title
total_cols = len(metadata.columns)
title = f"Dataset Overview - {total_examples} samples, {total_cols} columns"
if total_cols > 5:
title += " (showing first 5)"
fig.update_layout(title=title, height=600, showlegend=False, template="plotly_dark")
return fig
except Exception as e:
return create_styled_error_figure(
"Dataset Overview Error",
"An error occurred while generating the dataset overview.",
f"Technical details: {str(e)}",
)
def _find_class_column(metadata, class_key: str) -> str:
"""Find the appropriate class column in metadata."""
if class_key in metadata.columns:
return class_key
alternatives = ["class", "label", "modulation", "impairment", "use_case"]
for alt in alternatives:
if alt in metadata.columns:
return alt
return class_key
def _get_sample_data(dataset, sample_idx: int):
"""Get sample data from dataset, with synthetic fallback."""
try:
return dataset[sample_idx]
except Exception:
# Generate synthetic signal based on class
n_samples = 1024
t = np.linspace(0, 1, n_samples)
freq = 0.1 + 0.05 * sample_idx % 5 # Vary frequency by sample
sample_data = np.exp(1j * 2 * np.pi * freq * t)
# Add some noise
sample_data += 0.1 * (np.random.randn(n_samples) + 1j * np.random.randn(n_samples))
return sample_data
def _calculate_spectrogram_params(n_samples: int) -> tuple[int, int, int, int]:
"""Calculate spectrogram parameters based on sample length."""
if n_samples < 32:
raise ValueError(f"Insufficient data: need at least 32 samples, got {n_samples}")
nperseg = min(256, max(32, n_samples // 4))
hop_length = max(1, nperseg // 2)
# Adjust for very short signals
if n_samples < nperseg:
nperseg = n_samples
hop_length = 1
n_frames = max(1, (n_samples - nperseg) // hop_length + 1)
freq_bins = max(1, nperseg // 2)
return nperseg, hop_length, n_frames, freq_bins
def _compute_spectrogram(sample_data, nperseg: int, hop_length: int, n_frames: int, freq_bins: int):
"""Compute spectrogram using FFT."""
n_samples = len(sample_data)
Sxx = np.zeros((freq_bins, n_frames))
for i in range(n_frames):
start_idx = i * hop_length
end_idx = min(start_idx + nperseg, n_samples)
if end_idx > start_idx:
windowed = sample_data[start_idx:end_idx]
# Pad if necessary to maintain nperseg size
if len(windowed) < nperseg:
windowed = np.pad(windowed, (0, nperseg - len(windowed)), mode="constant")
fft_result = np.fft.fft(windowed)
Sxx[:, i] = np.abs(fft_result[:freq_bins]) ** 2
return Sxx
def _create_spectrogram_figure(
Sxx,
n_frames: int,
hop_length: int,
n_samples: int,
freq_bins: int,
sample_idx: int,
class_key: str,
sample_metadata,
) -> Figure:
"""Create the plotly figure for the spectrogram."""
# Convert to dB
Sxx_db = 10 * np.log10(Sxx + 1e-10)
# Create time and frequency vectors
t = np.arange(n_frames) * hop_length / max(1, n_samples)
f = np.linspace(0, 0.5, freq_bins)
# Create plot
fig = go.Figure(data=go.Heatmap(z=Sxx_db, x=t, y=f, colorscale="viridis", colorbar=dict(title="Power (dB)")))
# Add title with metadata
title = f"Sample Spectrogram (Index: {sample_idx})"
if class_key in sample_metadata:
title += f" - {class_key}: {sample_metadata[class_key]}"
fig.update_layout(title=title, xaxis_title="Time", yaxis_title="Frequency", height=400, template="plotly_dark")
return fig
def sample_spectrogram_plot(dataset, class_key: str = "modulation", sample_idx: Optional[int] = None) -> Figure:
"""Generate a spectrogram plot from a sample in the dataset."""
try:
# Check dataset compatibility first
is_compatible, error_msg = _check_dataset_compatibility(dataset, "sample_spectrogram")
if not is_compatible:
return create_styled_error_figure(
"Spectrogram Not Available",
"This dataset doesn't have sufficient signal data for spectrogram visualization.",
"Ensure your dataset contains complex-valued signal samples with at least 32 data points per sample.",
)
metadata = dataset.metadata
if len(metadata) == 0:
raise ValueError("Dataset is empty")
# Find class column and select sample
class_key = _find_class_column(metadata, class_key)
if sample_idx is None:
sample_idx = random.randint(0, len(metadata) - 1)
sample_metadata = metadata.iloc[sample_idx]
# Get sample data and ensure it's complex
sample_data = _get_sample_data(dataset, sample_idx)
if not np.iscomplexobj(sample_data):
sample_data = sample_data.astype(complex)
# Calculate spectrogram parameters and compute spectrogram
n_samples = len(sample_data)
nperseg, hop_length, n_frames, freq_bins = _calculate_spectrogram_params(n_samples)
Sxx = _compute_spectrogram(sample_data, nperseg, hop_length, n_frames, freq_bins)
# Create and return the figure
return _create_spectrogram_figure(
Sxx, n_frames, hop_length, n_samples, freq_bins, sample_idx, class_key, sample_metadata
)
except Exception as e:
return create_styled_error_figure(
"Spectrogram Error",
"An error occurred while generating the spectrogram plot.",
f"Technical details: {str(e)}",
)