Compare commits

..

2 Commits

Author SHA1 Message Date
ben
c06e58f5d6 Formatting Fixes
Some checks failed
Build Sphinx Docs Set / Build Docs (pull_request) Successful in 17s
Test with tox / Test with tox (3.12) (pull_request) Successful in 42s
Test with tox / Test with tox (3.11) (pull_request) Successful in 49s
Build Project / Build Project (3.10) (pull_request) Successful in 1m2s
Test with tox / Test with tox (3.10) (pull_request) Failing after 58s
Build Project / Build Project (3.11) (pull_request) Successful in 1m2s
Build Project / Build Project (3.12) (pull_request) Successful in 1m2s
2025-10-21 11:48:40 -04:00
ben
c7c7100d46 Formatting fixes 2025-10-20 14:44:51 -04:00
3 changed files with 556 additions and 545 deletions

View File

@ -6,18 +6,16 @@ as other ria-toolkit-oss visualization modules.
"""
from pathlib import Path
from typing import Optional
import plotly.graph_objects as go
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import pandas as pd
import numpy as np
try:
import onnx
import onnx.helper
import onnx.numpy_helper
ONNX_AVAILABLE = True
except ImportError:
ONNX_AVAILABLE = False
@ -32,25 +30,24 @@ def create_styled_error_figure(title: str, message: str, suggestion: str = None)
main_text += f"<span style='color:#e2e8f0;font-size:14px'>{message}</span>"
if suggestion:
main_text += f"<br><br><span style='color:#63b3ed;font-size:13px'>💡 <b>Suggestion:</b></span><br>"
main_text += "<br><br><span style='color:#63b3ed;font-size:13px'>💡 <b>Suggestion:</b></span><br>"
main_text += f"<span style='color:#cbd5e0;font-size:12px'>{suggestion}</span>"
# Add the main text annotation
fig.add_annotation(
text=main_text,
xref="paper", yref="paper",
x=0.5, y=0.5,
xanchor='center', yanchor='middle',
xref="paper",
yref="paper",
x=0.5,
y=0.5,
xanchor="center",
yanchor="middle",
showarrow=False,
align="center",
borderwidth=2,
bordercolor="#4a5568",
bgcolor="#2d3748",
font=dict(
family="Arial, sans-serif",
size=14,
color="#e2e8f0"
)
font=dict(family="Arial, sans-serif", size=14, color="#e2e8f0"),
)
# Update layout with dark theme
@ -61,7 +58,7 @@ def create_styled_error_figure(title: str, message: str, suggestion: str = None)
margin=dict(l=40, r=40, t=40, b=40),
plot_bgcolor="#1a202c",
paper_bgcolor="#1a202c",
font=dict(color="#e2e8f0")
font=dict(color="#e2e8f0"),
)
# Remove axes and grid
@ -99,13 +96,15 @@ def graph_structure(file_path: Path) -> go.Figure:
# Create network diagram data
node_info = []
for i, node in enumerate(nodes):
node_info.append({
'id': i,
'name': node.name or f"{node.op_type}_{i}",
'op_type': node.op_type,
'inputs': len(node.input),
'outputs': len(node.output)
})
node_info.append(
{
"id": i,
"name": node.name or f"{node.op_type}_{i}",
"op_type": node.op_type,
"inputs": len(node.input),
"outputs": len(node.output),
}
)
# Create visualization
fig = go.Figure()
@ -115,45 +114,50 @@ def graph_structure(file_path: Path) -> go.Figure:
y_positions = [0] * len(node_info)
# Add nodes as scatter points
fig.add_trace(go.Scatter(
fig.add_trace(
go.Scatter(
x=x_positions,
y=y_positions,
mode='markers+text',
mode="markers+text",
marker=dict(
size=[min(max(info['inputs'] + info['outputs'] + 15, 20), 50) for info in node_info],
size=[min(max(info["inputs"] + info["outputs"] + 15, 20), 50) for info in node_info],
color=px.colors.qualitative.Set3[: len(node_info)],
opacity=0.8,
line=dict(width=2, color='white')
line=dict(width=2, color="white"),
),
text=[f"{info['op_type']}" for info in node_info],
textposition="middle center",
textfont=dict(size=10, color="white"),
hovertemplate="<b>%{text}</b><br>" +
"Name: %{customdata[0]}<br>" +
"Inputs: %{customdata[1]}<br>" +
"Outputs: %{customdata[2]}<br>" +
"<extra></extra>",
customdata=[[info['name'], info['inputs'], info['outputs']] for info in node_info],
name="Operators"
))
hovertemplate="<b>%{text}</b><br>"
+ "Name: %{customdata[0]}<br>"
+ "Inputs: %{customdata[1]}<br>"
+ "Outputs: %{customdata[2]}<br>"
+ "<extra></extra>",
customdata=[[info["name"], info["inputs"], info["outputs"]] for info in node_info],
name="Operators",
)
)
# Add connecting lines
for i in range(len(node_info) - 1):
fig.add_trace(go.Scatter(
fig.add_trace(
go.Scatter(
x=[x_positions[i], x_positions[i + 1]],
y=[y_positions[i], y_positions[i + 1]],
mode='lines',
line=dict(color='gray', width=1, dash='dot'),
mode="lines",
line=dict(color="gray", width=1, dash="dot"),
showlegend=False,
hoverinfo='skip'
))
hoverinfo="skip",
)
)
fig.update_layout(
title={
'text': f"ONNX Graph Structure<br><span style='font-size:14px; color:#a0a0a0;'>{len(nodes)} Operators</span>",
'x': 0.5,
'xanchor': 'center',
'font': {'size': 22}
"text": ("ONNX Graph Structure<br>"
f"<span style='font-size:14px; color:#a0a0a0;'>{len(nodes)} Operators</span>"),
"x": 0.5,
"xanchor": "center",
"font": {"size": 22},
},
xaxis_title="Execution Order",
yaxis_title="",
@ -162,7 +166,7 @@ def graph_structure(file_path: Path) -> go.Figure:
template="plotly_dark",
yaxis=dict(showticklabels=False, showgrid=False),
xaxis=dict(showgrid=False),
margin=dict(l=50, r=50, t=80, b=50)
margin=dict(l=50, r=50, t=80, b=50),
)
return fig
@ -170,7 +174,7 @@ def graph_structure(file_path: Path) -> go.Figure:
except Exception as e:
return create_styled_error_figure(
"Graph Analysis Error",
f"Could not analyze ONNX model structure.",
"Could not analyze ONNX model structure.",
f"Error: {str(e)}"
)
@ -201,7 +205,7 @@ def operator_analysis(file_path: Path) -> go.Figure:
return create_styled_error_figure(
"No Operators",
"This ONNX model contains no operators to analyze.",
"Please verify the model file is valid."
"Please verify the model file is valid.",
)
# Sort by frequency
@ -209,9 +213,10 @@ def operator_analysis(file_path: Path) -> go.Figure:
# Create pie chart and bar chart
fig = make_subplots(
rows=2, cols=1,
rows=2,
cols=1,
subplot_titles=("Operator Distribution", "Operator Frequency"),
specs=[[{"type": "pie"}], [{"type": "bar"}]]
specs=[[{"type": "pie"}], [{"type": "bar"}]],
)
# Pie chart for operator distribution
@ -223,9 +228,10 @@ def operator_analysis(file_path: Path) -> go.Figure:
values=list(op_values),
textinfo="label+percent",
textposition="auto",
showlegend=False
showlegend=False,
),
row=1, col=1
row=1,
col=1,
)
# Bar chart for frequency
@ -234,20 +240,22 @@ def operator_analysis(file_path: Path) -> go.Figure:
x=list(op_names),
y=list(op_values),
marker_color=px.colors.qualitative.Set3[: len(op_names)],
showlegend=False
showlegend=False,
),
row=2, col=1
row=2,
col=1,
)
fig.update_layout(
title={
'text': f"ONNX Operator Analysis<br><span style='font-size:14px; color:#a0a0a0;'>{len(op_counts)} Unique Types</span>",
'x': 0.5,
'xanchor': 'center',
'font': {'size': 22}
"text": ("ONNX Operator Analysis<br>"
f"<span style='font-size:14px; color:#a0a0a0;'>{len(op_counts)} Unique Types</span>"),
"x": 0.5,
"xanchor": "center",
"font": {"size": 22},
},
height=700,
template="plotly_dark"
template="plotly_dark",
)
return fig
@ -255,7 +263,7 @@ def operator_analysis(file_path: Path) -> go.Figure:
except Exception as e:
return create_styled_error_figure(
"Operator Analysis Error",
f"Could not analyze ONNX operators.",
"Could not analyze ONNX operators.",
f"Error: {str(e)}"
)
@ -288,7 +296,7 @@ def model_metadata(file_path: Path) -> go.Figure:
try:
tensor = onnx.numpy_helper.to_array(initializer)
total_params += tensor.size
except:
except Exception:
pass # Skip if tensor can't be loaded
# Get model file size
@ -296,10 +304,10 @@ def model_metadata(file_path: Path) -> go.Figure:
# Create metadata display
fig = make_subplots(
rows=2, cols=2,
rows=2,
cols=2,
subplot_titles=("Model Size", "Architecture", "Inputs/Outputs", "Parameters"),
specs=[[{"type": "indicator"}, {"type": "bar"}],
[{"type": "table"}, {"type": "indicator"}]]
specs=[[{"type": "indicator"}, {"type": "bar"}], [{"type": "table"}, {"type": "indicator"}]],
)
# Model size indicator
@ -307,19 +315,20 @@ def model_metadata(file_path: Path) -> go.Figure:
go.Indicator(
mode="number+gauge",
value=file_size_mb,
title={'text': "Model Size (MB)"},
number={'suffix': ' MB', 'valueformat': '.2f'},
title={"text": "Model Size (MB)"},
number={"suffix": " MB", "valueformat": ".2f"},
gauge={
'axis': {'range': [0, max(100, file_size_mb * 1.5)]},
'bar': {'color': "darkblue"},
'steps': [
{'range': [0, 10], 'color': "lightgreen"},
{'range': [10, 50], 'color': "yellow"},
{'range': [50, 100], 'color': "orange"}
]
}
"axis": {"range": [0, max(100, file_size_mb * 1.5)]},
"bar": {"color": "darkblue"},
"steps": [
{"range": [0, 10], "color": "lightgreen"},
{"range": [10, 50], "color": "yellow"},
{"range": [50, 100], "color": "orange"},
],
},
),
row=1, col=1
row=1,
col=1,
)
# Architecture components
@ -330,10 +339,11 @@ def model_metadata(file_path: Path) -> go.Figure:
go.Bar(
x=arch_data,
y=arch_values,
marker_color=['blue', 'green', 'orange', 'red'],
marker_color=["blue", "green", "orange", "red"],
showlegend=False
),
row=1, col=2
row=1,
col=2,
)
# I/O Table
@ -346,17 +356,24 @@ def model_metadata(file_path: Path) -> go.Figure:
if inp.type and inp.type.tensor_type:
# Get shape
if inp.type.tensor_type.shape:
dims = [str(d.dim_value) if d.dim_value > 0 else "?"
for d in inp.type.tensor_type.shape.dim]
dims = [str(d.dim_value) if d.dim_value > 0 else "?" for d in inp.type.tensor_type.shape.dim]
shape = f"[{', '.join(dims)}]"
# Get data type
elem_type = inp.type.tensor_type.elem_type
type_map = {1: 'float32', 2: 'uint8', 3: 'int8', 6: 'int32',
7: 'int64', 9: 'bool', 10: 'float16', 11: 'double'}
dtype = type_map.get(elem_type, f'type_{elem_type}')
type_map = {
1: "float32",
2: "uint8",
3: "int8",
6: "int32",
7: "int64",
9: "bool",
10: "float16",
11: "double",
}
dtype = type_map.get(elem_type, f"type_{elem_type}")
io_data.append(['Input', inp.name[:20], shape, dtype])
io_data.append(["Input", inp.name[:20], shape, dtype])
# Add output info
for out in graph.output[:5]: # Limit to first 5
@ -364,32 +381,40 @@ def model_metadata(file_path: Path) -> go.Figure:
dtype = "Unknown"
if out.type and out.type.tensor_type:
if out.type.tensor_type.shape:
dims = [str(d.dim_value) if d.dim_value > 0 else "?"
for d in out.type.tensor_type.shape.dim]
dims = [str(d.dim_value) if d.dim_value > 0 else "?" for d in out.type.tensor_type.shape.dim]
shape = f"[{', '.join(dims)}]"
elem_type = out.type.tensor_type.elem_type
type_map = {1: 'float32', 2: 'uint8', 3: 'int8', 6: 'int32',
7: 'int64', 9: 'bool', 10: 'float16', 11: 'double'}
dtype = type_map.get(elem_type, f'type_{elem_type}')
type_map = {
1: "float32",
2: "uint8",
3: "int8",
6: "int32",
7: "int64",
9: "bool",
10: "float16",
11: "double",
}
dtype = type_map.get(elem_type, f"type_{elem_type}")
io_data.append(['Output', out.name[:20], shape, dtype])
io_data.append(["Output", out.name[:20], shape, dtype])
if io_data:
fig.add_trace(
go.Table(
header=dict(
values=['Type', 'Name', 'Shape', 'Data Type'],
fill_color='lightblue',
align='left'
values=["Type", "Name", "Shape", "Data Type"],
fill_color="lightblue",
align="left"
),
cells=dict(
values=list(zip(*io_data)),
fill_color='white',
align='left'
)
fill_color="white",
align="left"
),
row=2, col=1
),
row=2,
col=1,
)
# Parameters indicator
@ -397,23 +422,25 @@ def model_metadata(file_path: Path) -> go.Figure:
go.Indicator(
mode="number",
value=total_params,
title={'text': "Total Parameters"},
number={'suffix': 'M', 'valueformat': '.2f'},
number_font_size=30
title={"text": "Total Parameters"},
number={"suffix": "M", "valueformat": ".2f"},
number_font_size=30,
),
row=2, col=2
row=2,
col=2,
)
fig.update_layout(
title={
'text': f"ONNX Model Metadata<br><span style='font-size:14px; color:#a0a0a0;'>{total_params/1e6:.2f}M Parameters</span>",
'x': 0.5,
'xanchor': 'center',
'font': {'size': 22}
"text": ("ONNX Model Metadata<br>"
f"<span style='font-size:14px; color:#a0a0a0;'>{total_params/1e6:.2f}M Parameters</span>"),
"x": 0.5,
"xanchor": "center",
"font": {"size": 22},
},
height=600,
template="plotly_dark",
showlegend=False
showlegend=False,
)
return fig
@ -421,7 +448,7 @@ def model_metadata(file_path: Path) -> go.Figure:
except Exception as e:
return create_styled_error_figure(
"Metadata Analysis Error",
f"Could not extract ONNX model metadata.",
"Could not extract ONNX model metadata.",
f"Error: {str(e)}"
)
@ -435,7 +462,7 @@ def performance_metrics(file_path: Path) -> go.Figure:
return create_styled_error_figure(
"ONNX Not Available",
"ONNX library is required for performance analysis.",
"Install with: pip install onnx"
"Install with: pip install onnx",
)
try:
@ -452,29 +479,27 @@ def performance_metrics(file_path: Path) -> go.Figure:
try:
tensor = onnx.numpy_helper.to_array(initializer)
total_params += tensor.size
except:
except Exception:
pass
# Estimate memory usage (rough approximation)
param_memory_mb = (total_params * 4) / (1024 * 1024) # Assume float32
# Count operations by complexity
compute_ops = ['Conv', 'MatMul', 'Gemm', 'LSTM', 'GRU']
efficient_ops = ['Relu', 'Add', 'Mul', 'BatchNormalization', 'Dropout']
compute_ops = ["Conv", "MatMul", "Gemm", "LSTM", "GRU"]
efficient_ops = ["Relu", "Add", "Mul", "BatchNormalization", "Dropout"]
compute_count = sum(1 for node in graph.node
if any(op in node.op_type for op in compute_ops))
efficient_count = sum(1 for node in graph.node
if any(op in node.op_type for op in efficient_ops))
compute_count = sum(1 for node in graph.node if any(op in node.op_type for op in compute_ops))
efficient_count = sum(1 for node in graph.node if any(op in node.op_type for op in efficient_ops))
total_ops = len(graph.node)
other_count = total_ops - compute_count - efficient_count
# Create performance dashboard
fig = make_subplots(
rows=2, cols=2,
rows=2,
cols=2,
subplot_titles=("Model Efficiency", "Memory Usage", "Operation Types", "Complexity Score"),
specs=[[{"type": "bar"}, {"type": "bar"}],
[{"type": "pie"}, {"type": "indicator"}]]
specs=[[{"type": "bar"}, {"type": "bar"}], [{"type": "pie"}, {"type": "indicator"}]],
)
# Model efficiency metrics
@ -485,10 +510,11 @@ def performance_metrics(file_path: Path) -> go.Figure:
go.Bar(
x=efficiency_metrics,
y=efficiency_values,
marker_color=['blue', 'green', 'orange'],
marker_color=["blue", "green", "orange"],
showlegend=False
),
row=1, col=1
row=1,
col=1,
)
# Memory usage
@ -499,20 +525,22 @@ def performance_metrics(file_path: Path) -> go.Figure:
go.Bar(
x=memory_types,
y=memory_values,
marker_color=['purple', 'red'],
marker_color=["purple", "red"],
showlegend=False
),
row=1, col=2
row=1,
col=2,
)
# Operation types pie chart
fig.add_trace(
go.Pie(
labels=['Compute Ops', 'Efficient Ops', 'Other Ops'],
labels=["Compute Ops", "Efficient Ops", "Other Ops"],
values=[compute_count, efficient_count, other_count],
marker_colors=['red', 'green', 'gray']
marker_colors=["red", "green", "gray"],
),
row=2, col=1
row=2,
col=1,
)
# Complexity score (simple heuristic)
@ -522,30 +550,35 @@ def performance_metrics(file_path: Path) -> go.Figure:
go.Indicator(
mode="gauge+number",
value=complexity_score,
title={'text': "Complexity Score"},
title={"text": "Complexity Score"},
gauge={
'axis': {'range': [0, 100]},
'bar': {'color': "darkred" if complexity_score > 70 else "orange" if complexity_score > 40 else "green"},
'steps': [
{'range': [0, 40], 'color': "lightgreen"},
{'range': [40, 70], 'color': "yellow"},
{'range': [70, 100], 'color': "lightcoral"}
]
}
"axis": {"range": [0, 100]},
"bar": {
"color": "darkred" if complexity_score > 70 else "orange" if complexity_score > 40 else "green"
},
"steps": [
{"range": [0, 40], "color": "lightgreen"},
{"range": [40, 70], "color": "yellow"},
{"range": [70, 100], "color": "lightcoral"},
],
},
),
row=2, col=2
row=2,
col=2,
)
fig.update_layout(
title={
'text': f"ONNX Performance Metrics<br><span style='font-size:14px; color:#a0a0a0;'>Complexity Score: {complexity_score:.0f}/100</span>",
'x': 0.5,
'xanchor': 'center',
'font': {'size': 22}
"text": ("ONNX Performance Metrics<br>"
f"<span style='font-size:14px; color:#a0a0a0;'>"
f"Complexity Score: {complexity_score:.0f}/100</span>"),
"x": 0.5,
"xanchor": "center",
"font": {"size": 22},
},
height=600,
template="plotly_dark",
showlegend=False
showlegend=False,
)
return fig
@ -553,6 +586,6 @@ def performance_metrics(file_path: Path) -> go.Figure:
except Exception as e:
return create_styled_error_figure(
"Performance Analysis Error",
f"Could not analyze ONNX model performance.",
"Could not analyze ONNX model performance.",
f"Error: {str(e)}"
)

View File

@ -1,35 +1,79 @@
import torch
import plotly.graph_objects as go
from plotly.graph_objects import Figure
import numpy as np
def create_styled_error_figure(title: str, message: str, suggestion: str = None) -> go.Figure:
"""Create a professional error figure with Qoherent dark theme styling."""
fig = go.Figure()
# Create a clean, centered text display using Plotly's text formatting
main_text = f"<b style='color:#f56565;font-size:18px'>⚠️ {title}</b><br><br>"
main_text += f"<span style='color:#e2e8f0;font-size:14px'>{message}</span>"
if suggestion:
main_text += "<br><br><span style='color:#63b3ed;font-size:13px'>💡 <b>Suggestion:</b></span><br>"
main_text += f"<span style='color:#cbd5e0;font-size:12px'>{suggestion}</span>"
# Add the main text annotation
fig.add_annotation(
text=main_text,
xref="paper",
yref="paper",
x=0.5,
y=0.5,
xanchor="center",
yanchor="middle",
showarrow=False,
align="center",
borderwidth=2,
bordercolor="#4a5568",
bgcolor="#2d3748",
font=dict(family="Arial, sans-serif", size=14, color="#e2e8f0"),
)
# Update layout with dark theme
fig.update_layout(
title="",
height=400,
template="plotly_dark",
margin=dict(l=40, r=40, t=40, b=40),
plot_bgcolor="#1a202c",
paper_bgcolor="#1a202c",
font=dict(color="#e2e8f0"),
)
# Remove axes and grid
fig.update_xaxes(visible=False)
fig.update_yaxes(visible=False)
return fig
def model_summary_plot(state_dict: dict) -> Figure:
"""Generate a summary plot of the PyTorch model state dict."""
if not state_dict:
# Handle empty state dict
fig = go.Figure()
fig.add_annotation(
text="No parameters found in state dict",
xref="paper", yref="paper",
x=0.5, y=0.5, showarrow=False,
font=dict(size=16)
return create_styled_error_figure(
"Empty State Dict",
"No parameters found in state dict",
"Ensure the model state dictionary contains weight parameters"
)
fig.update_layout(
title="Model Layer Parameter Counts",
xaxis_title="Layer",
yaxis_title="Number of Parameters",
template="plotly_dark"
)
return fig
# Count parameters by layer type
layer_info = []
for key, tensor in state_dict.items():
if 'weight' in key:
try:
layer_name = key.replace('.weight', '')
param_count = tensor.numel() if hasattr(tensor, 'numel') else len(tensor.flatten()) if hasattr(tensor, 'flatten') else 0
shape = list(tensor.shape) if hasattr(tensor, 'shape') else [len(tensor)] if hasattr(tensor, '__len__') else []
param_count = (
tensor.numel() if hasattr(tensor, 'numel')
else len(tensor.flatten()) if hasattr(tensor, 'flatten')
else 0
)
shape = (
list(tensor.shape) if hasattr(tensor, 'shape')
else [len(tensor)] if hasattr(tensor, '__len__')
else []
)
layer_info.append({
'layer': layer_name,
'parameters': param_count,
@ -38,24 +82,12 @@ def model_summary_plot(state_dict: dict) -> Figure:
except Exception as e:
print(f"Warning: Could not process layer {key}: {e}")
continue
if not layer_info:
# Handle case where no weight layers found
fig = go.Figure()
fig.add_annotation(
text="No weight layers found in state dict",
xref="paper", yref="paper",
x=0.5, y=0.5, showarrow=False,
font=dict(size=16)
return create_styled_error_figure(
"No Weight Layers Found",
"No weight layers found in state dict",
"Ensure the state dictionary contains layers with '.weight' parameters"
)
fig.update_layout(
title="Model Layer Parameter Counts",
xaxis_title="Layer",
yaxis_title="Number of Parameters",
template="plotly_dark"
)
return fig
# Create bar chart of parameter counts
fig = go.Figure(data=[
go.Bar(
@ -65,53 +97,35 @@ def model_summary_plot(state_dict: dict) -> Figure:
textposition='auto',
)
])
fig.update_layout(
title="Model Layer Parameter Counts",
xaxis_title="Layer",
yaxis_title="Number of Parameters",
template="plotly_dark"
)
return fig
def layer_weights_plot(state_dict: dict, layer_name: str = None) -> Figure:
"""Visualize weights for a specific layer."""
if not state_dict:
fig = go.Figure()
fig.add_annotation(
text="No data in state dict",
xref="paper", yref="paper",
x=0.5, y=0.5, showarrow=False,
font=dict(size=16)
return create_styled_error_figure(
"Empty State Dict",
"No data in state dict",
"Ensure the model state dictionary contains data"
)
fig.update_layout(
title="Layer Weights",
template="plotly_dark"
)
return fig
if layer_name is None:
# Get first weight tensor
weight_keys = [k for k in state_dict.keys() if 'weight' in k]
if not weight_keys:
fig = go.Figure()
fig.add_annotation(
text="No weight tensors found in state dict",
xref="paper", yref="paper",
x=0.5, y=0.5, showarrow=False,
font=dict(size=16)
return create_styled_error_figure(
"No Weight Tensors Found",
"No weight tensors found in state dict",
"Ensure the state dictionary contains layers with '.weight' parameters"
)
fig.update_layout(
title="Layer Weights",
template="plotly_dark"
)
return fig
layer_name = weight_keys[0]
try:
weights = state_dict[layer_name]
# Convert to numpy if it's a torch tensor
if hasattr(weights, 'numpy'):
weights_np = weights.detach().numpy() if hasattr(weights, 'detach') else weights.numpy()
@ -119,7 +133,6 @@ def layer_weights_plot(state_dict: dict, layer_name: str = None) -> Figure:
weights_np = weights.cpu().detach().numpy()
else:
weights_np = np.array(weights)
# For 2D weights, create heatmap
if len(weights_np.shape) == 2:
fig = go.Figure(data=go.Heatmap(
@ -143,36 +156,21 @@ def layer_weights_plot(state_dict: dict, layer_name: str = None) -> Figure:
return fig
except Exception as e:
fig = go.Figure()
fig.add_annotation(
text=f"Error processing layer {layer_name}: {str(e)}",
xref="paper", yref="paper",
x=0.5, y=0.5, showarrow=False,
font=dict(size=14)
return create_styled_error_figure(
"Layer Processing Error",
f"Error processing layer {layer_name}: {str(e)}",
"Check that the layer name exists and contains valid tensor data"
)
fig.update_layout(
title="Layer Weights - Error",
template="plotly_dark"
)
return fig
def weight_distribution_plot(state_dict: dict) -> Figure:
"""Show distribution of weights across all layers."""
if not state_dict:
fig = go.Figure()
fig.add_annotation(
text="No data in state dict",
xref="paper", yref="paper",
x=0.5, y=0.5, showarrow=False,
font=dict(size=16)
return create_styled_error_figure(
"Empty State Dict",
"No data in state dict",
"Ensure the model state dictionary contains data"
)
fig.update_layout(
title="Overall Weight Distribution",
xaxis_title="Weight Value",
yaxis_title="Frequency",
template="plotly_dark"
)
return fig
all_weights = []
layer_names = []
@ -187,7 +185,6 @@ def weight_distribution_plot(state_dict: dict) -> Figure:
weights_np = tensor.cpu().detach().numpy()
else:
weights_np = np.array(tensor)
flat_weights = weights_np.flatten()
all_weights.extend(flat_weights)
layer_names.extend([key] * len(flat_weights))
@ -196,20 +193,11 @@ def weight_distribution_plot(state_dict: dict) -> Figure:
continue
if not all_weights:
fig = go.Figure()
fig.add_annotation(
text="No weight data found in state dict",
xref="paper", yref="paper",
x=0.5, y=0.5, showarrow=False,
font=dict(size=16)
return create_styled_error_figure(
"No Weight Data Found",
"No weight data found in state dict",
"Ensure the state dictionary contains layers with '.weight' parameters"
)
fig.update_layout(
title="Overall Weight Distribution",
xaxis_title="Weight Value",
yaxis_title="Frequency",
template="plotly_dark"
)
return fig
fig = go.Figure(data=[
go.Histogram(
@ -225,5 +213,4 @@ def weight_distribution_plot(state_dict: dict) -> Figure:
yaxis_title="Frequency",
template="plotly_dark"
)
return fig

View File

@ -6,7 +6,6 @@ import random
from typing import Optional
import numpy as np
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
from plotly.graph_objects import Figure
@ -22,25 +21,24 @@ def create_styled_error_figure(title: str, message: str, suggestion: str = None)
main_text += f"<span style='color:#e2e8f0;font-size:14px'>{message}</span>"
if suggestion:
main_text += f"<br><br><span style='color:#63b3ed;font-size:13px'>💡 <b>Suggestion:</b></span><br>"
main_text += "<br><br><span style='color:#63b3ed;font-size:13px'>💡 <b>Suggestion:</b></span><br>"
main_text += f"<span style='color:#cbd5e0;font-size:12px'>{suggestion}</span>"
# Add the main text annotation
fig.add_annotation(
text=main_text,
xref="paper", yref="paper",
x=0.5, y=0.5,
xanchor='center', yanchor='middle',
xref="paper",
yref="paper",
x=0.5,
y=0.5,
xanchor="center",
yanchor="middle",
showarrow=False,
align="center",
borderwidth=2,
bordercolor="#4a5568",
bgcolor="#2d3748",
font=dict(
family="Arial, sans-serif",
size=14,
color="#e2e8f0"
)
font=dict(family="Arial, sans-serif", size=14, color="#e2e8f0"),
)
# Update layout with dark theme
@ -51,7 +49,7 @@ def create_styled_error_figure(title: str, message: str, suggestion: str = None)
margin=dict(l=40, r=40, t=40, b=40),
plot_bgcolor="#1a202c",
paper_bgcolor="#1a202c",
font=dict(color="#e2e8f0")
font=dict(color="#e2e8f0"),
)
# Remove axes and grid
@ -73,7 +71,7 @@ def _check_dataset_compatibility(dataset, plot_type: str) -> tuple[bool, str]:
if plot_type == "class_distribution":
# Check if we have any categorical columns
categorical_cols = [col for col in metadata.columns if metadata[col].dtype == 'object']
categorical_cols = [col for col in metadata.columns if metadata[col].dtype == "object"]
alternatives = ["class", "label", "modulation", "impairment", "use_case", "category", "labels"]
has_class_col = any(alt in metadata.columns for alt in alternatives)
@ -89,7 +87,7 @@ def _check_dataset_compatibility(dataset, plot_type: str) -> tuple[bool, str]:
# Check if we can access sample data (basic test)
try:
sample_data = dataset[0] if hasattr(dataset, '__getitem__') else None
sample_data = dataset[0] if hasattr(dataset, "__getitem__") else None
if sample_data is None or len(sample_data) < 32:
return False, "Insufficient sample data for spectrogram (need at least 32 points)"
except Exception:
@ -111,7 +109,7 @@ def class_distribution_plot(dataset, class_key: str = "modulation") -> Figure:
return create_styled_error_figure(
"Dataset Not Compatible",
"This dataset doesn't have categorical labels needed for class distribution analysis.",
"Try using the Dataset Overview widget to explore the available data columns."
"Try using the Dataset Overview widget to explore the available data columns.",
)
metadata = dataset.metadata
@ -127,7 +125,7 @@ def class_distribution_plot(dataset, class_key: str = "modulation") -> Figure:
else:
# Use first categorical column
for col in metadata.columns:
if metadata[col].dtype == 'object' or metadata[col].nunique() < 50:
if metadata[col].dtype == "object" or metadata[col].nunique() < 50:
class_key = col
break
@ -135,7 +133,8 @@ def class_distribution_plot(dataset, class_key: str = "modulation") -> Figure:
return create_styled_error_figure(
"No Class Labels Found",
"This dataset contains numerical data without categorical labels.",
"Try using the Dataset Overview widget for data analysis, or check if your dataset has hidden categorical columns."
("Try using the Dataset Overview widget for data analysis, "
"or check if your dataset has hidden categorical columns."),
)
# Count examples per class (limit to top 20 for performance)
@ -146,19 +145,15 @@ def class_distribution_plot(dataset, class_key: str = "modulation") -> Figure:
class_counts = class_counts.sort_index()
# Create simple bar plot
fig = px.bar(
x=class_counts.index,
y=class_counts.values,
title=f'Class Distribution: {class_key.title()}'
)
fig = px.bar(x=class_counts.index, y=class_counts.values, title=f"Class Distribution: {class_key.title()}")
fig.update_traces(texttemplate='%{y}', textposition='outside')
fig.update_traces(texttemplate="%{y}", textposition="outside")
fig.update_layout(
xaxis_title=class_key.title(),
yaxis_title='Number of Examples',
yaxis_title="Number of Examples",
showlegend=False,
height=400,
template="plotly_dark"
template="plotly_dark",
)
return fig
@ -166,8 +161,8 @@ def class_distribution_plot(dataset, class_key: str = "modulation") -> Figure:
except Exception as e:
return create_styled_error_figure(
"Class Distribution Error",
f"An error occurred while generating the class distribution plot.",
f"Technical details: {str(e)}"
"An error occurred while generating the class distribution plot.",
f"Technical details: {str(e)}",
)
@ -180,91 +175,79 @@ def dataset_overview_plot(dataset) -> Figure:
# Create subplot with multiple charts
# Determine subplot titles based on data type
categorical_cols = [col for col in metadata.columns if metadata[col].dtype == 'object']
numeric_cols = [col for col in metadata.columns if metadata[col].dtype in ['int64', 'float64']]
categorical_cols = [col for col in metadata.columns if metadata[col].dtype == "object"]
numeric_cols = [col for col in metadata.columns if metadata[col].dtype in ["int64", "float64"]]
dist_title = "Value Distribution" if categorical_cols else "Data Distribution"
fig = make_subplots(
rows=2, cols=2,
rows=2,
cols=2,
subplot_titles=("Dataset Size", "Data Types", dist_title, "Statistics Summary"),
specs=[[{"type": "indicator"}, {"type": "bar"}],
[{"type": "histogram" if not categorical_cols else "bar"}, {"type": "table"}]]
specs=[
[{"type": "indicator"}, {"type": "bar"}],
[{"type": "histogram" if not categorical_cols else "bar"}, {"type": "table"}],
],
)
# Top left: Dataset size indicator
fig.add_trace(
go.Indicator(
mode="number",
value=total_examples,
title={"text": "Total Examples"},
number={"font": {"size": 40}}
mode="number", value=total_examples, title={"text": "Total Examples"}, number={"font": {"size": 40}}
),
row=1, col=1
row=1,
col=1,
)
# Top right: Data types distribution
dtype_counts = metadata.dtypes.value_counts()
fig.add_trace(
go.Bar(
x=[str(dt) for dt in dtype_counts.index],
y=dtype_counts.values,
name="Data Types",
showlegend=False
x=[str(dt) for dt in dtype_counts.index], y=dtype_counts.values, name="Data Types", showlegend=False
),
row=1, col=2
row=1,
col=2,
)
# Bottom left: Show distribution of numeric columns or categorical if available
categorical_cols = [col for col in metadata.columns if metadata[col].dtype == 'object']
numeric_cols = [col for col in metadata.columns if metadata[col].dtype in ['int64', 'float64']]
categorical_cols = [col for col in metadata.columns if metadata[col].dtype == "object"]
numeric_cols = [col for col in metadata.columns if metadata[col].dtype in ["int64", "float64"]]
if categorical_cols:
col = categorical_cols[0] # Show first categorical column
value_counts = metadata[col].value_counts().head(10)
fig.add_trace(
go.Bar(
x=value_counts.index,
y=value_counts.values,
name=f"{col} Distribution",
showlegend=False
),
row=2, col=1
go.Bar(x=value_counts.index, y=value_counts.values, name=f"{col} Distribution", showlegend=False),
row=2,
col=1,
)
elif numeric_cols:
# Show histogram of first numeric column
col = numeric_cols[0]
fig.add_trace(
go.Histogram(
x=metadata[col],
name=f"{col} Distribution",
showlegend=False,
nbinsx=20
),
row=2, col=1
go.Histogram(x=metadata[col], name=f"{col} Distribution", showlegend=False, nbinsx=20), row=2, col=1
)
# Bottom right: Basic statistics table
stats_data = []
display_cols = (numeric_cols[:5] if len(numeric_cols) > 0 else metadata.columns[:5])
display_cols = numeric_cols[:5] if len(numeric_cols) > 0 else metadata.columns[:5]
for col in display_cols:
if metadata[col].dtype in ['int64', 'float64']:
stats_data.append([
if metadata[col].dtype in ["int64", "float64"]:
stats_data.append(
[
col[:15] + "..." if len(col) > 15 else col, # Truncate long column names
f"{metadata[col].mean():.3f}",
f"{metadata[col].std():.3f}",
f"{metadata[col].min():.3f}",
f"{metadata[col].max():.3f}"
])
f"{metadata[col].max():.3f}",
]
)
else:
unique_count = metadata[col].nunique()
stats_data.append([
col[:15] + "..." if len(col) > 15 else col,
"N/A", "N/A",
f"{unique_count} unique",
"N/A"
])
stats_data.append(
[col[:15] + "..." if len(col) > 15 else col, "N/A", "N/A", f"{unique_count} unique", "N/A"]
)
if stats_data:
fig.add_trace(
@ -273,30 +256,26 @@ def dataset_overview_plot(dataset) -> Figure:
values=["Column", "Mean", "Std", "Min/Unique", "Max"],
fill_color="rgba(30, 30, 30, 0.8)",
align="center",
font=dict(color="white", size=12)
font=dict(color="white", size=12),
),
cells=dict(
values=list(zip(*stats_data)),
fill_color="rgba(50, 50, 50, 0.6)",
align="center",
font=dict(color="white", size=11)
)
font=dict(color="white", size=11),
),
row=2, col=2
),
row=2,
col=2,
)
# Create informative title
total_cols = len(metadata.columns)
title = f"Dataset Overview - {total_examples} samples, {total_cols} columns"
if total_cols > 5:
title += f" (showing first 5)"
title += " (showing first 5)"
fig.update_layout(
title=title,
height=600,
showlegend=False,
template="plotly_dark"
)
fig.update_layout(title=title, height=600, showlegend=False, template="plotly_dark")
return fig
@ -304,10 +283,100 @@ def dataset_overview_plot(dataset) -> Figure:
return create_styled_error_figure(
"Dataset Overview Error",
"An error occurred while generating the dataset overview.",
f"Technical details: {str(e)}"
f"Technical details: {str(e)}",
)
def _find_class_column(metadata, class_key: str) -> str:
"""Find the appropriate class column in metadata."""
if class_key in metadata.columns:
return class_key
alternatives = ["class", "label", "modulation", "impairment", "use_case"]
for alt in alternatives:
if alt in metadata.columns:
return alt
return class_key
def _get_sample_data(dataset, sample_idx: int):
"""Get sample data from dataset, with synthetic fallback."""
try:
return dataset[sample_idx]
except Exception:
# Generate synthetic signal based on class
n_samples = 1024
t = np.linspace(0, 1, n_samples)
freq = 0.1 + 0.05 * sample_idx % 5 # Vary frequency by sample
sample_data = np.exp(1j * 2 * np.pi * freq * t)
# Add some noise
sample_data += 0.1 * (np.random.randn(n_samples) + 1j * np.random.randn(n_samples))
return sample_data
def _calculate_spectrogram_params(n_samples: int) -> tuple[int, int, int, int]:
"""Calculate spectrogram parameters based on sample length."""
if n_samples < 32:
raise ValueError(f"Insufficient data: need at least 32 samples, got {n_samples}")
nperseg = min(256, max(32, n_samples // 4))
hop_length = max(1, nperseg // 2)
# Adjust for very short signals
if n_samples < nperseg:
nperseg = n_samples
hop_length = 1
n_frames = max(1, (n_samples - nperseg) // hop_length + 1)
freq_bins = max(1, nperseg // 2)
return nperseg, hop_length, n_frames, freq_bins
def _compute_spectrogram(sample_data, nperseg: int, hop_length: int, n_frames: int, freq_bins: int):
"""Compute spectrogram using FFT."""
n_samples = len(sample_data)
Sxx = np.zeros((freq_bins, n_frames))
for i in range(n_frames):
start_idx = i * hop_length
end_idx = min(start_idx + nperseg, n_samples)
if end_idx > start_idx:
windowed = sample_data[start_idx:end_idx]
# Pad if necessary to maintain nperseg size
if len(windowed) < nperseg:
windowed = np.pad(windowed, (0, nperseg - len(windowed)), mode="constant")
fft_result = np.fft.fft(windowed)
Sxx[:, i] = np.abs(fft_result[:freq_bins]) ** 2
return Sxx
def _create_spectrogram_figure(Sxx, n_frames: int, hop_length: int, n_samples: int, freq_bins: int,
sample_idx: int, class_key: str, sample_metadata) -> Figure:
"""Create the plotly figure for the spectrogram."""
# Convert to dB
Sxx_db = 10 * np.log10(Sxx + 1e-10)
# Create time and frequency vectors
t = np.arange(n_frames) * hop_length / max(1, n_samples)
f = np.linspace(0, 0.5, freq_bins)
# Create plot
fig = go.Figure(data=go.Heatmap(z=Sxx_db, x=t, y=f, colorscale="viridis", colorbar=dict(title="Power (dB)")))
# Add title with metadata
title = f"Sample Spectrogram (Index: {sample_idx})"
if class_key in sample_metadata:
title += f" - {class_key}: {sample_metadata[class_key]}"
fig.update_layout(title=title, xaxis_title="Time", yaxis_title="Frequency", height=400, template="plotly_dark")
return fig
def sample_spectrogram_plot(dataset, class_key: str = "modulation", sample_idx: Optional[int] = None) -> Figure:
"""Generate a spectrogram plot from a sample in the dataset."""
try:
@ -317,114 +386,36 @@ def sample_spectrogram_plot(dataset, class_key: str = "modulation", sample_idx:
return create_styled_error_figure(
"Spectrogram Not Available",
"This dataset doesn't have sufficient signal data for spectrogram visualization.",
"Ensure your dataset contains complex-valued signal samples with at least 32 data points per sample."
"Ensure your dataset contains complex-valued signal samples with at least 32 data points per sample.",
)
metadata = dataset.metadata
if len(metadata) == 0:
raise ValueError("Dataset is empty")
# Find class column
if class_key not in metadata.columns:
alternatives = ["class", "label", "modulation", "impairment", "use_case"]
for alt in alternatives:
if alt in metadata.columns:
class_key = alt
break
# Select sample
# Find class column and select sample
class_key = _find_class_column(metadata, class_key)
if sample_idx is None:
sample_idx = random.randint(0, len(metadata) - 1)
sample_metadata = metadata.iloc[sample_idx]
# Try to get actual sample data, fall back to synthetic
try:
sample_data = dataset[sample_idx]
except:
# Generate synthetic signal based on class
n_samples = 1024
t = np.linspace(0, 1, n_samples)
freq = 0.1 + 0.05 * sample_idx % 5 # Vary frequency by sample
sample_data = np.exp(1j * 2 * np.pi * freq * t)
# Add some noise
sample_data += 0.1 * (np.random.randn(n_samples) + 1j * np.random.randn(n_samples))
# Ensure complex data
# Get sample data and ensure it's complex
sample_data = _get_sample_data(dataset, sample_idx)
if not np.iscomplexobj(sample_data):
sample_data = sample_data.astype(complex)
# Simple FFT-based spectrogram
# Calculate spectrogram parameters and compute spectrogram
n_samples = len(sample_data)
nperseg, hop_length, n_frames, freq_bins = _calculate_spectrogram_params(n_samples)
Sxx = _compute_spectrogram(sample_data, nperseg, hop_length, n_frames, freq_bins)
# Ensure minimum viable data size
if n_samples < 32:
raise ValueError(f"Insufficient data: need at least 32 samples, got {n_samples}")
nperseg = min(256, max(32, n_samples // 4))
# Create spectrogram using numpy (no scipy dependency)
hop_length = max(1, nperseg // 2) # Prevent zero hop_length
# Ensure we can create at least one frame
if n_samples < nperseg:
nperseg = n_samples
hop_length = 1
n_frames = max(1, (n_samples - nperseg) // hop_length + 1)
freq_bins = max(1, nperseg // 2) # Prevent zero frequency bins
Sxx = np.zeros((freq_bins, n_frames))
for i in range(n_frames):
start_idx = i * hop_length
end_idx = min(start_idx + nperseg, n_samples) # Prevent index overflow
if end_idx > start_idx: # Ensure we have data to process
windowed = sample_data[start_idx:end_idx]
# Pad if necessary to maintain nperseg size
if len(windowed) < nperseg:
windowed = np.pad(windowed, (0, nperseg - len(windowed)), mode='constant')
fft_result = np.fft.fft(windowed)
Sxx[:, i] = np.abs(fft_result[:freq_bins]) ** 2
# Convert to dB
Sxx_db = 10 * np.log10(Sxx + 1e-10)
# Create time and frequency vectors
t = np.arange(n_frames) * hop_length / max(1, n_samples) # Prevent division by zero
f = np.linspace(0, 0.5, freq_bins)
# Create plot
fig = go.Figure(data=go.Heatmap(
z=Sxx_db,
x=t,
y=f,
colorscale='viridis',
colorbar=dict(title="Power (dB)")
))
# Add title with metadata
title = f"Sample Spectrogram (Index: {sample_idx})"
if class_key in sample_metadata:
title += f" - {class_key}: {sample_metadata[class_key]}"
fig.update_layout(
title=title,
xaxis_title="Time",
yaxis_title="Frequency",
height=400,
template="plotly_dark"
)
return fig
# Create and return the figure
return _create_spectrogram_figure(Sxx, n_frames, hop_length, n_samples, freq_bins,
sample_idx, class_key, sample_metadata)
except Exception as e:
return create_styled_error_figure(
"Spectrogram Error",
"An error occurred while generating the spectrogram plot.",
f"Technical details: {str(e)}"
f"Technical details: {str(e)}",
)