Formatting Fixes
Some checks failed
Build Sphinx Docs Set / Build Docs (pull_request) Successful in 17s
Test with tox / Test with tox (3.12) (pull_request) Successful in 42s
Test with tox / Test with tox (3.11) (pull_request) Successful in 49s
Build Project / Build Project (3.10) (pull_request) Successful in 1m2s
Test with tox / Test with tox (3.10) (pull_request) Failing after 58s
Build Project / Build Project (3.11) (pull_request) Successful in 1m2s
Build Project / Build Project (3.12) (pull_request) Successful in 1m2s
Some checks failed
Build Sphinx Docs Set / Build Docs (pull_request) Successful in 17s
Test with tox / Test with tox (3.12) (pull_request) Successful in 42s
Test with tox / Test with tox (3.11) (pull_request) Successful in 49s
Build Project / Build Project (3.10) (pull_request) Successful in 1m2s
Test with tox / Test with tox (3.10) (pull_request) Failing after 58s
Build Project / Build Project (3.11) (pull_request) Successful in 1m2s
Build Project / Build Project (3.12) (pull_request) Successful in 1m2s
This commit is contained in:
parent
c7c7100d46
commit
c06e58f5d6
|
|
@ -1,188 +1,190 @@
|
||||||
import numpy as np
|
|
||||||
import plotly.graph_objects as go
|
import plotly.graph_objects as go
|
||||||
from plotly.graph_objects import Figure
|
from plotly.graph_objects import Figure
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
def create_styled_error_figure(title: str, message: str, suggestion: str = None) -> go.Figure:
|
||||||
|
"""Create a professional error figure with Qoherent dark theme styling."""
|
||||||
|
fig = go.Figure()
|
||||||
|
|
||||||
|
# Create a clean, centered text display using Plotly's text formatting
|
||||||
|
main_text = f"<b style='color:#f56565;font-size:18px'>⚠️ {title}</b><br><br>"
|
||||||
|
main_text += f"<span style='color:#e2e8f0;font-size:14px'>{message}</span>"
|
||||||
|
|
||||||
|
if suggestion:
|
||||||
|
main_text += "<br><br><span style='color:#63b3ed;font-size:13px'>💡 <b>Suggestion:</b></span><br>"
|
||||||
|
main_text += f"<span style='color:#cbd5e0;font-size:12px'>{suggestion}</span>"
|
||||||
|
|
||||||
|
# Add the main text annotation
|
||||||
|
fig.add_annotation(
|
||||||
|
text=main_text,
|
||||||
|
xref="paper",
|
||||||
|
yref="paper",
|
||||||
|
x=0.5,
|
||||||
|
y=0.5,
|
||||||
|
xanchor="center",
|
||||||
|
yanchor="middle",
|
||||||
|
showarrow=False,
|
||||||
|
align="center",
|
||||||
|
borderwidth=2,
|
||||||
|
bordercolor="#4a5568",
|
||||||
|
bgcolor="#2d3748",
|
||||||
|
font=dict(family="Arial, sans-serif", size=14, color="#e2e8f0"),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Update layout with dark theme
|
||||||
|
fig.update_layout(
|
||||||
|
title="",
|
||||||
|
height=400,
|
||||||
|
template="plotly_dark",
|
||||||
|
margin=dict(l=40, r=40, t=40, b=40),
|
||||||
|
plot_bgcolor="#1a202c",
|
||||||
|
paper_bgcolor="#1a202c",
|
||||||
|
font=dict(color="#e2e8f0"),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Remove axes and grid
|
||||||
|
fig.update_xaxes(visible=False)
|
||||||
|
fig.update_yaxes(visible=False)
|
||||||
|
|
||||||
|
return fig
|
||||||
|
|
||||||
|
|
||||||
def model_summary_plot(state_dict: dict) -> Figure:
|
def model_summary_plot(state_dict: dict) -> Figure:
|
||||||
"""Generate a summary plot of the PyTorch model state dict."""
|
"""Generate a summary plot of the PyTorch model state dict."""
|
||||||
if not state_dict:
|
if not state_dict:
|
||||||
# Handle empty state dict
|
return create_styled_error_figure(
|
||||||
fig = go.Figure()
|
"Empty State Dict",
|
||||||
fig.add_annotation(
|
"No parameters found in state dict",
|
||||||
text="No parameters found in state dict",
|
"Ensure the model state dictionary contains weight parameters"
|
||||||
xref="paper",
|
|
||||||
yref="paper",
|
|
||||||
x=0.5,
|
|
||||||
y=0.5,
|
|
||||||
showarrow=False,
|
|
||||||
font=dict(size=16),
|
|
||||||
)
|
)
|
||||||
fig.update_layout(
|
|
||||||
title="Model Layer Parameter Counts",
|
|
||||||
xaxis_title="Layer",
|
|
||||||
yaxis_title="Number of Parameters",
|
|
||||||
template="plotly_dark",
|
|
||||||
)
|
|
||||||
return fig
|
|
||||||
|
|
||||||
# Count parameters by layer type
|
# Count parameters by layer type
|
||||||
layer_info = []
|
layer_info = []
|
||||||
for key, tensor in state_dict.items():
|
for key, tensor in state_dict.items():
|
||||||
if "weight" in key:
|
if 'weight' in key:
|
||||||
try:
|
try:
|
||||||
layer_name = key.replace(".weight", "")
|
layer_name = key.replace('.weight', '')
|
||||||
param_count = (
|
param_count = (
|
||||||
tensor.numel()
|
tensor.numel() if hasattr(tensor, 'numel')
|
||||||
if hasattr(tensor, "numel")
|
else len(tensor.flatten()) if hasattr(tensor, 'flatten')
|
||||||
else len(tensor.flatten()) if hasattr(tensor, "flatten") else 0
|
else 0
|
||||||
)
|
)
|
||||||
shape = (
|
shape = (
|
||||||
list(tensor.shape)
|
list(tensor.shape) if hasattr(tensor, 'shape')
|
||||||
if hasattr(tensor, "shape")
|
else [len(tensor)] if hasattr(tensor, '__len__')
|
||||||
else [len(tensor)] if hasattr(tensor, "__len__") else []
|
else []
|
||||||
)
|
)
|
||||||
layer_info.append({"layer": layer_name, "parameters": param_count, "shape": shape})
|
layer_info.append({
|
||||||
|
'layer': layer_name,
|
||||||
|
'parameters': param_count,
|
||||||
|
'shape': shape
|
||||||
|
})
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Warning: Could not process layer {key}: {e}")
|
print(f"Warning: Could not process layer {key}: {e}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if not layer_info:
|
if not layer_info:
|
||||||
# Handle case where no weight layers found
|
return create_styled_error_figure(
|
||||||
fig = go.Figure()
|
"No Weight Layers Found",
|
||||||
fig.add_annotation(
|
"No weight layers found in state dict",
|
||||||
text="No weight layers found in state dict",
|
"Ensure the state dictionary contains layers with '.weight' parameters"
|
||||||
xref="paper",
|
|
||||||
yref="paper",
|
|
||||||
x=0.5,
|
|
||||||
y=0.5,
|
|
||||||
showarrow=False,
|
|
||||||
font=dict(size=16),
|
|
||||||
)
|
)
|
||||||
fig.update_layout(
|
|
||||||
title="Model Layer Parameter Counts",
|
|
||||||
xaxis_title="Layer",
|
|
||||||
yaxis_title="Number of Parameters",
|
|
||||||
template="plotly_dark",
|
|
||||||
)
|
|
||||||
return fig
|
|
||||||
|
|
||||||
# Create bar chart of parameter counts
|
# Create bar chart of parameter counts
|
||||||
fig = go.Figure(
|
fig = go.Figure(data=[
|
||||||
data=[
|
go.Bar(
|
||||||
go.Bar(
|
x=[info['layer'] for info in layer_info],
|
||||||
x=[info["layer"] for info in layer_info],
|
y=[info['parameters'] for info in layer_info],
|
||||||
y=[info["parameters"] for info in layer_info],
|
text=[f"Shape: {info['shape']}" for info in layer_info],
|
||||||
text=[f"Shape: {info['shape']}" for info in layer_info],
|
textposition='auto',
|
||||||
textposition="auto",
|
)
|
||||||
)
|
])
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
fig.update_layout(
|
fig.update_layout(
|
||||||
title="Model Layer Parameter Counts",
|
title="Model Layer Parameter Counts",
|
||||||
xaxis_title="Layer",
|
xaxis_title="Layer",
|
||||||
yaxis_title="Number of Parameters",
|
yaxis_title="Number of Parameters",
|
||||||
template="plotly_dark",
|
template="plotly_dark"
|
||||||
)
|
)
|
||||||
|
|
||||||
return fig
|
return fig
|
||||||
|
|
||||||
|
|
||||||
def layer_weights_plot(state_dict: dict, layer_name: str = None) -> Figure:
|
def layer_weights_plot(state_dict: dict, layer_name: str = None) -> Figure:
|
||||||
"""Visualize weights for a specific layer."""
|
"""Visualize weights for a specific layer."""
|
||||||
if not state_dict:
|
if not state_dict:
|
||||||
fig = go.Figure()
|
return create_styled_error_figure(
|
||||||
fig.add_annotation(
|
"Empty State Dict",
|
||||||
text="No data in state dict", xref="paper", yref="paper", x=0.5, y=0.5, showarrow=False, font=dict(size=16)
|
"No data in state dict",
|
||||||
|
"Ensure the model state dictionary contains data"
|
||||||
)
|
)
|
||||||
fig.update_layout(title="Layer Weights", template="plotly_dark")
|
|
||||||
return fig
|
|
||||||
|
|
||||||
if layer_name is None:
|
if layer_name is None:
|
||||||
# Get first weight tensor
|
# Get first weight tensor
|
||||||
weight_keys = [k for k in state_dict.keys() if "weight" in k]
|
weight_keys = [k for k in state_dict.keys() if 'weight' in k]
|
||||||
if not weight_keys:
|
if not weight_keys:
|
||||||
fig = go.Figure()
|
return create_styled_error_figure(
|
||||||
fig.add_annotation(
|
"No Weight Tensors Found",
|
||||||
text="No weight tensors found in state dict",
|
"No weight tensors found in state dict",
|
||||||
xref="paper",
|
"Ensure the state dictionary contains layers with '.weight' parameters"
|
||||||
yref="paper",
|
|
||||||
x=0.5,
|
|
||||||
y=0.5,
|
|
||||||
showarrow=False,
|
|
||||||
font=dict(size=16),
|
|
||||||
)
|
)
|
||||||
fig.update_layout(title="Layer Weights", template="plotly_dark")
|
|
||||||
return fig
|
|
||||||
layer_name = weight_keys[0]
|
layer_name = weight_keys[0]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
weights = state_dict[layer_name]
|
weights = state_dict[layer_name]
|
||||||
|
|
||||||
# Convert to numpy if it's a torch tensor
|
# Convert to numpy if it's a torch tensor
|
||||||
if hasattr(weights, "numpy"):
|
if hasattr(weights, 'numpy'):
|
||||||
weights_np = weights.detach().numpy() if hasattr(weights, "detach") else weights.numpy()
|
weights_np = weights.detach().numpy() if hasattr(weights, 'detach') else weights.numpy()
|
||||||
elif hasattr(weights, "cpu"):
|
elif hasattr(weights, 'cpu'):
|
||||||
weights_np = weights.cpu().detach().numpy()
|
weights_np = weights.cpu().detach().numpy()
|
||||||
else:
|
else:
|
||||||
weights_np = np.array(weights)
|
weights_np = np.array(weights)
|
||||||
|
|
||||||
# For 2D weights, create heatmap
|
# For 2D weights, create heatmap
|
||||||
if len(weights_np.shape) == 2:
|
if len(weights_np.shape) == 2:
|
||||||
fig = go.Figure(data=go.Heatmap(z=weights_np, colorscale="RdBu", zmid=0))
|
fig = go.Figure(data=go.Heatmap(
|
||||||
fig.update_layout(title=f"Weights Heatmap: {layer_name}", template="plotly_dark")
|
z=weights_np,
|
||||||
|
colorscale='RdBu',
|
||||||
|
zmid=0
|
||||||
|
))
|
||||||
|
fig.update_layout(
|
||||||
|
title=f"Weights Heatmap: {layer_name}",
|
||||||
|
template="plotly_dark"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
# For other shapes, flatten and show histogram
|
# For other shapes, flatten and show histogram
|
||||||
flat_weights = weights_np.flatten()
|
flat_weights = weights_np.flatten()
|
||||||
fig = go.Figure(data=[go.Histogram(x=flat_weights, nbinsx=50)])
|
fig = go.Figure(data=[go.Histogram(x=flat_weights, nbinsx=50)])
|
||||||
fig.update_layout(title=f"Weight Distribution: {layer_name}", template="plotly_dark")
|
fig.update_layout(
|
||||||
|
title=f"Weight Distribution: {layer_name}",
|
||||||
|
template="plotly_dark"
|
||||||
|
)
|
||||||
|
|
||||||
return fig
|
return fig
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
fig = go.Figure()
|
return create_styled_error_figure(
|
||||||
fig.add_annotation(
|
"Layer Processing Error",
|
||||||
text=f"Error processing layer {layer_name}: {str(e)}",
|
f"Error processing layer {layer_name}: {str(e)}",
|
||||||
xref="paper",
|
"Check that the layer name exists and contains valid tensor data"
|
||||||
yref="paper",
|
|
||||||
x=0.5,
|
|
||||||
y=0.5,
|
|
||||||
showarrow=False,
|
|
||||||
font=dict(size=14),
|
|
||||||
)
|
)
|
||||||
fig.update_layout(title="Layer Weights - Error", template="plotly_dark")
|
|
||||||
return fig
|
|
||||||
|
|
||||||
|
|
||||||
def weight_distribution_plot(state_dict: dict) -> Figure:
|
def weight_distribution_plot(state_dict: dict) -> Figure:
|
||||||
"""Show distribution of weights across all layers."""
|
"""Show distribution of weights across all layers."""
|
||||||
if not state_dict:
|
if not state_dict:
|
||||||
fig = go.Figure()
|
return create_styled_error_figure(
|
||||||
fig.add_annotation(
|
"Empty State Dict",
|
||||||
text="No data in state dict", xref="paper", yref="paper", x=0.5, y=0.5, showarrow=False, font=dict(size=16)
|
"No data in state dict",
|
||||||
|
"Ensure the model state dictionary contains data"
|
||||||
)
|
)
|
||||||
fig.update_layout(
|
|
||||||
title="Overall Weight Distribution",
|
|
||||||
xaxis_title="Weight Value",
|
|
||||||
yaxis_title="Frequency",
|
|
||||||
template="plotly_dark",
|
|
||||||
)
|
|
||||||
return fig
|
|
||||||
|
|
||||||
all_weights = []
|
all_weights = []
|
||||||
layer_names = []
|
layer_names = []
|
||||||
|
|
||||||
for key, tensor in state_dict.items():
|
for key, tensor in state_dict.items():
|
||||||
if "weight" in key:
|
if 'weight' in key:
|
||||||
try:
|
try:
|
||||||
# Convert to numpy if it's a torch tensor
|
# Convert to numpy if it's a torch tensor
|
||||||
if hasattr(tensor, "numpy"):
|
if hasattr(tensor, 'numpy'):
|
||||||
weights_np = tensor.detach().numpy() if hasattr(tensor, "detach") else tensor.numpy()
|
weights_np = tensor.detach().numpy() if hasattr(tensor, 'detach') else tensor.numpy()
|
||||||
elif hasattr(tensor, "cpu"):
|
elif hasattr(tensor, 'cpu'):
|
||||||
weights_np = tensor.cpu().detach().numpy()
|
weights_np = tensor.cpu().detach().numpy()
|
||||||
else:
|
else:
|
||||||
weights_np = np.array(tensor)
|
weights_np = np.array(tensor)
|
||||||
|
|
||||||
flat_weights = weights_np.flatten()
|
flat_weights = weights_np.flatten()
|
||||||
all_weights.extend(flat_weights)
|
all_weights.extend(flat_weights)
|
||||||
layer_names.extend([key] * len(flat_weights))
|
layer_names.extend([key] * len(flat_weights))
|
||||||
|
|
@ -191,31 +193,24 @@ def weight_distribution_plot(state_dict: dict) -> Figure:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if not all_weights:
|
if not all_weights:
|
||||||
fig = go.Figure()
|
return create_styled_error_figure(
|
||||||
fig.add_annotation(
|
"No Weight Data Found",
|
||||||
text="No weight data found in state dict",
|
"No weight data found in state dict",
|
||||||
xref="paper",
|
"Ensure the state dictionary contains layers with '.weight' parameters"
|
||||||
yref="paper",
|
|
||||||
x=0.5,
|
|
||||||
y=0.5,
|
|
||||||
showarrow=False,
|
|
||||||
font=dict(size=16),
|
|
||||||
)
|
)
|
||||||
fig.update_layout(
|
|
||||||
title="Overall Weight Distribution",
|
|
||||||
xaxis_title="Weight Value",
|
|
||||||
yaxis_title="Frequency",
|
|
||||||
template="plotly_dark",
|
|
||||||
)
|
|
||||||
return fig
|
|
||||||
|
|
||||||
fig = go.Figure(data=[go.Histogram(x=all_weights, nbinsx=100, name="All Weights")])
|
fig = go.Figure(data=[
|
||||||
|
go.Histogram(
|
||||||
|
x=all_weights,
|
||||||
|
nbinsx=100,
|
||||||
|
name="All Weights"
|
||||||
|
)
|
||||||
|
])
|
||||||
|
|
||||||
fig.update_layout(
|
fig.update_layout(
|
||||||
title="Overall Weight Distribution",
|
title="Overall Weight Distribution",
|
||||||
xaxis_title="Weight Value",
|
xaxis_title="Weight Value",
|
||||||
yaxis_title="Frequency",
|
yaxis_title="Frequency",
|
||||||
template="plotly_dark",
|
template="plotly_dark"
|
||||||
)
|
)
|
||||||
|
|
||||||
return fig
|
return fig
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user