From f430e626a61163743298f218f79b1735a8869155 Mon Sep 17 00:00:00 2001 From: ben Date: Tue, 14 Oct 2025 14:22:37 -0400 Subject: [PATCH] Pytorch state dict widget --- src/ria_toolkit_oss/viz/pytorch_state_dict.py | 198 +++++++++++++++--- 1 file changed, 169 insertions(+), 29 deletions(-) diff --git a/src/ria_toolkit_oss/viz/pytorch_state_dict.py b/src/ria_toolkit_oss/viz/pytorch_state_dict.py index 05c96f4..7db7528 100644 --- a/src/ria_toolkit_oss/viz/pytorch_state_dict.py +++ b/src/ria_toolkit_oss/viz/pytorch_state_dict.py @@ -5,17 +5,56 @@ import numpy as np def model_summary_plot(state_dict: dict) -> Figure: """Generate a summary plot of the PyTorch model state dict.""" + if not state_dict: + # Handle empty state dict + fig = go.Figure() + fig.add_annotation( + text="No parameters found in state dict", + xref="paper", yref="paper", + x=0.5, y=0.5, showarrow=False, + font=dict(size=16) + ) + fig.update_layout( + title="Model Layer Parameter Counts", + xaxis_title="Layer", + yaxis_title="Number of Parameters", + template="plotly_dark" + ) + return fig + # Count parameters by layer type layer_info = [] for key, tensor in state_dict.items(): if 'weight' in key: - layer_name = key.replace('.weight', '') - param_count = tensor.numel() - layer_info.append({ - 'layer': layer_name, - 'parameters': param_count, - 'shape': list(tensor.shape) - }) + try: + layer_name = key.replace('.weight', '') + param_count = tensor.numel() if hasattr(tensor, 'numel') else len(tensor.flatten()) if hasattr(tensor, 'flatten') else 0 + shape = list(tensor.shape) if hasattr(tensor, 'shape') else [len(tensor)] if hasattr(tensor, '__len__') else [] + layer_info.append({ + 'layer': layer_name, + 'parameters': param_count, + 'shape': shape + }) + except Exception as e: + print(f"Warning: Could not process layer {key}: {e}") + continue + + if not layer_info: + # Handle case where no weight layers found + fig = go.Figure() + fig.add_annotation( + text="No weight layers found in state dict", + xref="paper", yref="paper", + x=0.5, y=0.5, showarrow=False, + font=dict(size=16) + ) + fig.update_layout( + title="Model Layer Parameter Counts", + xaxis_title="Layer", + yaxis_title="Number of Parameters", + template="plotly_dark" + ) + return fig # Create bar chart of parameter counts fig = go.Figure(data=[ @@ -30,47 +69,147 @@ def model_summary_plot(state_dict: dict) -> Figure: fig.update_layout( title="Model Layer Parameter Counts", xaxis_title="Layer", - yaxis_title="Number of Parameters" + yaxis_title="Number of Parameters", + template="plotly_dark" ) return fig def layer_weights_plot(state_dict: dict, layer_name: str = None) -> Figure: """Visualize weights for a specific layer.""" + if not state_dict: + fig = go.Figure() + fig.add_annotation( + text="No data in state dict", + xref="paper", yref="paper", + x=0.5, y=0.5, showarrow=False, + font=dict(size=16) + ) + fig.update_layout( + title="Layer Weights", + template="plotly_dark" + ) + return fig + if layer_name is None: # Get first weight tensor weight_keys = [k for k in state_dict.keys() if 'weight' in k] if not weight_keys: - raise ValueError("No weight tensors found in state dict") + fig = go.Figure() + fig.add_annotation( + text="No weight tensors found in state dict", + xref="paper", yref="paper", + x=0.5, y=0.5, showarrow=False, + font=dict(size=16) + ) + fig.update_layout( + title="Layer Weights", + template="plotly_dark" + ) + return fig layer_name = weight_keys[0] - weights = state_dict[layer_name] - - # For 2D weights, create heatmap - if len(weights.shape) == 2: - fig = go.Figure(data=go.Heatmap( - z=weights.numpy(), - colorscale='RdBu', - zmid=0 - )) - fig.update_layout(title=f"Weights Heatmap: {layer_name}") - else: - # For other shapes, flatten and show histogram - flat_weights = weights.flatten().numpy() - fig = go.Figure(data=[go.Histogram(x=flat_weights, nbinsx=50)]) - fig.update_layout(title=f"Weight Distribution: {layer_name}") - - return fig + try: + weights = state_dict[layer_name] + + # Convert to numpy if it's a torch tensor + if hasattr(weights, 'numpy'): + weights_np = weights.detach().numpy() if hasattr(weights, 'detach') else weights.numpy() + elif hasattr(weights, 'cpu'): + weights_np = weights.cpu().detach().numpy() + else: + weights_np = np.array(weights) + + # For 2D weights, create heatmap + if len(weights_np.shape) == 2: + fig = go.Figure(data=go.Heatmap( + z=weights_np, + colorscale='RdBu', + zmid=0 + )) + fig.update_layout( + title=f"Weights Heatmap: {layer_name}", + template="plotly_dark" + ) + else: + # For other shapes, flatten and show histogram + flat_weights = weights_np.flatten() + fig = go.Figure(data=[go.Histogram(x=flat_weights, nbinsx=50)]) + fig.update_layout( + title=f"Weight Distribution: {layer_name}", + template="plotly_dark" + ) + + return fig + + except Exception as e: + fig = go.Figure() + fig.add_annotation( + text=f"Error processing layer {layer_name}: {str(e)}", + xref="paper", yref="paper", + x=0.5, y=0.5, showarrow=False, + font=dict(size=14) + ) + fig.update_layout( + title="Layer Weights - Error", + template="plotly_dark" + ) + return fig def weight_distribution_plot(state_dict: dict) -> Figure: """Show distribution of weights across all layers.""" + if not state_dict: + fig = go.Figure() + fig.add_annotation( + text="No data in state dict", + xref="paper", yref="paper", + x=0.5, y=0.5, showarrow=False, + font=dict(size=16) + ) + fig.update_layout( + title="Overall Weight Distribution", + xaxis_title="Weight Value", + yaxis_title="Frequency", + template="plotly_dark" + ) + return fig + all_weights = [] layer_names = [] for key, tensor in state_dict.items(): if 'weight' in key: - all_weights.extend(tensor.flatten().numpy()) - layer_names.extend([key] * tensor.numel()) + try: + # Convert to numpy if it's a torch tensor + if hasattr(tensor, 'numpy'): + weights_np = tensor.detach().numpy() if hasattr(tensor, 'detach') else tensor.numpy() + elif hasattr(tensor, 'cpu'): + weights_np = tensor.cpu().detach().numpy() + else: + weights_np = np.array(tensor) + + flat_weights = weights_np.flatten() + all_weights.extend(flat_weights) + layer_names.extend([key] * len(flat_weights)) + except Exception as e: + print(f"Warning: Could not process weights for layer {key}: {e}") + continue + + if not all_weights: + fig = go.Figure() + fig.add_annotation( + text="No weight data found in state dict", + xref="paper", yref="paper", + x=0.5, y=0.5, showarrow=False, + font=dict(size=16) + ) + fig.update_layout( + title="Overall Weight Distribution", + xaxis_title="Weight Value", + yaxis_title="Frequency", + template="plotly_dark" + ) + return fig fig = go.Figure(data=[ go.Histogram( @@ -83,7 +222,8 @@ def weight_distribution_plot(state_dict: dict) -> Figure: fig.update_layout( title="Overall Weight Distribution", xaxis_title="Weight Value", - yaxis_title="Frequency" + yaxis_title="Frequency", + template="plotly_dark" ) return fig \ No newline at end of file