Pytorch widgets

This commit is contained in:
ben 2025-10-09 16:55:23 -04:00
parent d919e4666c
commit 1fb55607a2

View File

@ -0,0 +1,89 @@
import torch
import plotly.graph_objects as go
from plotly.graph_objects import Figure
import numpy as np
def model_summary_plot(state_dict: dict) -> Figure:
"""Generate a summary plot of the PyTorch model state dict."""
# Count parameters by layer type
layer_info = []
for key, tensor in state_dict.items():
if 'weight' in key:
layer_name = key.replace('.weight', '')
param_count = tensor.numel()
layer_info.append({
'layer': layer_name,
'parameters': param_count,
'shape': list(tensor.shape)
})
# Create bar chart of parameter counts
fig = go.Figure(data=[
go.Bar(
x=[info['layer'] for info in layer_info],
y=[info['parameters'] for info in layer_info],
text=[f"Shape: {info['shape']}" for info in layer_info],
textposition='auto',
)
])
fig.update_layout(
title="Model Layer Parameter Counts",
xaxis_title="Layer",
yaxis_title="Number of Parameters"
)
return fig
def layer_weights_plot(state_dict: dict, layer_name: str = None) -> Figure:
"""Visualize weights for a specific layer."""
if layer_name is None:
# Get first weight tensor
weight_keys = [k for k in state_dict.keys() if 'weight' in k]
if not weight_keys:
raise ValueError("No weight tensors found in state dict")
layer_name = weight_keys[0]
weights = state_dict[layer_name]
# For 2D weights, create heatmap
if len(weights.shape) == 2:
fig = go.Figure(data=go.Heatmap(
z=weights.numpy(),
colorscale='RdBu',
zmid=0
))
fig.update_layout(title=f"Weights Heatmap: {layer_name}")
else:
# For other shapes, flatten and show histogram
flat_weights = weights.flatten().numpy()
fig = go.Figure(data=[go.Histogram(x=flat_weights, nbinsx=50)])
fig.update_layout(title=f"Weight Distribution: {layer_name}")
return fig
def weight_distribution_plot(state_dict: dict) -> Figure:
"""Show distribution of weights across all layers."""
all_weights = []
layer_names = []
for key, tensor in state_dict.items():
if 'weight' in key:
all_weights.extend(tensor.flatten().numpy())
layer_names.extend([key] * tensor.numel())
fig = go.Figure(data=[
go.Histogram(
x=all_weights,
nbinsx=100,
name="All Weights"
)
])
fig.update_layout(
title="Overall Weight Distribution",
xaxis_title="Weight Value",
yaxis_title="Frequency"
)
return fig