Widget Support Panels #5

Merged
benchinnery merged 9 commits from new_widgets_support into main 2025-10-22 12:04:53 -04:00
Showing only changes of commit f430e626a6 - Show all commits

View File

@ -5,17 +5,56 @@ import numpy as np
def model_summary_plot(state_dict: dict) -> Figure: def model_summary_plot(state_dict: dict) -> Figure:
"""Generate a summary plot of the PyTorch model state dict.""" """Generate a summary plot of the PyTorch model state dict."""
if not state_dict:
# Handle empty state dict
fig = go.Figure()
fig.add_annotation(
text="No parameters found in state dict",
xref="paper", yref="paper",
x=0.5, y=0.5, showarrow=False,
font=dict(size=16)
)
fig.update_layout(
title="Model Layer Parameter Counts",
xaxis_title="Layer",
yaxis_title="Number of Parameters",
template="plotly_dark"
)
return fig
# Count parameters by layer type # Count parameters by layer type
layer_info = [] layer_info = []
for key, tensor in state_dict.items(): for key, tensor in state_dict.items():
if 'weight' in key: if 'weight' in key:
layer_name = key.replace('.weight', '') try:
param_count = tensor.numel() layer_name = key.replace('.weight', '')
layer_info.append({ param_count = tensor.numel() if hasattr(tensor, 'numel') else len(tensor.flatten()) if hasattr(tensor, 'flatten') else 0
'layer': layer_name, shape = list(tensor.shape) if hasattr(tensor, 'shape') else [len(tensor)] if hasattr(tensor, '__len__') else []
'parameters': param_count, layer_info.append({
'shape': list(tensor.shape) 'layer': layer_name,
}) 'parameters': param_count,
'shape': shape
})
except Exception as e:
print(f"Warning: Could not process layer {key}: {e}")
continue
if not layer_info:
# Handle case where no weight layers found
fig = go.Figure()
fig.add_annotation(
text="No weight layers found in state dict",
xref="paper", yref="paper",
x=0.5, y=0.5, showarrow=False,
font=dict(size=16)
)
fig.update_layout(
title="Model Layer Parameter Counts",
xaxis_title="Layer",
yaxis_title="Number of Parameters",
template="plotly_dark"
)
return fig
# Create bar chart of parameter counts # Create bar chart of parameter counts
fig = go.Figure(data=[ fig = go.Figure(data=[
@ -30,47 +69,147 @@ def model_summary_plot(state_dict: dict) -> Figure:
fig.update_layout( fig.update_layout(
title="Model Layer Parameter Counts", title="Model Layer Parameter Counts",
xaxis_title="Layer", xaxis_title="Layer",
yaxis_title="Number of Parameters" yaxis_title="Number of Parameters",
template="plotly_dark"
) )
return fig return fig
def layer_weights_plot(state_dict: dict, layer_name: str = None) -> Figure: def layer_weights_plot(state_dict: dict, layer_name: str = None) -> Figure:
"""Visualize weights for a specific layer.""" """Visualize weights for a specific layer."""
if not state_dict:
fig = go.Figure()
fig.add_annotation(
text="No data in state dict",
xref="paper", yref="paper",
x=0.5, y=0.5, showarrow=False,
font=dict(size=16)
)
fig.update_layout(
title="Layer Weights",
template="plotly_dark"
)
return fig
if layer_name is None: if layer_name is None:
# Get first weight tensor # Get first weight tensor
weight_keys = [k for k in state_dict.keys() if 'weight' in k] weight_keys = [k for k in state_dict.keys() if 'weight' in k]
if not weight_keys: if not weight_keys:
raise ValueError("No weight tensors found in state dict") fig = go.Figure()
fig.add_annotation(
text="No weight tensors found in state dict",
xref="paper", yref="paper",
x=0.5, y=0.5, showarrow=False,
font=dict(size=16)
)
fig.update_layout(
title="Layer Weights",
template="plotly_dark"
)
return fig
layer_name = weight_keys[0] layer_name = weight_keys[0]
weights = state_dict[layer_name] try:
weights = state_dict[layer_name]
# For 2D weights, create heatmap # Convert to numpy if it's a torch tensor
if len(weights.shape) == 2: if hasattr(weights, 'numpy'):
fig = go.Figure(data=go.Heatmap( weights_np = weights.detach().numpy() if hasattr(weights, 'detach') else weights.numpy()
z=weights.numpy(), elif hasattr(weights, 'cpu'):
colorscale='RdBu', weights_np = weights.cpu().detach().numpy()
zmid=0 else:
)) weights_np = np.array(weights)
fig.update_layout(title=f"Weights Heatmap: {layer_name}")
else:
# For other shapes, flatten and show histogram
flat_weights = weights.flatten().numpy()
fig = go.Figure(data=[go.Histogram(x=flat_weights, nbinsx=50)])
fig.update_layout(title=f"Weight Distribution: {layer_name}")
return fig # For 2D weights, create heatmap
if len(weights_np.shape) == 2:
fig = go.Figure(data=go.Heatmap(
z=weights_np,
colorscale='RdBu',
zmid=0
))
fig.update_layout(
title=f"Weights Heatmap: {layer_name}",
template="plotly_dark"
)
else:
# For other shapes, flatten and show histogram
flat_weights = weights_np.flatten()
fig = go.Figure(data=[go.Histogram(x=flat_weights, nbinsx=50)])
fig.update_layout(
title=f"Weight Distribution: {layer_name}",
template="plotly_dark"
)
return fig
except Exception as e:
fig = go.Figure()
fig.add_annotation(
text=f"Error processing layer {layer_name}: {str(e)}",
xref="paper", yref="paper",
x=0.5, y=0.5, showarrow=False,
font=dict(size=14)
)
fig.update_layout(
title="Layer Weights - Error",
template="plotly_dark"
)
return fig
def weight_distribution_plot(state_dict: dict) -> Figure: def weight_distribution_plot(state_dict: dict) -> Figure:
"""Show distribution of weights across all layers.""" """Show distribution of weights across all layers."""
if not state_dict:
fig = go.Figure()
fig.add_annotation(
text="No data in state dict",
xref="paper", yref="paper",
x=0.5, y=0.5, showarrow=False,
font=dict(size=16)
)
fig.update_layout(
title="Overall Weight Distribution",
xaxis_title="Weight Value",
yaxis_title="Frequency",
template="plotly_dark"
)
return fig
all_weights = [] all_weights = []
layer_names = [] layer_names = []
for key, tensor in state_dict.items(): for key, tensor in state_dict.items():
if 'weight' in key: if 'weight' in key:
all_weights.extend(tensor.flatten().numpy()) try:
layer_names.extend([key] * tensor.numel()) # Convert to numpy if it's a torch tensor
if hasattr(tensor, 'numpy'):
weights_np = tensor.detach().numpy() if hasattr(tensor, 'detach') else tensor.numpy()
elif hasattr(tensor, 'cpu'):
weights_np = tensor.cpu().detach().numpy()
else:
weights_np = np.array(tensor)
flat_weights = weights_np.flatten()
all_weights.extend(flat_weights)
layer_names.extend([key] * len(flat_weights))
except Exception as e:
print(f"Warning: Could not process weights for layer {key}: {e}")
continue
if not all_weights:
fig = go.Figure()
fig.add_annotation(
text="No weight data found in state dict",
xref="paper", yref="paper",
x=0.5, y=0.5, showarrow=False,
font=dict(size=16)
)
fig.update_layout(
title="Overall Weight Distribution",
xaxis_title="Weight Value",
yaxis_title="Frequency",
template="plotly_dark"
)
return fig
fig = go.Figure(data=[ fig = go.Figure(data=[
go.Histogram( go.Histogram(
@ -83,7 +222,8 @@ def weight_distribution_plot(state_dict: dict) -> Figure:
fig.update_layout( fig.update_layout(
title="Overall Weight Distribution", title="Overall Weight Distribution",
xaxis_title="Weight Value", xaxis_title="Weight Value",
yaxis_title="Frequency" yaxis_title="Frequency",
template="plotly_dark"
) )
return fig return fig