Compare commits
2 Commits
e863040e19
...
c06e58f5d6
| Author | SHA1 | Date | |
|---|---|---|---|
| c06e58f5d6 | |||
| c7c7100d46 |
|
|
@ -6,18 +6,16 @@ as other ria-toolkit-oss visualization modules.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
import plotly.graph_objects as go
|
|
||||||
import plotly.express as px
|
import plotly.express as px
|
||||||
|
import plotly.graph_objects as go
|
||||||
from plotly.subplots import make_subplots
|
from plotly.subplots import make_subplots
|
||||||
import pandas as pd
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import onnx
|
import onnx
|
||||||
import onnx.helper
|
import onnx.helper
|
||||||
import onnx.numpy_helper
|
import onnx.numpy_helper
|
||||||
|
|
||||||
ONNX_AVAILABLE = True
|
ONNX_AVAILABLE = True
|
||||||
except ImportError:
|
except ImportError:
|
||||||
ONNX_AVAILABLE = False
|
ONNX_AVAILABLE = False
|
||||||
|
|
@ -32,25 +30,24 @@ def create_styled_error_figure(title: str, message: str, suggestion: str = None)
|
||||||
main_text += f"<span style='color:#e2e8f0;font-size:14px'>{message}</span>"
|
main_text += f"<span style='color:#e2e8f0;font-size:14px'>{message}</span>"
|
||||||
|
|
||||||
if suggestion:
|
if suggestion:
|
||||||
main_text += f"<br><br><span style='color:#63b3ed;font-size:13px'>💡 <b>Suggestion:</b></span><br>"
|
main_text += "<br><br><span style='color:#63b3ed;font-size:13px'>💡 <b>Suggestion:</b></span><br>"
|
||||||
main_text += f"<span style='color:#cbd5e0;font-size:12px'>{suggestion}</span>"
|
main_text += f"<span style='color:#cbd5e0;font-size:12px'>{suggestion}</span>"
|
||||||
|
|
||||||
# Add the main text annotation
|
# Add the main text annotation
|
||||||
fig.add_annotation(
|
fig.add_annotation(
|
||||||
text=main_text,
|
text=main_text,
|
||||||
xref="paper", yref="paper",
|
xref="paper",
|
||||||
x=0.5, y=0.5,
|
yref="paper",
|
||||||
xanchor='center', yanchor='middle',
|
x=0.5,
|
||||||
|
y=0.5,
|
||||||
|
xanchor="center",
|
||||||
|
yanchor="middle",
|
||||||
showarrow=False,
|
showarrow=False,
|
||||||
align="center",
|
align="center",
|
||||||
borderwidth=2,
|
borderwidth=2,
|
||||||
bordercolor="#4a5568",
|
bordercolor="#4a5568",
|
||||||
bgcolor="#2d3748",
|
bgcolor="#2d3748",
|
||||||
font=dict(
|
font=dict(family="Arial, sans-serif", size=14, color="#e2e8f0"),
|
||||||
family="Arial, sans-serif",
|
|
||||||
size=14,
|
|
||||||
color="#e2e8f0"
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Update layout with dark theme
|
# Update layout with dark theme
|
||||||
|
|
@ -61,7 +58,7 @@ def create_styled_error_figure(title: str, message: str, suggestion: str = None)
|
||||||
margin=dict(l=40, r=40, t=40, b=40),
|
margin=dict(l=40, r=40, t=40, b=40),
|
||||||
plot_bgcolor="#1a202c",
|
plot_bgcolor="#1a202c",
|
||||||
paper_bgcolor="#1a202c",
|
paper_bgcolor="#1a202c",
|
||||||
font=dict(color="#e2e8f0")
|
font=dict(color="#e2e8f0"),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Remove axes and grid
|
# Remove axes and grid
|
||||||
|
|
@ -99,13 +96,15 @@ def graph_structure(file_path: Path) -> go.Figure:
|
||||||
# Create network diagram data
|
# Create network diagram data
|
||||||
node_info = []
|
node_info = []
|
||||||
for i, node in enumerate(nodes):
|
for i, node in enumerate(nodes):
|
||||||
node_info.append({
|
node_info.append(
|
||||||
'id': i,
|
{
|
||||||
'name': node.name or f"{node.op_type}_{i}",
|
"id": i,
|
||||||
'op_type': node.op_type,
|
"name": node.name or f"{node.op_type}_{i}",
|
||||||
'inputs': len(node.input),
|
"op_type": node.op_type,
|
||||||
'outputs': len(node.output)
|
"inputs": len(node.input),
|
||||||
})
|
"outputs": len(node.output),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
# Create visualization
|
# Create visualization
|
||||||
fig = go.Figure()
|
fig = go.Figure()
|
||||||
|
|
@ -115,45 +114,50 @@ def graph_structure(file_path: Path) -> go.Figure:
|
||||||
y_positions = [0] * len(node_info)
|
y_positions = [0] * len(node_info)
|
||||||
|
|
||||||
# Add nodes as scatter points
|
# Add nodes as scatter points
|
||||||
fig.add_trace(go.Scatter(
|
fig.add_trace(
|
||||||
|
go.Scatter(
|
||||||
x=x_positions,
|
x=x_positions,
|
||||||
y=y_positions,
|
y=y_positions,
|
||||||
mode='markers+text',
|
mode="markers+text",
|
||||||
marker=dict(
|
marker=dict(
|
||||||
size=[min(max(info['inputs'] + info['outputs'] + 15, 20), 50) for info in node_info],
|
size=[min(max(info["inputs"] + info["outputs"] + 15, 20), 50) for info in node_info],
|
||||||
color=px.colors.qualitative.Set3[:len(node_info)],
|
color=px.colors.qualitative.Set3[: len(node_info)],
|
||||||
opacity=0.8,
|
opacity=0.8,
|
||||||
line=dict(width=2, color='white')
|
line=dict(width=2, color="white"),
|
||||||
),
|
),
|
||||||
text=[f"{info['op_type']}" for info in node_info],
|
text=[f"{info['op_type']}" for info in node_info],
|
||||||
textposition="middle center",
|
textposition="middle center",
|
||||||
textfont=dict(size=10, color="white"),
|
textfont=dict(size=10, color="white"),
|
||||||
hovertemplate="<b>%{text}</b><br>" +
|
hovertemplate="<b>%{text}</b><br>"
|
||||||
"Name: %{customdata[0]}<br>" +
|
+ "Name: %{customdata[0]}<br>"
|
||||||
"Inputs: %{customdata[1]}<br>" +
|
+ "Inputs: %{customdata[1]}<br>"
|
||||||
"Outputs: %{customdata[2]}<br>" +
|
+ "Outputs: %{customdata[2]}<br>"
|
||||||
"<extra></extra>",
|
+ "<extra></extra>",
|
||||||
customdata=[[info['name'], info['inputs'], info['outputs']] for info in node_info],
|
customdata=[[info["name"], info["inputs"], info["outputs"]] for info in node_info],
|
||||||
name="Operators"
|
name="Operators",
|
||||||
))
|
)
|
||||||
|
)
|
||||||
|
|
||||||
# Add connecting lines
|
# Add connecting lines
|
||||||
for i in range(len(node_info) - 1):
|
for i in range(len(node_info) - 1):
|
||||||
fig.add_trace(go.Scatter(
|
fig.add_trace(
|
||||||
x=[x_positions[i], x_positions[i+1]],
|
go.Scatter(
|
||||||
y=[y_positions[i], y_positions[i+1]],
|
x=[x_positions[i], x_positions[i + 1]],
|
||||||
mode='lines',
|
y=[y_positions[i], y_positions[i + 1]],
|
||||||
line=dict(color='gray', width=1, dash='dot'),
|
mode="lines",
|
||||||
|
line=dict(color="gray", width=1, dash="dot"),
|
||||||
showlegend=False,
|
showlegend=False,
|
||||||
hoverinfo='skip'
|
hoverinfo="skip",
|
||||||
))
|
)
|
||||||
|
)
|
||||||
|
|
||||||
fig.update_layout(
|
fig.update_layout(
|
||||||
title={
|
title={
|
||||||
'text': f"ONNX Graph Structure<br><span style='font-size:14px; color:#a0a0a0;'>{len(nodes)} Operators</span>",
|
"text": ("ONNX Graph Structure<br>"
|
||||||
'x': 0.5,
|
f"<span style='font-size:14px; color:#a0a0a0;'>{len(nodes)} Operators</span>"),
|
||||||
'xanchor': 'center',
|
"x": 0.5,
|
||||||
'font': {'size': 22}
|
"xanchor": "center",
|
||||||
|
"font": {"size": 22},
|
||||||
},
|
},
|
||||||
xaxis_title="Execution Order",
|
xaxis_title="Execution Order",
|
||||||
yaxis_title="",
|
yaxis_title="",
|
||||||
|
|
@ -162,7 +166,7 @@ def graph_structure(file_path: Path) -> go.Figure:
|
||||||
template="plotly_dark",
|
template="plotly_dark",
|
||||||
yaxis=dict(showticklabels=False, showgrid=False),
|
yaxis=dict(showticklabels=False, showgrid=False),
|
||||||
xaxis=dict(showgrid=False),
|
xaxis=dict(showgrid=False),
|
||||||
margin=dict(l=50, r=50, t=80, b=50)
|
margin=dict(l=50, r=50, t=80, b=50),
|
||||||
)
|
)
|
||||||
|
|
||||||
return fig
|
return fig
|
||||||
|
|
@ -170,7 +174,7 @@ def graph_structure(file_path: Path) -> go.Figure:
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return create_styled_error_figure(
|
return create_styled_error_figure(
|
||||||
"Graph Analysis Error",
|
"Graph Analysis Error",
|
||||||
f"Could not analyze ONNX model structure.",
|
"Could not analyze ONNX model structure.",
|
||||||
f"Error: {str(e)}"
|
f"Error: {str(e)}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -201,7 +205,7 @@ def operator_analysis(file_path: Path) -> go.Figure:
|
||||||
return create_styled_error_figure(
|
return create_styled_error_figure(
|
||||||
"No Operators",
|
"No Operators",
|
||||||
"This ONNX model contains no operators to analyze.",
|
"This ONNX model contains no operators to analyze.",
|
||||||
"Please verify the model file is valid."
|
"Please verify the model file is valid.",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Sort by frequency
|
# Sort by frequency
|
||||||
|
|
@ -209,9 +213,10 @@ def operator_analysis(file_path: Path) -> go.Figure:
|
||||||
|
|
||||||
# Create pie chart and bar chart
|
# Create pie chart and bar chart
|
||||||
fig = make_subplots(
|
fig = make_subplots(
|
||||||
rows=2, cols=1,
|
rows=2,
|
||||||
|
cols=1,
|
||||||
subplot_titles=("Operator Distribution", "Operator Frequency"),
|
subplot_titles=("Operator Distribution", "Operator Frequency"),
|
||||||
specs=[[{"type": "pie"}], [{"type": "bar"}]]
|
specs=[[{"type": "pie"}], [{"type": "bar"}]],
|
||||||
)
|
)
|
||||||
|
|
||||||
# Pie chart for operator distribution
|
# Pie chart for operator distribution
|
||||||
|
|
@ -223,9 +228,10 @@ def operator_analysis(file_path: Path) -> go.Figure:
|
||||||
values=list(op_values),
|
values=list(op_values),
|
||||||
textinfo="label+percent",
|
textinfo="label+percent",
|
||||||
textposition="auto",
|
textposition="auto",
|
||||||
showlegend=False
|
showlegend=False,
|
||||||
),
|
),
|
||||||
row=1, col=1
|
row=1,
|
||||||
|
col=1,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Bar chart for frequency
|
# Bar chart for frequency
|
||||||
|
|
@ -233,21 +239,23 @@ def operator_analysis(file_path: Path) -> go.Figure:
|
||||||
go.Bar(
|
go.Bar(
|
||||||
x=list(op_names),
|
x=list(op_names),
|
||||||
y=list(op_values),
|
y=list(op_values),
|
||||||
marker_color=px.colors.qualitative.Set3[:len(op_names)],
|
marker_color=px.colors.qualitative.Set3[: len(op_names)],
|
||||||
showlegend=False
|
showlegend=False,
|
||||||
),
|
),
|
||||||
row=2, col=1
|
row=2,
|
||||||
|
col=1,
|
||||||
)
|
)
|
||||||
|
|
||||||
fig.update_layout(
|
fig.update_layout(
|
||||||
title={
|
title={
|
||||||
'text': f"ONNX Operator Analysis<br><span style='font-size:14px; color:#a0a0a0;'>{len(op_counts)} Unique Types</span>",
|
"text": ("ONNX Operator Analysis<br>"
|
||||||
'x': 0.5,
|
f"<span style='font-size:14px; color:#a0a0a0;'>{len(op_counts)} Unique Types</span>"),
|
||||||
'xanchor': 'center',
|
"x": 0.5,
|
||||||
'font': {'size': 22}
|
"xanchor": "center",
|
||||||
|
"font": {"size": 22},
|
||||||
},
|
},
|
||||||
height=700,
|
height=700,
|
||||||
template="plotly_dark"
|
template="plotly_dark",
|
||||||
)
|
)
|
||||||
|
|
||||||
return fig
|
return fig
|
||||||
|
|
@ -255,7 +263,7 @@ def operator_analysis(file_path: Path) -> go.Figure:
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return create_styled_error_figure(
|
return create_styled_error_figure(
|
||||||
"Operator Analysis Error",
|
"Operator Analysis Error",
|
||||||
f"Could not analyze ONNX operators.",
|
"Could not analyze ONNX operators.",
|
||||||
f"Error: {str(e)}"
|
f"Error: {str(e)}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -288,7 +296,7 @@ def model_metadata(file_path: Path) -> go.Figure:
|
||||||
try:
|
try:
|
||||||
tensor = onnx.numpy_helper.to_array(initializer)
|
tensor = onnx.numpy_helper.to_array(initializer)
|
||||||
total_params += tensor.size
|
total_params += tensor.size
|
||||||
except:
|
except Exception:
|
||||||
pass # Skip if tensor can't be loaded
|
pass # Skip if tensor can't be loaded
|
||||||
|
|
||||||
# Get model file size
|
# Get model file size
|
||||||
|
|
@ -296,10 +304,10 @@ def model_metadata(file_path: Path) -> go.Figure:
|
||||||
|
|
||||||
# Create metadata display
|
# Create metadata display
|
||||||
fig = make_subplots(
|
fig = make_subplots(
|
||||||
rows=2, cols=2,
|
rows=2,
|
||||||
|
cols=2,
|
||||||
subplot_titles=("Model Size", "Architecture", "Inputs/Outputs", "Parameters"),
|
subplot_titles=("Model Size", "Architecture", "Inputs/Outputs", "Parameters"),
|
||||||
specs=[[{"type": "indicator"}, {"type": "bar"}],
|
specs=[[{"type": "indicator"}, {"type": "bar"}], [{"type": "table"}, {"type": "indicator"}]],
|
||||||
[{"type": "table"}, {"type": "indicator"}]]
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Model size indicator
|
# Model size indicator
|
||||||
|
|
@ -307,19 +315,20 @@ def model_metadata(file_path: Path) -> go.Figure:
|
||||||
go.Indicator(
|
go.Indicator(
|
||||||
mode="number+gauge",
|
mode="number+gauge",
|
||||||
value=file_size_mb,
|
value=file_size_mb,
|
||||||
title={'text': "Model Size (MB)"},
|
title={"text": "Model Size (MB)"},
|
||||||
number={'suffix': ' MB', 'valueformat': '.2f'},
|
number={"suffix": " MB", "valueformat": ".2f"},
|
||||||
gauge={
|
gauge={
|
||||||
'axis': {'range': [0, max(100, file_size_mb * 1.5)]},
|
"axis": {"range": [0, max(100, file_size_mb * 1.5)]},
|
||||||
'bar': {'color': "darkblue"},
|
"bar": {"color": "darkblue"},
|
||||||
'steps': [
|
"steps": [
|
||||||
{'range': [0, 10], 'color': "lightgreen"},
|
{"range": [0, 10], "color": "lightgreen"},
|
||||||
{'range': [10, 50], 'color': "yellow"},
|
{"range": [10, 50], "color": "yellow"},
|
||||||
{'range': [50, 100], 'color': "orange"}
|
{"range": [50, 100], "color": "orange"},
|
||||||
]
|
],
|
||||||
}
|
},
|
||||||
),
|
),
|
||||||
row=1, col=1
|
row=1,
|
||||||
|
col=1,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Architecture components
|
# Architecture components
|
||||||
|
|
@ -330,10 +339,11 @@ def model_metadata(file_path: Path) -> go.Figure:
|
||||||
go.Bar(
|
go.Bar(
|
||||||
x=arch_data,
|
x=arch_data,
|
||||||
y=arch_values,
|
y=arch_values,
|
||||||
marker_color=['blue', 'green', 'orange', 'red'],
|
marker_color=["blue", "green", "orange", "red"],
|
||||||
showlegend=False
|
showlegend=False
|
||||||
),
|
),
|
||||||
row=1, col=2
|
row=1,
|
||||||
|
col=2,
|
||||||
)
|
)
|
||||||
|
|
||||||
# I/O Table
|
# I/O Table
|
||||||
|
|
@ -346,17 +356,24 @@ def model_metadata(file_path: Path) -> go.Figure:
|
||||||
if inp.type and inp.type.tensor_type:
|
if inp.type and inp.type.tensor_type:
|
||||||
# Get shape
|
# Get shape
|
||||||
if inp.type.tensor_type.shape:
|
if inp.type.tensor_type.shape:
|
||||||
dims = [str(d.dim_value) if d.dim_value > 0 else "?"
|
dims = [str(d.dim_value) if d.dim_value > 0 else "?" for d in inp.type.tensor_type.shape.dim]
|
||||||
for d in inp.type.tensor_type.shape.dim]
|
|
||||||
shape = f"[{', '.join(dims)}]"
|
shape = f"[{', '.join(dims)}]"
|
||||||
|
|
||||||
# Get data type
|
# Get data type
|
||||||
elem_type = inp.type.tensor_type.elem_type
|
elem_type = inp.type.tensor_type.elem_type
|
||||||
type_map = {1: 'float32', 2: 'uint8', 3: 'int8', 6: 'int32',
|
type_map = {
|
||||||
7: 'int64', 9: 'bool', 10: 'float16', 11: 'double'}
|
1: "float32",
|
||||||
dtype = type_map.get(elem_type, f'type_{elem_type}')
|
2: "uint8",
|
||||||
|
3: "int8",
|
||||||
|
6: "int32",
|
||||||
|
7: "int64",
|
||||||
|
9: "bool",
|
||||||
|
10: "float16",
|
||||||
|
11: "double",
|
||||||
|
}
|
||||||
|
dtype = type_map.get(elem_type, f"type_{elem_type}")
|
||||||
|
|
||||||
io_data.append(['Input', inp.name[:20], shape, dtype])
|
io_data.append(["Input", inp.name[:20], shape, dtype])
|
||||||
|
|
||||||
# Add output info
|
# Add output info
|
||||||
for out in graph.output[:5]: # Limit to first 5
|
for out in graph.output[:5]: # Limit to first 5
|
||||||
|
|
@ -364,32 +381,40 @@ def model_metadata(file_path: Path) -> go.Figure:
|
||||||
dtype = "Unknown"
|
dtype = "Unknown"
|
||||||
if out.type and out.type.tensor_type:
|
if out.type and out.type.tensor_type:
|
||||||
if out.type.tensor_type.shape:
|
if out.type.tensor_type.shape:
|
||||||
dims = [str(d.dim_value) if d.dim_value > 0 else "?"
|
dims = [str(d.dim_value) if d.dim_value > 0 else "?" for d in out.type.tensor_type.shape.dim]
|
||||||
for d in out.type.tensor_type.shape.dim]
|
|
||||||
shape = f"[{', '.join(dims)}]"
|
shape = f"[{', '.join(dims)}]"
|
||||||
|
|
||||||
elem_type = out.type.tensor_type.elem_type
|
elem_type = out.type.tensor_type.elem_type
|
||||||
type_map = {1: 'float32', 2: 'uint8', 3: 'int8', 6: 'int32',
|
type_map = {
|
||||||
7: 'int64', 9: 'bool', 10: 'float16', 11: 'double'}
|
1: "float32",
|
||||||
dtype = type_map.get(elem_type, f'type_{elem_type}')
|
2: "uint8",
|
||||||
|
3: "int8",
|
||||||
|
6: "int32",
|
||||||
|
7: "int64",
|
||||||
|
9: "bool",
|
||||||
|
10: "float16",
|
||||||
|
11: "double",
|
||||||
|
}
|
||||||
|
dtype = type_map.get(elem_type, f"type_{elem_type}")
|
||||||
|
|
||||||
io_data.append(['Output', out.name[:20], shape, dtype])
|
io_data.append(["Output", out.name[:20], shape, dtype])
|
||||||
|
|
||||||
if io_data:
|
if io_data:
|
||||||
fig.add_trace(
|
fig.add_trace(
|
||||||
go.Table(
|
go.Table(
|
||||||
header=dict(
|
header=dict(
|
||||||
values=['Type', 'Name', 'Shape', 'Data Type'],
|
values=["Type", "Name", "Shape", "Data Type"],
|
||||||
fill_color='lightblue',
|
fill_color="lightblue",
|
||||||
align='left'
|
align="left"
|
||||||
),
|
),
|
||||||
cells=dict(
|
cells=dict(
|
||||||
values=list(zip(*io_data)),
|
values=list(zip(*io_data)),
|
||||||
fill_color='white',
|
fill_color="white",
|
||||||
align='left'
|
align="left"
|
||||||
)
|
|
||||||
),
|
),
|
||||||
row=2, col=1
|
),
|
||||||
|
row=2,
|
||||||
|
col=1,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Parameters indicator
|
# Parameters indicator
|
||||||
|
|
@ -397,23 +422,25 @@ def model_metadata(file_path: Path) -> go.Figure:
|
||||||
go.Indicator(
|
go.Indicator(
|
||||||
mode="number",
|
mode="number",
|
||||||
value=total_params,
|
value=total_params,
|
||||||
title={'text': "Total Parameters"},
|
title={"text": "Total Parameters"},
|
||||||
number={'suffix': 'M', 'valueformat': '.2f'},
|
number={"suffix": "M", "valueformat": ".2f"},
|
||||||
number_font_size=30
|
number_font_size=30,
|
||||||
),
|
),
|
||||||
row=2, col=2
|
row=2,
|
||||||
|
col=2,
|
||||||
)
|
)
|
||||||
|
|
||||||
fig.update_layout(
|
fig.update_layout(
|
||||||
title={
|
title={
|
||||||
'text': f"ONNX Model Metadata<br><span style='font-size:14px; color:#a0a0a0;'>{total_params/1e6:.2f}M Parameters</span>",
|
"text": ("ONNX Model Metadata<br>"
|
||||||
'x': 0.5,
|
f"<span style='font-size:14px; color:#a0a0a0;'>{total_params/1e6:.2f}M Parameters</span>"),
|
||||||
'xanchor': 'center',
|
"x": 0.5,
|
||||||
'font': {'size': 22}
|
"xanchor": "center",
|
||||||
|
"font": {"size": 22},
|
||||||
},
|
},
|
||||||
height=600,
|
height=600,
|
||||||
template="plotly_dark",
|
template="plotly_dark",
|
||||||
showlegend=False
|
showlegend=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
return fig
|
return fig
|
||||||
|
|
@ -421,7 +448,7 @@ def model_metadata(file_path: Path) -> go.Figure:
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return create_styled_error_figure(
|
return create_styled_error_figure(
|
||||||
"Metadata Analysis Error",
|
"Metadata Analysis Error",
|
||||||
f"Could not extract ONNX model metadata.",
|
"Could not extract ONNX model metadata.",
|
||||||
f"Error: {str(e)}"
|
f"Error: {str(e)}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -435,7 +462,7 @@ def performance_metrics(file_path: Path) -> go.Figure:
|
||||||
return create_styled_error_figure(
|
return create_styled_error_figure(
|
||||||
"ONNX Not Available",
|
"ONNX Not Available",
|
||||||
"ONNX library is required for performance analysis.",
|
"ONNX library is required for performance analysis.",
|
||||||
"Install with: pip install onnx"
|
"Install with: pip install onnx",
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|
@ -452,43 +479,42 @@ def performance_metrics(file_path: Path) -> go.Figure:
|
||||||
try:
|
try:
|
||||||
tensor = onnx.numpy_helper.to_array(initializer)
|
tensor = onnx.numpy_helper.to_array(initializer)
|
||||||
total_params += tensor.size
|
total_params += tensor.size
|
||||||
except:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
# Estimate memory usage (rough approximation)
|
# Estimate memory usage (rough approximation)
|
||||||
param_memory_mb = (total_params * 4) / (1024 * 1024) # Assume float32
|
param_memory_mb = (total_params * 4) / (1024 * 1024) # Assume float32
|
||||||
|
|
||||||
# Count operations by complexity
|
# Count operations by complexity
|
||||||
compute_ops = ['Conv', 'MatMul', 'Gemm', 'LSTM', 'GRU']
|
compute_ops = ["Conv", "MatMul", "Gemm", "LSTM", "GRU"]
|
||||||
efficient_ops = ['Relu', 'Add', 'Mul', 'BatchNormalization', 'Dropout']
|
efficient_ops = ["Relu", "Add", "Mul", "BatchNormalization", "Dropout"]
|
||||||
|
|
||||||
compute_count = sum(1 for node in graph.node
|
compute_count = sum(1 for node in graph.node if any(op in node.op_type for op in compute_ops))
|
||||||
if any(op in node.op_type for op in compute_ops))
|
efficient_count = sum(1 for node in graph.node if any(op in node.op_type for op in efficient_ops))
|
||||||
efficient_count = sum(1 for node in graph.node
|
|
||||||
if any(op in node.op_type for op in efficient_ops))
|
|
||||||
total_ops = len(graph.node)
|
total_ops = len(graph.node)
|
||||||
other_count = total_ops - compute_count - efficient_count
|
other_count = total_ops - compute_count - efficient_count
|
||||||
|
|
||||||
# Create performance dashboard
|
# Create performance dashboard
|
||||||
fig = make_subplots(
|
fig = make_subplots(
|
||||||
rows=2, cols=2,
|
rows=2,
|
||||||
|
cols=2,
|
||||||
subplot_titles=("Model Efficiency", "Memory Usage", "Operation Types", "Complexity Score"),
|
subplot_titles=("Model Efficiency", "Memory Usage", "Operation Types", "Complexity Score"),
|
||||||
specs=[[{"type": "bar"}, {"type": "bar"}],
|
specs=[[{"type": "bar"}, {"type": "bar"}], [{"type": "pie"}, {"type": "indicator"}]],
|
||||||
[{"type": "pie"}, {"type": "indicator"}]]
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Model efficiency metrics
|
# Model efficiency metrics
|
||||||
efficiency_metrics = ["Model Size (MB)", "Parameters (M)", "Total Ops"]
|
efficiency_metrics = ["Model Size (MB)", "Parameters (M)", "Total Ops"]
|
||||||
efficiency_values = [model_size_mb, total_params/1e6, total_ops]
|
efficiency_values = [model_size_mb, total_params / 1e6, total_ops]
|
||||||
|
|
||||||
fig.add_trace(
|
fig.add_trace(
|
||||||
go.Bar(
|
go.Bar(
|
||||||
x=efficiency_metrics,
|
x=efficiency_metrics,
|
||||||
y=efficiency_values,
|
y=efficiency_values,
|
||||||
marker_color=['blue', 'green', 'orange'],
|
marker_color=["blue", "green", "orange"],
|
||||||
showlegend=False
|
showlegend=False
|
||||||
),
|
),
|
||||||
row=1, col=1
|
row=1,
|
||||||
|
col=1,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Memory usage
|
# Memory usage
|
||||||
|
|
@ -499,20 +525,22 @@ def performance_metrics(file_path: Path) -> go.Figure:
|
||||||
go.Bar(
|
go.Bar(
|
||||||
x=memory_types,
|
x=memory_types,
|
||||||
y=memory_values,
|
y=memory_values,
|
||||||
marker_color=['purple', 'red'],
|
marker_color=["purple", "red"],
|
||||||
showlegend=False
|
showlegend=False
|
||||||
),
|
),
|
||||||
row=1, col=2
|
row=1,
|
||||||
|
col=2,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Operation types pie chart
|
# Operation types pie chart
|
||||||
fig.add_trace(
|
fig.add_trace(
|
||||||
go.Pie(
|
go.Pie(
|
||||||
labels=['Compute Ops', 'Efficient Ops', 'Other Ops'],
|
labels=["Compute Ops", "Efficient Ops", "Other Ops"],
|
||||||
values=[compute_count, efficient_count, other_count],
|
values=[compute_count, efficient_count, other_count],
|
||||||
marker_colors=['red', 'green', 'gray']
|
marker_colors=["red", "green", "gray"],
|
||||||
),
|
),
|
||||||
row=2, col=1
|
row=2,
|
||||||
|
col=1,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Complexity score (simple heuristic)
|
# Complexity score (simple heuristic)
|
||||||
|
|
@ -522,30 +550,35 @@ def performance_metrics(file_path: Path) -> go.Figure:
|
||||||
go.Indicator(
|
go.Indicator(
|
||||||
mode="gauge+number",
|
mode="gauge+number",
|
||||||
value=complexity_score,
|
value=complexity_score,
|
||||||
title={'text': "Complexity Score"},
|
title={"text": "Complexity Score"},
|
||||||
gauge={
|
gauge={
|
||||||
'axis': {'range': [0, 100]},
|
"axis": {"range": [0, 100]},
|
||||||
'bar': {'color': "darkred" if complexity_score > 70 else "orange" if complexity_score > 40 else "green"},
|
"bar": {
|
||||||
'steps': [
|
"color": "darkred" if complexity_score > 70 else "orange" if complexity_score > 40 else "green"
|
||||||
{'range': [0, 40], 'color': "lightgreen"},
|
},
|
||||||
{'range': [40, 70], 'color': "yellow"},
|
"steps": [
|
||||||
{'range': [70, 100], 'color': "lightcoral"}
|
{"range": [0, 40], "color": "lightgreen"},
|
||||||
]
|
{"range": [40, 70], "color": "yellow"},
|
||||||
}
|
{"range": [70, 100], "color": "lightcoral"},
|
||||||
|
],
|
||||||
|
},
|
||||||
),
|
),
|
||||||
row=2, col=2
|
row=2,
|
||||||
|
col=2,
|
||||||
)
|
)
|
||||||
|
|
||||||
fig.update_layout(
|
fig.update_layout(
|
||||||
title={
|
title={
|
||||||
'text': f"ONNX Performance Metrics<br><span style='font-size:14px; color:#a0a0a0;'>Complexity Score: {complexity_score:.0f}/100</span>",
|
"text": ("ONNX Performance Metrics<br>"
|
||||||
'x': 0.5,
|
f"<span style='font-size:14px; color:#a0a0a0;'>"
|
||||||
'xanchor': 'center',
|
f"Complexity Score: {complexity_score:.0f}/100</span>"),
|
||||||
'font': {'size': 22}
|
"x": 0.5,
|
||||||
|
"xanchor": "center",
|
||||||
|
"font": {"size": 22},
|
||||||
},
|
},
|
||||||
height=600,
|
height=600,
|
||||||
template="plotly_dark",
|
template="plotly_dark",
|
||||||
showlegend=False
|
showlegend=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
return fig
|
return fig
|
||||||
|
|
@ -553,6 +586,6 @@ def performance_metrics(file_path: Path) -> go.Figure:
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return create_styled_error_figure(
|
return create_styled_error_figure(
|
||||||
"Performance Analysis Error",
|
"Performance Analysis Error",
|
||||||
f"Could not analyze ONNX model performance.",
|
"Could not analyze ONNX model performance.",
|
||||||
f"Error: {str(e)}"
|
f"Error: {str(e)}"
|
||||||
)
|
)
|
||||||
|
|
@ -1,35 +1,79 @@
|
||||||
import torch
|
|
||||||
import plotly.graph_objects as go
|
import plotly.graph_objects as go
|
||||||
from plotly.graph_objects import Figure
|
from plotly.graph_objects import Figure
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
def create_styled_error_figure(title: str, message: str, suggestion: str = None) -> go.Figure:
|
||||||
|
"""Create a professional error figure with Qoherent dark theme styling."""
|
||||||
|
fig = go.Figure()
|
||||||
|
|
||||||
|
# Create a clean, centered text display using Plotly's text formatting
|
||||||
|
main_text = f"<b style='color:#f56565;font-size:18px'>⚠️ {title}</b><br><br>"
|
||||||
|
main_text += f"<span style='color:#e2e8f0;font-size:14px'>{message}</span>"
|
||||||
|
|
||||||
|
if suggestion:
|
||||||
|
main_text += "<br><br><span style='color:#63b3ed;font-size:13px'>💡 <b>Suggestion:</b></span><br>"
|
||||||
|
main_text += f"<span style='color:#cbd5e0;font-size:12px'>{suggestion}</span>"
|
||||||
|
|
||||||
|
# Add the main text annotation
|
||||||
|
fig.add_annotation(
|
||||||
|
text=main_text,
|
||||||
|
xref="paper",
|
||||||
|
yref="paper",
|
||||||
|
x=0.5,
|
||||||
|
y=0.5,
|
||||||
|
xanchor="center",
|
||||||
|
yanchor="middle",
|
||||||
|
showarrow=False,
|
||||||
|
align="center",
|
||||||
|
borderwidth=2,
|
||||||
|
bordercolor="#4a5568",
|
||||||
|
bgcolor="#2d3748",
|
||||||
|
font=dict(family="Arial, sans-serif", size=14, color="#e2e8f0"),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Update layout with dark theme
|
||||||
|
fig.update_layout(
|
||||||
|
title="",
|
||||||
|
height=400,
|
||||||
|
template="plotly_dark",
|
||||||
|
margin=dict(l=40, r=40, t=40, b=40),
|
||||||
|
plot_bgcolor="#1a202c",
|
||||||
|
paper_bgcolor="#1a202c",
|
||||||
|
font=dict(color="#e2e8f0"),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Remove axes and grid
|
||||||
|
fig.update_xaxes(visible=False)
|
||||||
|
fig.update_yaxes(visible=False)
|
||||||
|
|
||||||
|
return fig
|
||||||
|
|
||||||
|
|
||||||
def model_summary_plot(state_dict: dict) -> Figure:
|
def model_summary_plot(state_dict: dict) -> Figure:
|
||||||
"""Generate a summary plot of the PyTorch model state dict."""
|
"""Generate a summary plot of the PyTorch model state dict."""
|
||||||
if not state_dict:
|
if not state_dict:
|
||||||
# Handle empty state dict
|
return create_styled_error_figure(
|
||||||
fig = go.Figure()
|
"Empty State Dict",
|
||||||
fig.add_annotation(
|
"No parameters found in state dict",
|
||||||
text="No parameters found in state dict",
|
"Ensure the model state dictionary contains weight parameters"
|
||||||
xref="paper", yref="paper",
|
|
||||||
x=0.5, y=0.5, showarrow=False,
|
|
||||||
font=dict(size=16)
|
|
||||||
)
|
)
|
||||||
fig.update_layout(
|
|
||||||
title="Model Layer Parameter Counts",
|
|
||||||
xaxis_title="Layer",
|
|
||||||
yaxis_title="Number of Parameters",
|
|
||||||
template="plotly_dark"
|
|
||||||
)
|
|
||||||
return fig
|
|
||||||
|
|
||||||
# Count parameters by layer type
|
# Count parameters by layer type
|
||||||
layer_info = []
|
layer_info = []
|
||||||
for key, tensor in state_dict.items():
|
for key, tensor in state_dict.items():
|
||||||
if 'weight' in key:
|
if 'weight' in key:
|
||||||
try:
|
try:
|
||||||
layer_name = key.replace('.weight', '')
|
layer_name = key.replace('.weight', '')
|
||||||
param_count = tensor.numel() if hasattr(tensor, 'numel') else len(tensor.flatten()) if hasattr(tensor, 'flatten') else 0
|
param_count = (
|
||||||
shape = list(tensor.shape) if hasattr(tensor, 'shape') else [len(tensor)] if hasattr(tensor, '__len__') else []
|
tensor.numel() if hasattr(tensor, 'numel')
|
||||||
|
else len(tensor.flatten()) if hasattr(tensor, 'flatten')
|
||||||
|
else 0
|
||||||
|
)
|
||||||
|
shape = (
|
||||||
|
list(tensor.shape) if hasattr(tensor, 'shape')
|
||||||
|
else [len(tensor)] if hasattr(tensor, '__len__')
|
||||||
|
else []
|
||||||
|
)
|
||||||
layer_info.append({
|
layer_info.append({
|
||||||
'layer': layer_name,
|
'layer': layer_name,
|
||||||
'parameters': param_count,
|
'parameters': param_count,
|
||||||
|
|
@ -38,24 +82,12 @@ def model_summary_plot(state_dict: dict) -> Figure:
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Warning: Could not process layer {key}: {e}")
|
print(f"Warning: Could not process layer {key}: {e}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if not layer_info:
|
if not layer_info:
|
||||||
# Handle case where no weight layers found
|
return create_styled_error_figure(
|
||||||
fig = go.Figure()
|
"No Weight Layers Found",
|
||||||
fig.add_annotation(
|
"No weight layers found in state dict",
|
||||||
text="No weight layers found in state dict",
|
"Ensure the state dictionary contains layers with '.weight' parameters"
|
||||||
xref="paper", yref="paper",
|
|
||||||
x=0.5, y=0.5, showarrow=False,
|
|
||||||
font=dict(size=16)
|
|
||||||
)
|
)
|
||||||
fig.update_layout(
|
|
||||||
title="Model Layer Parameter Counts",
|
|
||||||
xaxis_title="Layer",
|
|
||||||
yaxis_title="Number of Parameters",
|
|
||||||
template="plotly_dark"
|
|
||||||
)
|
|
||||||
return fig
|
|
||||||
|
|
||||||
# Create bar chart of parameter counts
|
# Create bar chart of parameter counts
|
||||||
fig = go.Figure(data=[
|
fig = go.Figure(data=[
|
||||||
go.Bar(
|
go.Bar(
|
||||||
|
|
@ -65,53 +97,35 @@ def model_summary_plot(state_dict: dict) -> Figure:
|
||||||
textposition='auto',
|
textposition='auto',
|
||||||
)
|
)
|
||||||
])
|
])
|
||||||
|
|
||||||
fig.update_layout(
|
fig.update_layout(
|
||||||
title="Model Layer Parameter Counts",
|
title="Model Layer Parameter Counts",
|
||||||
xaxis_title="Layer",
|
xaxis_title="Layer",
|
||||||
yaxis_title="Number of Parameters",
|
yaxis_title="Number of Parameters",
|
||||||
template="plotly_dark"
|
template="plotly_dark"
|
||||||
)
|
)
|
||||||
|
|
||||||
return fig
|
return fig
|
||||||
|
|
||||||
|
|
||||||
def layer_weights_plot(state_dict: dict, layer_name: str = None) -> Figure:
|
def layer_weights_plot(state_dict: dict, layer_name: str = None) -> Figure:
|
||||||
"""Visualize weights for a specific layer."""
|
"""Visualize weights for a specific layer."""
|
||||||
if not state_dict:
|
if not state_dict:
|
||||||
fig = go.Figure()
|
return create_styled_error_figure(
|
||||||
fig.add_annotation(
|
"Empty State Dict",
|
||||||
text="No data in state dict",
|
"No data in state dict",
|
||||||
xref="paper", yref="paper",
|
"Ensure the model state dictionary contains data"
|
||||||
x=0.5, y=0.5, showarrow=False,
|
|
||||||
font=dict(size=16)
|
|
||||||
)
|
)
|
||||||
fig.update_layout(
|
|
||||||
title="Layer Weights",
|
|
||||||
template="plotly_dark"
|
|
||||||
)
|
|
||||||
return fig
|
|
||||||
|
|
||||||
if layer_name is None:
|
if layer_name is None:
|
||||||
# Get first weight tensor
|
# Get first weight tensor
|
||||||
weight_keys = [k for k in state_dict.keys() if 'weight' in k]
|
weight_keys = [k for k in state_dict.keys() if 'weight' in k]
|
||||||
if not weight_keys:
|
if not weight_keys:
|
||||||
fig = go.Figure()
|
return create_styled_error_figure(
|
||||||
fig.add_annotation(
|
"No Weight Tensors Found",
|
||||||
text="No weight tensors found in state dict",
|
"No weight tensors found in state dict",
|
||||||
xref="paper", yref="paper",
|
"Ensure the state dictionary contains layers with '.weight' parameters"
|
||||||
x=0.5, y=0.5, showarrow=False,
|
|
||||||
font=dict(size=16)
|
|
||||||
)
|
)
|
||||||
fig.update_layout(
|
|
||||||
title="Layer Weights",
|
|
||||||
template="plotly_dark"
|
|
||||||
)
|
|
||||||
return fig
|
|
||||||
layer_name = weight_keys[0]
|
layer_name = weight_keys[0]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
weights = state_dict[layer_name]
|
weights = state_dict[layer_name]
|
||||||
|
|
||||||
# Convert to numpy if it's a torch tensor
|
# Convert to numpy if it's a torch tensor
|
||||||
if hasattr(weights, 'numpy'):
|
if hasattr(weights, 'numpy'):
|
||||||
weights_np = weights.detach().numpy() if hasattr(weights, 'detach') else weights.numpy()
|
weights_np = weights.detach().numpy() if hasattr(weights, 'detach') else weights.numpy()
|
||||||
|
|
@ -119,7 +133,6 @@ def layer_weights_plot(state_dict: dict, layer_name: str = None) -> Figure:
|
||||||
weights_np = weights.cpu().detach().numpy()
|
weights_np = weights.cpu().detach().numpy()
|
||||||
else:
|
else:
|
||||||
weights_np = np.array(weights)
|
weights_np = np.array(weights)
|
||||||
|
|
||||||
# For 2D weights, create heatmap
|
# For 2D weights, create heatmap
|
||||||
if len(weights_np.shape) == 2:
|
if len(weights_np.shape) == 2:
|
||||||
fig = go.Figure(data=go.Heatmap(
|
fig = go.Figure(data=go.Heatmap(
|
||||||
|
|
@ -143,36 +156,21 @@ def layer_weights_plot(state_dict: dict, layer_name: str = None) -> Figure:
|
||||||
return fig
|
return fig
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
fig = go.Figure()
|
return create_styled_error_figure(
|
||||||
fig.add_annotation(
|
"Layer Processing Error",
|
||||||
text=f"Error processing layer {layer_name}: {str(e)}",
|
f"Error processing layer {layer_name}: {str(e)}",
|
||||||
xref="paper", yref="paper",
|
"Check that the layer name exists and contains valid tensor data"
|
||||||
x=0.5, y=0.5, showarrow=False,
|
|
||||||
font=dict(size=14)
|
|
||||||
)
|
)
|
||||||
fig.update_layout(
|
|
||||||
title="Layer Weights - Error",
|
|
||||||
template="plotly_dark"
|
|
||||||
)
|
|
||||||
return fig
|
|
||||||
|
|
||||||
def weight_distribution_plot(state_dict: dict) -> Figure:
|
def weight_distribution_plot(state_dict: dict) -> Figure:
|
||||||
"""Show distribution of weights across all layers."""
|
"""Show distribution of weights across all layers."""
|
||||||
if not state_dict:
|
if not state_dict:
|
||||||
fig = go.Figure()
|
return create_styled_error_figure(
|
||||||
fig.add_annotation(
|
"Empty State Dict",
|
||||||
text="No data in state dict",
|
"No data in state dict",
|
||||||
xref="paper", yref="paper",
|
"Ensure the model state dictionary contains data"
|
||||||
x=0.5, y=0.5, showarrow=False,
|
|
||||||
font=dict(size=16)
|
|
||||||
)
|
)
|
||||||
fig.update_layout(
|
|
||||||
title="Overall Weight Distribution",
|
|
||||||
xaxis_title="Weight Value",
|
|
||||||
yaxis_title="Frequency",
|
|
||||||
template="plotly_dark"
|
|
||||||
)
|
|
||||||
return fig
|
|
||||||
|
|
||||||
all_weights = []
|
all_weights = []
|
||||||
layer_names = []
|
layer_names = []
|
||||||
|
|
@ -187,7 +185,6 @@ def weight_distribution_plot(state_dict: dict) -> Figure:
|
||||||
weights_np = tensor.cpu().detach().numpy()
|
weights_np = tensor.cpu().detach().numpy()
|
||||||
else:
|
else:
|
||||||
weights_np = np.array(tensor)
|
weights_np = np.array(tensor)
|
||||||
|
|
||||||
flat_weights = weights_np.flatten()
|
flat_weights = weights_np.flatten()
|
||||||
all_weights.extend(flat_weights)
|
all_weights.extend(flat_weights)
|
||||||
layer_names.extend([key] * len(flat_weights))
|
layer_names.extend([key] * len(flat_weights))
|
||||||
|
|
@ -196,20 +193,11 @@ def weight_distribution_plot(state_dict: dict) -> Figure:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if not all_weights:
|
if not all_weights:
|
||||||
fig = go.Figure()
|
return create_styled_error_figure(
|
||||||
fig.add_annotation(
|
"No Weight Data Found",
|
||||||
text="No weight data found in state dict",
|
"No weight data found in state dict",
|
||||||
xref="paper", yref="paper",
|
"Ensure the state dictionary contains layers with '.weight' parameters"
|
||||||
x=0.5, y=0.5, showarrow=False,
|
|
||||||
font=dict(size=16)
|
|
||||||
)
|
)
|
||||||
fig.update_layout(
|
|
||||||
title="Overall Weight Distribution",
|
|
||||||
xaxis_title="Weight Value",
|
|
||||||
yaxis_title="Frequency",
|
|
||||||
template="plotly_dark"
|
|
||||||
)
|
|
||||||
return fig
|
|
||||||
|
|
||||||
fig = go.Figure(data=[
|
fig = go.Figure(data=[
|
||||||
go.Histogram(
|
go.Histogram(
|
||||||
|
|
@ -225,5 +213,4 @@ def weight_distribution_plot(state_dict: dict) -> Figure:
|
||||||
yaxis_title="Frequency",
|
yaxis_title="Frequency",
|
||||||
template="plotly_dark"
|
template="plotly_dark"
|
||||||
)
|
)
|
||||||
|
|
||||||
return fig
|
return fig
|
||||||
|
|
@ -6,7 +6,6 @@ import random
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pandas as pd
|
|
||||||
import plotly.express as px
|
import plotly.express as px
|
||||||
import plotly.graph_objects as go
|
import plotly.graph_objects as go
|
||||||
from plotly.graph_objects import Figure
|
from plotly.graph_objects import Figure
|
||||||
|
|
@ -22,25 +21,24 @@ def create_styled_error_figure(title: str, message: str, suggestion: str = None)
|
||||||
main_text += f"<span style='color:#e2e8f0;font-size:14px'>{message}</span>"
|
main_text += f"<span style='color:#e2e8f0;font-size:14px'>{message}</span>"
|
||||||
|
|
||||||
if suggestion:
|
if suggestion:
|
||||||
main_text += f"<br><br><span style='color:#63b3ed;font-size:13px'>💡 <b>Suggestion:</b></span><br>"
|
main_text += "<br><br><span style='color:#63b3ed;font-size:13px'>💡 <b>Suggestion:</b></span><br>"
|
||||||
main_text += f"<span style='color:#cbd5e0;font-size:12px'>{suggestion}</span>"
|
main_text += f"<span style='color:#cbd5e0;font-size:12px'>{suggestion}</span>"
|
||||||
|
|
||||||
# Add the main text annotation
|
# Add the main text annotation
|
||||||
fig.add_annotation(
|
fig.add_annotation(
|
||||||
text=main_text,
|
text=main_text,
|
||||||
xref="paper", yref="paper",
|
xref="paper",
|
||||||
x=0.5, y=0.5,
|
yref="paper",
|
||||||
xanchor='center', yanchor='middle',
|
x=0.5,
|
||||||
|
y=0.5,
|
||||||
|
xanchor="center",
|
||||||
|
yanchor="middle",
|
||||||
showarrow=False,
|
showarrow=False,
|
||||||
align="center",
|
align="center",
|
||||||
borderwidth=2,
|
borderwidth=2,
|
||||||
bordercolor="#4a5568",
|
bordercolor="#4a5568",
|
||||||
bgcolor="#2d3748",
|
bgcolor="#2d3748",
|
||||||
font=dict(
|
font=dict(family="Arial, sans-serif", size=14, color="#e2e8f0"),
|
||||||
family="Arial, sans-serif",
|
|
||||||
size=14,
|
|
||||||
color="#e2e8f0"
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Update layout with dark theme
|
# Update layout with dark theme
|
||||||
|
|
@ -51,7 +49,7 @@ def create_styled_error_figure(title: str, message: str, suggestion: str = None)
|
||||||
margin=dict(l=40, r=40, t=40, b=40),
|
margin=dict(l=40, r=40, t=40, b=40),
|
||||||
plot_bgcolor="#1a202c",
|
plot_bgcolor="#1a202c",
|
||||||
paper_bgcolor="#1a202c",
|
paper_bgcolor="#1a202c",
|
||||||
font=dict(color="#e2e8f0")
|
font=dict(color="#e2e8f0"),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Remove axes and grid
|
# Remove axes and grid
|
||||||
|
|
@ -73,7 +71,7 @@ def _check_dataset_compatibility(dataset, plot_type: str) -> tuple[bool, str]:
|
||||||
|
|
||||||
if plot_type == "class_distribution":
|
if plot_type == "class_distribution":
|
||||||
# Check if we have any categorical columns
|
# Check if we have any categorical columns
|
||||||
categorical_cols = [col for col in metadata.columns if metadata[col].dtype == 'object']
|
categorical_cols = [col for col in metadata.columns if metadata[col].dtype == "object"]
|
||||||
alternatives = ["class", "label", "modulation", "impairment", "use_case", "category", "labels"]
|
alternatives = ["class", "label", "modulation", "impairment", "use_case", "category", "labels"]
|
||||||
|
|
||||||
has_class_col = any(alt in metadata.columns for alt in alternatives)
|
has_class_col = any(alt in metadata.columns for alt in alternatives)
|
||||||
|
|
@ -89,7 +87,7 @@ def _check_dataset_compatibility(dataset, plot_type: str) -> tuple[bool, str]:
|
||||||
|
|
||||||
# Check if we can access sample data (basic test)
|
# Check if we can access sample data (basic test)
|
||||||
try:
|
try:
|
||||||
sample_data = dataset[0] if hasattr(dataset, '__getitem__') else None
|
sample_data = dataset[0] if hasattr(dataset, "__getitem__") else None
|
||||||
if sample_data is None or len(sample_data) < 32:
|
if sample_data is None or len(sample_data) < 32:
|
||||||
return False, "Insufficient sample data for spectrogram (need at least 32 points)"
|
return False, "Insufficient sample data for spectrogram (need at least 32 points)"
|
||||||
except Exception:
|
except Exception:
|
||||||
|
|
@ -111,7 +109,7 @@ def class_distribution_plot(dataset, class_key: str = "modulation") -> Figure:
|
||||||
return create_styled_error_figure(
|
return create_styled_error_figure(
|
||||||
"Dataset Not Compatible",
|
"Dataset Not Compatible",
|
||||||
"This dataset doesn't have categorical labels needed for class distribution analysis.",
|
"This dataset doesn't have categorical labels needed for class distribution analysis.",
|
||||||
"Try using the Dataset Overview widget to explore the available data columns."
|
"Try using the Dataset Overview widget to explore the available data columns.",
|
||||||
)
|
)
|
||||||
|
|
||||||
metadata = dataset.metadata
|
metadata = dataset.metadata
|
||||||
|
|
@ -127,7 +125,7 @@ def class_distribution_plot(dataset, class_key: str = "modulation") -> Figure:
|
||||||
else:
|
else:
|
||||||
# Use first categorical column
|
# Use first categorical column
|
||||||
for col in metadata.columns:
|
for col in metadata.columns:
|
||||||
if metadata[col].dtype == 'object' or metadata[col].nunique() < 50:
|
if metadata[col].dtype == "object" or metadata[col].nunique() < 50:
|
||||||
class_key = col
|
class_key = col
|
||||||
break
|
break
|
||||||
|
|
||||||
|
|
@ -135,7 +133,8 @@ def class_distribution_plot(dataset, class_key: str = "modulation") -> Figure:
|
||||||
return create_styled_error_figure(
|
return create_styled_error_figure(
|
||||||
"No Class Labels Found",
|
"No Class Labels Found",
|
||||||
"This dataset contains numerical data without categorical labels.",
|
"This dataset contains numerical data without categorical labels.",
|
||||||
"Try using the Dataset Overview widget for data analysis, or check if your dataset has hidden categorical columns."
|
("Try using the Dataset Overview widget for data analysis, "
|
||||||
|
"or check if your dataset has hidden categorical columns."),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Count examples per class (limit to top 20 for performance)
|
# Count examples per class (limit to top 20 for performance)
|
||||||
|
|
@ -146,19 +145,15 @@ def class_distribution_plot(dataset, class_key: str = "modulation") -> Figure:
|
||||||
class_counts = class_counts.sort_index()
|
class_counts = class_counts.sort_index()
|
||||||
|
|
||||||
# Create simple bar plot
|
# Create simple bar plot
|
||||||
fig = px.bar(
|
fig = px.bar(x=class_counts.index, y=class_counts.values, title=f"Class Distribution: {class_key.title()}")
|
||||||
x=class_counts.index,
|
|
||||||
y=class_counts.values,
|
|
||||||
title=f'Class Distribution: {class_key.title()}'
|
|
||||||
)
|
|
||||||
|
|
||||||
fig.update_traces(texttemplate='%{y}', textposition='outside')
|
fig.update_traces(texttemplate="%{y}", textposition="outside")
|
||||||
fig.update_layout(
|
fig.update_layout(
|
||||||
xaxis_title=class_key.title(),
|
xaxis_title=class_key.title(),
|
||||||
yaxis_title='Number of Examples',
|
yaxis_title="Number of Examples",
|
||||||
showlegend=False,
|
showlegend=False,
|
||||||
height=400,
|
height=400,
|
||||||
template="plotly_dark"
|
template="plotly_dark",
|
||||||
)
|
)
|
||||||
|
|
||||||
return fig
|
return fig
|
||||||
|
|
@ -166,8 +161,8 @@ def class_distribution_plot(dataset, class_key: str = "modulation") -> Figure:
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return create_styled_error_figure(
|
return create_styled_error_figure(
|
||||||
"Class Distribution Error",
|
"Class Distribution Error",
|
||||||
f"An error occurred while generating the class distribution plot.",
|
"An error occurred while generating the class distribution plot.",
|
||||||
f"Technical details: {str(e)}"
|
f"Technical details: {str(e)}",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -180,91 +175,79 @@ def dataset_overview_plot(dataset) -> Figure:
|
||||||
# Create subplot with multiple charts
|
# Create subplot with multiple charts
|
||||||
|
|
||||||
# Determine subplot titles based on data type
|
# Determine subplot titles based on data type
|
||||||
categorical_cols = [col for col in metadata.columns if metadata[col].dtype == 'object']
|
categorical_cols = [col for col in metadata.columns if metadata[col].dtype == "object"]
|
||||||
numeric_cols = [col for col in metadata.columns if metadata[col].dtype in ['int64', 'float64']]
|
numeric_cols = [col for col in metadata.columns if metadata[col].dtype in ["int64", "float64"]]
|
||||||
|
|
||||||
dist_title = "Value Distribution" if categorical_cols else "Data Distribution"
|
dist_title = "Value Distribution" if categorical_cols else "Data Distribution"
|
||||||
|
|
||||||
fig = make_subplots(
|
fig = make_subplots(
|
||||||
rows=2, cols=2,
|
rows=2,
|
||||||
|
cols=2,
|
||||||
subplot_titles=("Dataset Size", "Data Types", dist_title, "Statistics Summary"),
|
subplot_titles=("Dataset Size", "Data Types", dist_title, "Statistics Summary"),
|
||||||
specs=[[{"type": "indicator"}, {"type": "bar"}],
|
specs=[
|
||||||
[{"type": "histogram" if not categorical_cols else "bar"}, {"type": "table"}]]
|
[{"type": "indicator"}, {"type": "bar"}],
|
||||||
|
[{"type": "histogram" if not categorical_cols else "bar"}, {"type": "table"}],
|
||||||
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
# Top left: Dataset size indicator
|
# Top left: Dataset size indicator
|
||||||
fig.add_trace(
|
fig.add_trace(
|
||||||
go.Indicator(
|
go.Indicator(
|
||||||
mode="number",
|
mode="number", value=total_examples, title={"text": "Total Examples"}, number={"font": {"size": 40}}
|
||||||
value=total_examples,
|
|
||||||
title={"text": "Total Examples"},
|
|
||||||
number={"font": {"size": 40}}
|
|
||||||
),
|
),
|
||||||
row=1, col=1
|
row=1,
|
||||||
|
col=1,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Top right: Data types distribution
|
# Top right: Data types distribution
|
||||||
dtype_counts = metadata.dtypes.value_counts()
|
dtype_counts = metadata.dtypes.value_counts()
|
||||||
fig.add_trace(
|
fig.add_trace(
|
||||||
go.Bar(
|
go.Bar(
|
||||||
x=[str(dt) for dt in dtype_counts.index],
|
x=[str(dt) for dt in dtype_counts.index], y=dtype_counts.values, name="Data Types", showlegend=False
|
||||||
y=dtype_counts.values,
|
|
||||||
name="Data Types",
|
|
||||||
showlegend=False
|
|
||||||
),
|
),
|
||||||
row=1, col=2
|
row=1,
|
||||||
|
col=2,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Bottom left: Show distribution of numeric columns or categorical if available
|
# Bottom left: Show distribution of numeric columns or categorical if available
|
||||||
categorical_cols = [col for col in metadata.columns if metadata[col].dtype == 'object']
|
categorical_cols = [col for col in metadata.columns if metadata[col].dtype == "object"]
|
||||||
numeric_cols = [col for col in metadata.columns if metadata[col].dtype in ['int64', 'float64']]
|
numeric_cols = [col for col in metadata.columns if metadata[col].dtype in ["int64", "float64"]]
|
||||||
|
|
||||||
if categorical_cols:
|
if categorical_cols:
|
||||||
col = categorical_cols[0] # Show first categorical column
|
col = categorical_cols[0] # Show first categorical column
|
||||||
value_counts = metadata[col].value_counts().head(10)
|
value_counts = metadata[col].value_counts().head(10)
|
||||||
fig.add_trace(
|
fig.add_trace(
|
||||||
go.Bar(
|
go.Bar(x=value_counts.index, y=value_counts.values, name=f"{col} Distribution", showlegend=False),
|
||||||
x=value_counts.index,
|
row=2,
|
||||||
y=value_counts.values,
|
col=1,
|
||||||
name=f"{col} Distribution",
|
|
||||||
showlegend=False
|
|
||||||
),
|
|
||||||
row=2, col=1
|
|
||||||
)
|
)
|
||||||
elif numeric_cols:
|
elif numeric_cols:
|
||||||
# Show histogram of first numeric column
|
# Show histogram of first numeric column
|
||||||
col = numeric_cols[0]
|
col = numeric_cols[0]
|
||||||
fig.add_trace(
|
fig.add_trace(
|
||||||
go.Histogram(
|
go.Histogram(x=metadata[col], name=f"{col} Distribution", showlegend=False, nbinsx=20), row=2, col=1
|
||||||
x=metadata[col],
|
|
||||||
name=f"{col} Distribution",
|
|
||||||
showlegend=False,
|
|
||||||
nbinsx=20
|
|
||||||
),
|
|
||||||
row=2, col=1
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Bottom right: Basic statistics table
|
# Bottom right: Basic statistics table
|
||||||
stats_data = []
|
stats_data = []
|
||||||
display_cols = (numeric_cols[:5] if len(numeric_cols) > 0 else metadata.columns[:5])
|
display_cols = numeric_cols[:5] if len(numeric_cols) > 0 else metadata.columns[:5]
|
||||||
|
|
||||||
for col in display_cols:
|
for col in display_cols:
|
||||||
if metadata[col].dtype in ['int64', 'float64']:
|
if metadata[col].dtype in ["int64", "float64"]:
|
||||||
stats_data.append([
|
stats_data.append(
|
||||||
|
[
|
||||||
col[:15] + "..." if len(col) > 15 else col, # Truncate long column names
|
col[:15] + "..." if len(col) > 15 else col, # Truncate long column names
|
||||||
f"{metadata[col].mean():.3f}",
|
f"{metadata[col].mean():.3f}",
|
||||||
f"{metadata[col].std():.3f}",
|
f"{metadata[col].std():.3f}",
|
||||||
f"{metadata[col].min():.3f}",
|
f"{metadata[col].min():.3f}",
|
||||||
f"{metadata[col].max():.3f}"
|
f"{metadata[col].max():.3f}",
|
||||||
])
|
]
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
unique_count = metadata[col].nunique()
|
unique_count = metadata[col].nunique()
|
||||||
stats_data.append([
|
stats_data.append(
|
||||||
col[:15] + "..." if len(col) > 15 else col,
|
[col[:15] + "..." if len(col) > 15 else col, "N/A", "N/A", f"{unique_count} unique", "N/A"]
|
||||||
"N/A", "N/A",
|
)
|
||||||
f"{unique_count} unique",
|
|
||||||
"N/A"
|
|
||||||
])
|
|
||||||
|
|
||||||
if stats_data:
|
if stats_data:
|
||||||
fig.add_trace(
|
fig.add_trace(
|
||||||
|
|
@ -273,30 +256,26 @@ def dataset_overview_plot(dataset) -> Figure:
|
||||||
values=["Column", "Mean", "Std", "Min/Unique", "Max"],
|
values=["Column", "Mean", "Std", "Min/Unique", "Max"],
|
||||||
fill_color="rgba(30, 30, 30, 0.8)",
|
fill_color="rgba(30, 30, 30, 0.8)",
|
||||||
align="center",
|
align="center",
|
||||||
font=dict(color="white", size=12)
|
font=dict(color="white", size=12),
|
||||||
),
|
),
|
||||||
cells=dict(
|
cells=dict(
|
||||||
values=list(zip(*stats_data)),
|
values=list(zip(*stats_data)),
|
||||||
fill_color="rgba(50, 50, 50, 0.6)",
|
fill_color="rgba(50, 50, 50, 0.6)",
|
||||||
align="center",
|
align="center",
|
||||||
font=dict(color="white", size=11)
|
font=dict(color="white", size=11),
|
||||||
)
|
|
||||||
),
|
),
|
||||||
row=2, col=2
|
),
|
||||||
|
row=2,
|
||||||
|
col=2,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create informative title
|
# Create informative title
|
||||||
total_cols = len(metadata.columns)
|
total_cols = len(metadata.columns)
|
||||||
title = f"Dataset Overview - {total_examples} samples, {total_cols} columns"
|
title = f"Dataset Overview - {total_examples} samples, {total_cols} columns"
|
||||||
if total_cols > 5:
|
if total_cols > 5:
|
||||||
title += f" (showing first 5)"
|
title += " (showing first 5)"
|
||||||
|
|
||||||
fig.update_layout(
|
fig.update_layout(title=title, height=600, showlegend=False, template="plotly_dark")
|
||||||
title=title,
|
|
||||||
height=600,
|
|
||||||
showlegend=False,
|
|
||||||
template="plotly_dark"
|
|
||||||
)
|
|
||||||
|
|
||||||
return fig
|
return fig
|
||||||
|
|
||||||
|
|
@ -304,10 +283,100 @@ def dataset_overview_plot(dataset) -> Figure:
|
||||||
return create_styled_error_figure(
|
return create_styled_error_figure(
|
||||||
"Dataset Overview Error",
|
"Dataset Overview Error",
|
||||||
"An error occurred while generating the dataset overview.",
|
"An error occurred while generating the dataset overview.",
|
||||||
f"Technical details: {str(e)}"
|
f"Technical details: {str(e)}",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _find_class_column(metadata, class_key: str) -> str:
|
||||||
|
"""Find the appropriate class column in metadata."""
|
||||||
|
if class_key in metadata.columns:
|
||||||
|
return class_key
|
||||||
|
|
||||||
|
alternatives = ["class", "label", "modulation", "impairment", "use_case"]
|
||||||
|
for alt in alternatives:
|
||||||
|
if alt in metadata.columns:
|
||||||
|
return alt
|
||||||
|
return class_key
|
||||||
|
|
||||||
|
|
||||||
|
def _get_sample_data(dataset, sample_idx: int):
|
||||||
|
"""Get sample data from dataset, with synthetic fallback."""
|
||||||
|
try:
|
||||||
|
return dataset[sample_idx]
|
||||||
|
except Exception:
|
||||||
|
# Generate synthetic signal based on class
|
||||||
|
n_samples = 1024
|
||||||
|
t = np.linspace(0, 1, n_samples)
|
||||||
|
freq = 0.1 + 0.05 * sample_idx % 5 # Vary frequency by sample
|
||||||
|
sample_data = np.exp(1j * 2 * np.pi * freq * t)
|
||||||
|
# Add some noise
|
||||||
|
sample_data += 0.1 * (np.random.randn(n_samples) + 1j * np.random.randn(n_samples))
|
||||||
|
return sample_data
|
||||||
|
|
||||||
|
|
||||||
|
def _calculate_spectrogram_params(n_samples: int) -> tuple[int, int, int, int]:
|
||||||
|
"""Calculate spectrogram parameters based on sample length."""
|
||||||
|
if n_samples < 32:
|
||||||
|
raise ValueError(f"Insufficient data: need at least 32 samples, got {n_samples}")
|
||||||
|
|
||||||
|
nperseg = min(256, max(32, n_samples // 4))
|
||||||
|
hop_length = max(1, nperseg // 2)
|
||||||
|
|
||||||
|
# Adjust for very short signals
|
||||||
|
if n_samples < nperseg:
|
||||||
|
nperseg = n_samples
|
||||||
|
hop_length = 1
|
||||||
|
|
||||||
|
n_frames = max(1, (n_samples - nperseg) // hop_length + 1)
|
||||||
|
freq_bins = max(1, nperseg // 2)
|
||||||
|
|
||||||
|
return nperseg, hop_length, n_frames, freq_bins
|
||||||
|
|
||||||
|
|
||||||
|
def _compute_spectrogram(sample_data, nperseg: int, hop_length: int, n_frames: int, freq_bins: int):
|
||||||
|
"""Compute spectrogram using FFT."""
|
||||||
|
n_samples = len(sample_data)
|
||||||
|
Sxx = np.zeros((freq_bins, n_frames))
|
||||||
|
|
||||||
|
for i in range(n_frames):
|
||||||
|
start_idx = i * hop_length
|
||||||
|
end_idx = min(start_idx + nperseg, n_samples)
|
||||||
|
|
||||||
|
if end_idx > start_idx:
|
||||||
|
windowed = sample_data[start_idx:end_idx]
|
||||||
|
|
||||||
|
# Pad if necessary to maintain nperseg size
|
||||||
|
if len(windowed) < nperseg:
|
||||||
|
windowed = np.pad(windowed, (0, nperseg - len(windowed)), mode="constant")
|
||||||
|
|
||||||
|
fft_result = np.fft.fft(windowed)
|
||||||
|
Sxx[:, i] = np.abs(fft_result[:freq_bins]) ** 2
|
||||||
|
|
||||||
|
return Sxx
|
||||||
|
|
||||||
|
|
||||||
|
def _create_spectrogram_figure(Sxx, n_frames: int, hop_length: int, n_samples: int, freq_bins: int,
|
||||||
|
sample_idx: int, class_key: str, sample_metadata) -> Figure:
|
||||||
|
"""Create the plotly figure for the spectrogram."""
|
||||||
|
# Convert to dB
|
||||||
|
Sxx_db = 10 * np.log10(Sxx + 1e-10)
|
||||||
|
|
||||||
|
# Create time and frequency vectors
|
||||||
|
t = np.arange(n_frames) * hop_length / max(1, n_samples)
|
||||||
|
f = np.linspace(0, 0.5, freq_bins)
|
||||||
|
|
||||||
|
# Create plot
|
||||||
|
fig = go.Figure(data=go.Heatmap(z=Sxx_db, x=t, y=f, colorscale="viridis", colorbar=dict(title="Power (dB)")))
|
||||||
|
|
||||||
|
# Add title with metadata
|
||||||
|
title = f"Sample Spectrogram (Index: {sample_idx})"
|
||||||
|
if class_key in sample_metadata:
|
||||||
|
title += f" - {class_key}: {sample_metadata[class_key]}"
|
||||||
|
|
||||||
|
fig.update_layout(title=title, xaxis_title="Time", yaxis_title="Frequency", height=400, template="plotly_dark")
|
||||||
|
return fig
|
||||||
|
|
||||||
|
|
||||||
def sample_spectrogram_plot(dataset, class_key: str = "modulation", sample_idx: Optional[int] = None) -> Figure:
|
def sample_spectrogram_plot(dataset, class_key: str = "modulation", sample_idx: Optional[int] = None) -> Figure:
|
||||||
"""Generate a spectrogram plot from a sample in the dataset."""
|
"""Generate a spectrogram plot from a sample in the dataset."""
|
||||||
try:
|
try:
|
||||||
|
|
@ -317,114 +386,36 @@ def sample_spectrogram_plot(dataset, class_key: str = "modulation", sample_idx:
|
||||||
return create_styled_error_figure(
|
return create_styled_error_figure(
|
||||||
"Spectrogram Not Available",
|
"Spectrogram Not Available",
|
||||||
"This dataset doesn't have sufficient signal data for spectrogram visualization.",
|
"This dataset doesn't have sufficient signal data for spectrogram visualization.",
|
||||||
"Ensure your dataset contains complex-valued signal samples with at least 32 data points per sample."
|
"Ensure your dataset contains complex-valued signal samples with at least 32 data points per sample.",
|
||||||
)
|
)
|
||||||
|
|
||||||
metadata = dataset.metadata
|
metadata = dataset.metadata
|
||||||
|
|
||||||
if len(metadata) == 0:
|
if len(metadata) == 0:
|
||||||
raise ValueError("Dataset is empty")
|
raise ValueError("Dataset is empty")
|
||||||
|
|
||||||
# Find class column
|
# Find class column and select sample
|
||||||
if class_key not in metadata.columns:
|
class_key = _find_class_column(metadata, class_key)
|
||||||
alternatives = ["class", "label", "modulation", "impairment", "use_case"]
|
|
||||||
for alt in alternatives:
|
|
||||||
if alt in metadata.columns:
|
|
||||||
class_key = alt
|
|
||||||
break
|
|
||||||
|
|
||||||
# Select sample
|
|
||||||
if sample_idx is None:
|
if sample_idx is None:
|
||||||
sample_idx = random.randint(0, len(metadata) - 1)
|
sample_idx = random.randint(0, len(metadata) - 1)
|
||||||
|
|
||||||
sample_metadata = metadata.iloc[sample_idx]
|
sample_metadata = metadata.iloc[sample_idx]
|
||||||
|
|
||||||
# Try to get actual sample data, fall back to synthetic
|
# Get sample data and ensure it's complex
|
||||||
try:
|
sample_data = _get_sample_data(dataset, sample_idx)
|
||||||
sample_data = dataset[sample_idx]
|
|
||||||
except:
|
|
||||||
# Generate synthetic signal based on class
|
|
||||||
n_samples = 1024
|
|
||||||
t = np.linspace(0, 1, n_samples)
|
|
||||||
freq = 0.1 + 0.05 * sample_idx % 5 # Vary frequency by sample
|
|
||||||
sample_data = np.exp(1j * 2 * np.pi * freq * t)
|
|
||||||
# Add some noise
|
|
||||||
sample_data += 0.1 * (np.random.randn(n_samples) + 1j * np.random.randn(n_samples))
|
|
||||||
|
|
||||||
# Ensure complex data
|
|
||||||
if not np.iscomplexobj(sample_data):
|
if not np.iscomplexobj(sample_data):
|
||||||
sample_data = sample_data.astype(complex)
|
sample_data = sample_data.astype(complex)
|
||||||
|
|
||||||
# Simple FFT-based spectrogram
|
# Calculate spectrogram parameters and compute spectrogram
|
||||||
n_samples = len(sample_data)
|
n_samples = len(sample_data)
|
||||||
|
nperseg, hop_length, n_frames, freq_bins = _calculate_spectrogram_params(n_samples)
|
||||||
|
Sxx = _compute_spectrogram(sample_data, nperseg, hop_length, n_frames, freq_bins)
|
||||||
|
|
||||||
# Ensure minimum viable data size
|
# Create and return the figure
|
||||||
if n_samples < 32:
|
return _create_spectrogram_figure(Sxx, n_frames, hop_length, n_samples, freq_bins,
|
||||||
raise ValueError(f"Insufficient data: need at least 32 samples, got {n_samples}")
|
sample_idx, class_key, sample_metadata)
|
||||||
|
|
||||||
nperseg = min(256, max(32, n_samples // 4))
|
|
||||||
|
|
||||||
# Create spectrogram using numpy (no scipy dependency)
|
|
||||||
hop_length = max(1, nperseg // 2) # Prevent zero hop_length
|
|
||||||
|
|
||||||
# Ensure we can create at least one frame
|
|
||||||
if n_samples < nperseg:
|
|
||||||
nperseg = n_samples
|
|
||||||
hop_length = 1
|
|
||||||
|
|
||||||
n_frames = max(1, (n_samples - nperseg) // hop_length + 1)
|
|
||||||
|
|
||||||
freq_bins = max(1, nperseg // 2) # Prevent zero frequency bins
|
|
||||||
Sxx = np.zeros((freq_bins, n_frames))
|
|
||||||
|
|
||||||
for i in range(n_frames):
|
|
||||||
start_idx = i * hop_length
|
|
||||||
end_idx = min(start_idx + nperseg, n_samples) # Prevent index overflow
|
|
||||||
|
|
||||||
if end_idx > start_idx: # Ensure we have data to process
|
|
||||||
windowed = sample_data[start_idx:end_idx]
|
|
||||||
|
|
||||||
# Pad if necessary to maintain nperseg size
|
|
||||||
if len(windowed) < nperseg:
|
|
||||||
windowed = np.pad(windowed, (0, nperseg - len(windowed)), mode='constant')
|
|
||||||
|
|
||||||
fft_result = np.fft.fft(windowed)
|
|
||||||
Sxx[:, i] = np.abs(fft_result[:freq_bins]) ** 2
|
|
||||||
|
|
||||||
# Convert to dB
|
|
||||||
Sxx_db = 10 * np.log10(Sxx + 1e-10)
|
|
||||||
|
|
||||||
# Create time and frequency vectors
|
|
||||||
t = np.arange(n_frames) * hop_length / max(1, n_samples) # Prevent division by zero
|
|
||||||
f = np.linspace(0, 0.5, freq_bins)
|
|
||||||
|
|
||||||
# Create plot
|
|
||||||
fig = go.Figure(data=go.Heatmap(
|
|
||||||
z=Sxx_db,
|
|
||||||
x=t,
|
|
||||||
y=f,
|
|
||||||
colorscale='viridis',
|
|
||||||
colorbar=dict(title="Power (dB)")
|
|
||||||
))
|
|
||||||
|
|
||||||
# Add title with metadata
|
|
||||||
title = f"Sample Spectrogram (Index: {sample_idx})"
|
|
||||||
if class_key in sample_metadata:
|
|
||||||
title += f" - {class_key}: {sample_metadata[class_key]}"
|
|
||||||
|
|
||||||
fig.update_layout(
|
|
||||||
title=title,
|
|
||||||
xaxis_title="Time",
|
|
||||||
yaxis_title="Frequency",
|
|
||||||
height=400,
|
|
||||||
template="plotly_dark"
|
|
||||||
)
|
|
||||||
|
|
||||||
return fig
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return create_styled_error_figure(
|
return create_styled_error_figure(
|
||||||
"Spectrogram Error",
|
"Spectrogram Error",
|
||||||
"An error occurred while generating the spectrogram plot.",
|
"An error occurred while generating the spectrogram plot.",
|
||||||
f"Technical details: {str(e)}"
|
f"Technical details: {str(e)}",
|
||||||
)
|
)
|
||||||
Loading…
Reference in New Issue
Block a user