Formatting Fixes
Some checks failed
Build Sphinx Docs Set / Build Docs (pull_request) Successful in 17s
Test with tox / Test with tox (3.12) (pull_request) Successful in 42s
Test with tox / Test with tox (3.11) (pull_request) Successful in 49s
Build Project / Build Project (3.10) (pull_request) Successful in 1m2s
Test with tox / Test with tox (3.10) (pull_request) Failing after 58s
Build Project / Build Project (3.11) (pull_request) Successful in 1m2s
Build Project / Build Project (3.12) (pull_request) Successful in 1m2s

This commit is contained in:
ben 2025-10-21 11:48:40 -04:00
parent c7c7100d46
commit c06e58f5d6

View File

@ -1,188 +1,190 @@
import numpy as np
import plotly.graph_objects as go import plotly.graph_objects as go
from plotly.graph_objects import Figure from plotly.graph_objects import Figure
import numpy as np
def create_styled_error_figure(title: str, message: str, suggestion: str = None) -> go.Figure:
"""Create a professional error figure with Qoherent dark theme styling."""
fig = go.Figure()
# Create a clean, centered text display using Plotly's text formatting
main_text = f"<b style='color:#f56565;font-size:18px'>⚠️ {title}</b><br><br>"
main_text += f"<span style='color:#e2e8f0;font-size:14px'>{message}</span>"
if suggestion:
main_text += "<br><br><span style='color:#63b3ed;font-size:13px'>💡 <b>Suggestion:</b></span><br>"
main_text += f"<span style='color:#cbd5e0;font-size:12px'>{suggestion}</span>"
# Add the main text annotation
fig.add_annotation(
text=main_text,
xref="paper",
yref="paper",
x=0.5,
y=0.5,
xanchor="center",
yanchor="middle",
showarrow=False,
align="center",
borderwidth=2,
bordercolor="#4a5568",
bgcolor="#2d3748",
font=dict(family="Arial, sans-serif", size=14, color="#e2e8f0"),
)
# Update layout with dark theme
fig.update_layout(
title="",
height=400,
template="plotly_dark",
margin=dict(l=40, r=40, t=40, b=40),
plot_bgcolor="#1a202c",
paper_bgcolor="#1a202c",
font=dict(color="#e2e8f0"),
)
# Remove axes and grid
fig.update_xaxes(visible=False)
fig.update_yaxes(visible=False)
return fig
def model_summary_plot(state_dict: dict) -> Figure: def model_summary_plot(state_dict: dict) -> Figure:
"""Generate a summary plot of the PyTorch model state dict.""" """Generate a summary plot of the PyTorch model state dict."""
if not state_dict: if not state_dict:
# Handle empty state dict return create_styled_error_figure(
fig = go.Figure() "Empty State Dict",
fig.add_annotation( "No parameters found in state dict",
text="No parameters found in state dict", "Ensure the model state dictionary contains weight parameters"
xref="paper",
yref="paper",
x=0.5,
y=0.5,
showarrow=False,
font=dict(size=16),
) )
fig.update_layout(
title="Model Layer Parameter Counts",
xaxis_title="Layer",
yaxis_title="Number of Parameters",
template="plotly_dark",
)
return fig
# Count parameters by layer type # Count parameters by layer type
layer_info = [] layer_info = []
for key, tensor in state_dict.items(): for key, tensor in state_dict.items():
if "weight" in key: if 'weight' in key:
try: try:
layer_name = key.replace(".weight", "") layer_name = key.replace('.weight', '')
param_count = ( param_count = (
tensor.numel() tensor.numel() if hasattr(tensor, 'numel')
if hasattr(tensor, "numel") else len(tensor.flatten()) if hasattr(tensor, 'flatten')
else len(tensor.flatten()) if hasattr(tensor, "flatten") else 0 else 0
) )
shape = ( shape = (
list(tensor.shape) list(tensor.shape) if hasattr(tensor, 'shape')
if hasattr(tensor, "shape") else [len(tensor)] if hasattr(tensor, '__len__')
else [len(tensor)] if hasattr(tensor, "__len__") else [] else []
) )
layer_info.append({"layer": layer_name, "parameters": param_count, "shape": shape}) layer_info.append({
'layer': layer_name,
'parameters': param_count,
'shape': shape
})
except Exception as e: except Exception as e:
print(f"Warning: Could not process layer {key}: {e}") print(f"Warning: Could not process layer {key}: {e}")
continue continue
if not layer_info: if not layer_info:
# Handle case where no weight layers found return create_styled_error_figure(
fig = go.Figure() "No Weight Layers Found",
fig.add_annotation( "No weight layers found in state dict",
text="No weight layers found in state dict", "Ensure the state dictionary contains layers with '.weight' parameters"
xref="paper",
yref="paper",
x=0.5,
y=0.5,
showarrow=False,
font=dict(size=16),
) )
fig.update_layout(
title="Model Layer Parameter Counts",
xaxis_title="Layer",
yaxis_title="Number of Parameters",
template="plotly_dark",
)
return fig
# Create bar chart of parameter counts # Create bar chart of parameter counts
fig = go.Figure( fig = go.Figure(data=[
data=[ go.Bar(
go.Bar( x=[info['layer'] for info in layer_info],
x=[info["layer"] for info in layer_info], y=[info['parameters'] for info in layer_info],
y=[info["parameters"] for info in layer_info], text=[f"Shape: {info['shape']}" for info in layer_info],
text=[f"Shape: {info['shape']}" for info in layer_info], textposition='auto',
textposition="auto", )
) ])
]
)
fig.update_layout( fig.update_layout(
title="Model Layer Parameter Counts", title="Model Layer Parameter Counts",
xaxis_title="Layer", xaxis_title="Layer",
yaxis_title="Number of Parameters", yaxis_title="Number of Parameters",
template="plotly_dark", template="plotly_dark"
) )
return fig return fig
def layer_weights_plot(state_dict: dict, layer_name: str = None) -> Figure: def layer_weights_plot(state_dict: dict, layer_name: str = None) -> Figure:
"""Visualize weights for a specific layer.""" """Visualize weights for a specific layer."""
if not state_dict: if not state_dict:
fig = go.Figure() return create_styled_error_figure(
fig.add_annotation( "Empty State Dict",
text="No data in state dict", xref="paper", yref="paper", x=0.5, y=0.5, showarrow=False, font=dict(size=16) "No data in state dict",
"Ensure the model state dictionary contains data"
) )
fig.update_layout(title="Layer Weights", template="plotly_dark")
return fig
if layer_name is None: if layer_name is None:
# Get first weight tensor # Get first weight tensor
weight_keys = [k for k in state_dict.keys() if "weight" in k] weight_keys = [k for k in state_dict.keys() if 'weight' in k]
if not weight_keys: if not weight_keys:
fig = go.Figure() return create_styled_error_figure(
fig.add_annotation( "No Weight Tensors Found",
text="No weight tensors found in state dict", "No weight tensors found in state dict",
xref="paper", "Ensure the state dictionary contains layers with '.weight' parameters"
yref="paper",
x=0.5,
y=0.5,
showarrow=False,
font=dict(size=16),
) )
fig.update_layout(title="Layer Weights", template="plotly_dark")
return fig
layer_name = weight_keys[0] layer_name = weight_keys[0]
try: try:
weights = state_dict[layer_name] weights = state_dict[layer_name]
# Convert to numpy if it's a torch tensor # Convert to numpy if it's a torch tensor
if hasattr(weights, "numpy"): if hasattr(weights, 'numpy'):
weights_np = weights.detach().numpy() if hasattr(weights, "detach") else weights.numpy() weights_np = weights.detach().numpy() if hasattr(weights, 'detach') else weights.numpy()
elif hasattr(weights, "cpu"): elif hasattr(weights, 'cpu'):
weights_np = weights.cpu().detach().numpy() weights_np = weights.cpu().detach().numpy()
else: else:
weights_np = np.array(weights) weights_np = np.array(weights)
# For 2D weights, create heatmap # For 2D weights, create heatmap
if len(weights_np.shape) == 2: if len(weights_np.shape) == 2:
fig = go.Figure(data=go.Heatmap(z=weights_np, colorscale="RdBu", zmid=0)) fig = go.Figure(data=go.Heatmap(
fig.update_layout(title=f"Weights Heatmap: {layer_name}", template="plotly_dark") z=weights_np,
colorscale='RdBu',
zmid=0
))
fig.update_layout(
title=f"Weights Heatmap: {layer_name}",
template="plotly_dark"
)
else: else:
# For other shapes, flatten and show histogram # For other shapes, flatten and show histogram
flat_weights = weights_np.flatten() flat_weights = weights_np.flatten()
fig = go.Figure(data=[go.Histogram(x=flat_weights, nbinsx=50)]) fig = go.Figure(data=[go.Histogram(x=flat_weights, nbinsx=50)])
fig.update_layout(title=f"Weight Distribution: {layer_name}", template="plotly_dark") fig.update_layout(
title=f"Weight Distribution: {layer_name}",
template="plotly_dark"
)
return fig return fig
except Exception as e: except Exception as e:
fig = go.Figure() return create_styled_error_figure(
fig.add_annotation( "Layer Processing Error",
text=f"Error processing layer {layer_name}: {str(e)}", f"Error processing layer {layer_name}: {str(e)}",
xref="paper", "Check that the layer name exists and contains valid tensor data"
yref="paper",
x=0.5,
y=0.5,
showarrow=False,
font=dict(size=14),
) )
fig.update_layout(title="Layer Weights - Error", template="plotly_dark")
return fig
def weight_distribution_plot(state_dict: dict) -> Figure: def weight_distribution_plot(state_dict: dict) -> Figure:
"""Show distribution of weights across all layers.""" """Show distribution of weights across all layers."""
if not state_dict: if not state_dict:
fig = go.Figure() return create_styled_error_figure(
fig.add_annotation( "Empty State Dict",
text="No data in state dict", xref="paper", yref="paper", x=0.5, y=0.5, showarrow=False, font=dict(size=16) "No data in state dict",
"Ensure the model state dictionary contains data"
) )
fig.update_layout(
title="Overall Weight Distribution",
xaxis_title="Weight Value",
yaxis_title="Frequency",
template="plotly_dark",
)
return fig
all_weights = [] all_weights = []
layer_names = [] layer_names = []
for key, tensor in state_dict.items(): for key, tensor in state_dict.items():
if "weight" in key: if 'weight' in key:
try: try:
# Convert to numpy if it's a torch tensor # Convert to numpy if it's a torch tensor
if hasattr(tensor, "numpy"): if hasattr(tensor, 'numpy'):
weights_np = tensor.detach().numpy() if hasattr(tensor, "detach") else tensor.numpy() weights_np = tensor.detach().numpy() if hasattr(tensor, 'detach') else tensor.numpy()
elif hasattr(tensor, "cpu"): elif hasattr(tensor, 'cpu'):
weights_np = tensor.cpu().detach().numpy() weights_np = tensor.cpu().detach().numpy()
else: else:
weights_np = np.array(tensor) weights_np = np.array(tensor)
flat_weights = weights_np.flatten() flat_weights = weights_np.flatten()
all_weights.extend(flat_weights) all_weights.extend(flat_weights)
layer_names.extend([key] * len(flat_weights)) layer_names.extend([key] * len(flat_weights))
@ -191,31 +193,24 @@ def weight_distribution_plot(state_dict: dict) -> Figure:
continue continue
if not all_weights: if not all_weights:
fig = go.Figure() return create_styled_error_figure(
fig.add_annotation( "No Weight Data Found",
text="No weight data found in state dict", "No weight data found in state dict",
xref="paper", "Ensure the state dictionary contains layers with '.weight' parameters"
yref="paper",
x=0.5,
y=0.5,
showarrow=False,
font=dict(size=16),
) )
fig.update_layout(
title="Overall Weight Distribution",
xaxis_title="Weight Value",
yaxis_title="Frequency",
template="plotly_dark",
)
return fig
fig = go.Figure(data=[go.Histogram(x=all_weights, nbinsx=100, name="All Weights")]) fig = go.Figure(data=[
go.Histogram(
x=all_weights,
nbinsx=100,
name="All Weights"
)
])
fig.update_layout( fig.update_layout(
title="Overall Weight Distribution", title="Overall Weight Distribution",
xaxis_title="Weight Value", xaxis_title="Weight Value",
yaxis_title="Frequency", yaxis_title="Frequency",
template="plotly_dark", template="plotly_dark"
) )
return fig return fig