Format Fixes
This commit is contained in:
parent
8105b829be
commit
c21b522a67
416
src/ria_toolkit_oss/viz/pytorch_model.py
Normal file
416
src/ria_toolkit_oss/viz/pytorch_model.py
Normal file
|
|
@ -0,0 +1,416 @@
|
||||||
|
"""Visualization functions for PyTorch model (.py) files.
|
||||||
|
|
||||||
|
This module provides visualization capabilities for PyTorch model Python files,
|
||||||
|
extracting architectural information through AST parsing and static analysis.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import ast
|
||||||
|
import re
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
|
import plotly.graph_objects as go
|
||||||
|
from plotly.graph_objects import Figure
|
||||||
|
|
||||||
|
|
||||||
|
def create_styled_error_figure(title: str, message: str, suggestion: str = None) -> go.Figure:
|
||||||
|
"""Create a professional error figure with Qoherent dark theme styling."""
|
||||||
|
fig = go.Figure()
|
||||||
|
|
||||||
|
# Create a clean, centered text display using Plotly's text formatting
|
||||||
|
main_text = f"<b style='color:#f56565;font-size:18px'>⚠️ {title}</b><br><br>"
|
||||||
|
main_text += f"<span style='color:#e2e8f0;font-size:14px'>{message}</span>"
|
||||||
|
|
||||||
|
if suggestion:
|
||||||
|
main_text += "<br><br><span style='color:#63b3ed;font-size:13px'>💡 <b>Suggestion:</b></span><br>"
|
||||||
|
main_text += f"<span style='color:#cbd5e0;font-size:12px'>{suggestion}</span>"
|
||||||
|
|
||||||
|
# Add the main text annotation
|
||||||
|
fig.add_annotation(
|
||||||
|
text=main_text,
|
||||||
|
xref="paper",
|
||||||
|
yref="paper",
|
||||||
|
x=0.5,
|
||||||
|
y=0.5,
|
||||||
|
xanchor="center",
|
||||||
|
yanchor="middle",
|
||||||
|
showarrow=False,
|
||||||
|
align="center",
|
||||||
|
borderwidth=2,
|
||||||
|
bordercolor="#4a5568",
|
||||||
|
bgcolor="#2d3748",
|
||||||
|
font=dict(family="Arial, sans-serif", size=14, color="#e2e8f0"),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Update layout with dark theme
|
||||||
|
fig.update_layout(
|
||||||
|
title="",
|
||||||
|
height=400,
|
||||||
|
template="plotly_dark",
|
||||||
|
margin=dict(l=40, r=40, t=40, b=40),
|
||||||
|
plot_bgcolor="#1a202c",
|
||||||
|
paper_bgcolor="#1a202c",
|
||||||
|
font=dict(color="#e2e8f0"),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Remove axes and grid
|
||||||
|
fig.update_xaxes(visible=False)
|
||||||
|
fig.update_yaxes(visible=False)
|
||||||
|
|
||||||
|
return fig
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_model_file(file_path: Path) -> Tuple[Optional[ast.Module], Optional[str]]:
|
||||||
|
"""Parse a Python model file and return the AST and any error message."""
|
||||||
|
try:
|
||||||
|
with open(file_path, "r", encoding="utf-8") as f:
|
||||||
|
code = f.read()
|
||||||
|
tree = ast.parse(code, filename=str(file_path))
|
||||||
|
return tree, None
|
||||||
|
except SyntaxError as e:
|
||||||
|
return None, f"Syntax error in file: {e}"
|
||||||
|
except Exception as e:
|
||||||
|
return None, f"Failed to parse file: {e}"
|
||||||
|
|
||||||
|
|
||||||
|
def _find_model_class(tree: ast.Module) -> Optional[ast.ClassDef]:
|
||||||
|
"""Find the main model class (subclass of nn.Module) in the AST."""
|
||||||
|
for node in ast.walk(tree):
|
||||||
|
if isinstance(node, ast.ClassDef):
|
||||||
|
# Check if it inherits from nn.Module or torch.nn.Module
|
||||||
|
for base in node.bases:
|
||||||
|
base_name = ""
|
||||||
|
if isinstance(base, ast.Name):
|
||||||
|
base_name = base.id
|
||||||
|
elif isinstance(base, ast.Attribute):
|
||||||
|
if isinstance(base.value, ast.Name):
|
||||||
|
base_name = f"{base.value.id}.{base.attr}"
|
||||||
|
|
||||||
|
if "Module" in base_name or "nn.Module" in base_name:
|
||||||
|
return node
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_layer_info(model_class: ast.ClassDef) -> List[Dict[str, Any]]:
|
||||||
|
"""Extract layer information from the model's __init__ method."""
|
||||||
|
layers = []
|
||||||
|
|
||||||
|
# Find __init__ method
|
||||||
|
init_method = None
|
||||||
|
for node in model_class.body:
|
||||||
|
if isinstance(node, ast.FunctionDef) and node.name == "__init__":
|
||||||
|
init_method = node
|
||||||
|
break
|
||||||
|
|
||||||
|
if not init_method:
|
||||||
|
return layers
|
||||||
|
|
||||||
|
# Parse assignments in __init__
|
||||||
|
for node in ast.walk(init_method):
|
||||||
|
if isinstance(node, ast.Assign):
|
||||||
|
for target in node.targets:
|
||||||
|
if isinstance(target, ast.Attribute):
|
||||||
|
layer_name = target.attr
|
||||||
|
layer_type = _extract_layer_type(node.value)
|
||||||
|
if layer_type:
|
||||||
|
layers.append(
|
||||||
|
{"name": layer_name, "type": layer_type, "details": _extract_layer_params(node.value)}
|
||||||
|
)
|
||||||
|
|
||||||
|
return layers
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_layer_type(node: ast.expr) -> Optional[str]:
|
||||||
|
"""Extract the layer type from an AST node."""
|
||||||
|
if isinstance(node, ast.Call):
|
||||||
|
if isinstance(node.func, ast.Name):
|
||||||
|
return node.func.id
|
||||||
|
elif isinstance(node.func, ast.Attribute):
|
||||||
|
return node.func.attr
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_layer_params(node: ast.Call) -> str:
|
||||||
|
"""Extract layer parameters as a string."""
|
||||||
|
params = []
|
||||||
|
|
||||||
|
# Extract positional arguments
|
||||||
|
for arg in node.args:
|
||||||
|
if isinstance(arg, ast.Constant):
|
||||||
|
params.append(str(arg.value))
|
||||||
|
elif isinstance(arg, ast.Name):
|
||||||
|
params.append(arg.id)
|
||||||
|
|
||||||
|
# Extract keyword arguments
|
||||||
|
for keyword in node.keywords:
|
||||||
|
if isinstance(keyword.value, ast.Constant):
|
||||||
|
params.append(f"{keyword.arg}={keyword.value.value}")
|
||||||
|
elif isinstance(keyword.value, ast.Name):
|
||||||
|
params.append(f"{keyword.arg}={keyword.value.id}")
|
||||||
|
|
||||||
|
return ", ".join(params)
|
||||||
|
|
||||||
|
|
||||||
|
def _count_parameters(layers: List[Dict[str, Any]]) -> int:
|
||||||
|
"""Estimate parameter count from layer definitions (rough estimate)."""
|
||||||
|
# This is a very rough estimate - actual counts would require instantiating the model
|
||||||
|
param_estimates = {
|
||||||
|
"Linear": 1000,
|
||||||
|
"Conv1d": 500,
|
||||||
|
"Conv2d": 5000,
|
||||||
|
"Conv3d": 10000,
|
||||||
|
"LSTM": 4000,
|
||||||
|
"GRU": 3000,
|
||||||
|
"TransformerEncoder": 50000,
|
||||||
|
"Embedding": 10000,
|
||||||
|
}
|
||||||
|
|
||||||
|
total = 0
|
||||||
|
for layer in layers:
|
||||||
|
layer_type = layer["type"]
|
||||||
|
total += param_estimates.get(layer_type, 100)
|
||||||
|
|
||||||
|
return total
|
||||||
|
|
||||||
|
|
||||||
|
def model_architecture_plot(file_path: Path) -> Figure:
|
||||||
|
"""Visualize the architecture of a PyTorch model from its .py file.
|
||||||
|
|
||||||
|
Parses the model file using AST to extract layers and their connections.
|
||||||
|
"""
|
||||||
|
tree, error = _parse_model_file(file_path)
|
||||||
|
|
||||||
|
if error:
|
||||||
|
return create_styled_error_figure(
|
||||||
|
"Parse Error", error, "Ensure the .py file contains valid Python code with a PyTorch nn.Module class"
|
||||||
|
)
|
||||||
|
|
||||||
|
model_class = _find_model_class(tree)
|
||||||
|
if not model_class:
|
||||||
|
return create_styled_error_figure(
|
||||||
|
"No Model Found",
|
||||||
|
"Could not find a PyTorch nn.Module class in the file",
|
||||||
|
"Ensure your model class inherits from torch.nn.Module or nn.Module",
|
||||||
|
)
|
||||||
|
|
||||||
|
layers = _extract_layer_info(model_class)
|
||||||
|
|
||||||
|
if not layers:
|
||||||
|
return create_styled_error_figure(
|
||||||
|
"No Layers Found",
|
||||||
|
"Could not extract layer information from the model",
|
||||||
|
"Ensure your model defines layers in the __init__ method",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create a hierarchical visualization
|
||||||
|
layer_names = [f"{i+1}. {layer['name']}" for i, layer in enumerate(layers)]
|
||||||
|
layer_types = [layer["type"] for layer in layers]
|
||||||
|
layer_details = [layer["details"] for layer in layers]
|
||||||
|
|
||||||
|
# Create a bar chart showing layers
|
||||||
|
fig = go.Figure()
|
||||||
|
|
||||||
|
fig.add_trace(
|
||||||
|
go.Bar(
|
||||||
|
y=layer_names,
|
||||||
|
x=[1] * len(layer_names),
|
||||||
|
orientation="h",
|
||||||
|
text=layer_types,
|
||||||
|
textposition="inside",
|
||||||
|
hovertext=[
|
||||||
|
f"{name}<br>Type: {type_}<br>Params: {details}"
|
||||||
|
for name, type_, details in zip(layer_names, layer_types, layer_details)
|
||||||
|
],
|
||||||
|
hoverinfo="text",
|
||||||
|
marker=dict(color="rgba(99, 179, 237, 0.8)", line=dict(color="rgba(99, 179, 237, 1.0)", width=2)),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
fig.update_layout(
|
||||||
|
title=f"Model Architecture: {model_class.name}",
|
||||||
|
xaxis=dict(visible=False),
|
||||||
|
yaxis=dict(title="Layers", autorange="reversed"),
|
||||||
|
template="plotly_dark",
|
||||||
|
height=max(400, len(layers) * 40),
|
||||||
|
showlegend=False,
|
||||||
|
margin=dict(l=200, r=40, t=60, b=40),
|
||||||
|
)
|
||||||
|
|
||||||
|
return fig
|
||||||
|
|
||||||
|
|
||||||
|
def model_complexity_plot(file_path: Path) -> Figure:
|
||||||
|
"""Analyze and visualize model complexity metrics."""
|
||||||
|
tree, error = _parse_model_file(file_path)
|
||||||
|
|
||||||
|
if error:
|
||||||
|
return create_styled_error_figure("Parse Error", error, "Ensure the .py file contains valid Python code")
|
||||||
|
|
||||||
|
model_class = _find_model_class(tree)
|
||||||
|
if not model_class:
|
||||||
|
return create_styled_error_figure("No Model Found", "Could not find a PyTorch nn.Module class in the file")
|
||||||
|
|
||||||
|
layers = _extract_layer_info(model_class)
|
||||||
|
|
||||||
|
if not layers:
|
||||||
|
return create_styled_error_figure("No Layers Found", "Could not extract layer information from the model")
|
||||||
|
|
||||||
|
# Count layer types
|
||||||
|
layer_type_counts = {}
|
||||||
|
for layer in layers:
|
||||||
|
layer_type = layer["type"]
|
||||||
|
layer_type_counts[layer_type] = layer_type_counts.get(layer_type, 0) + 1
|
||||||
|
|
||||||
|
# Create pie chart of layer types
|
||||||
|
fig = go.Figure(
|
||||||
|
data=[
|
||||||
|
go.Pie(
|
||||||
|
labels=list(layer_type_counts.keys()),
|
||||||
|
values=list(layer_type_counts.values()),
|
||||||
|
hole=0.3,
|
||||||
|
marker=dict(colors=["#5c79ff", "#63b3ed", "#48bb78", "#f6ad55", "#fc8181"]),
|
||||||
|
)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
fig.update_layout(
|
||||||
|
title="Layer Type Distribution",
|
||||||
|
template="plotly_dark",
|
||||||
|
height=400,
|
||||||
|
)
|
||||||
|
|
||||||
|
return fig
|
||||||
|
|
||||||
|
|
||||||
|
def model_metadata_plot(file_path: Path) -> Figure:
|
||||||
|
"""Display model metadata and information extracted from the Python file."""
|
||||||
|
tree, error = _parse_model_file(file_path)
|
||||||
|
|
||||||
|
if error:
|
||||||
|
return create_styled_error_figure("Parse Error", error, "Ensure the .py file contains valid Python code")
|
||||||
|
|
||||||
|
model_class = _find_model_class(tree)
|
||||||
|
if not model_class:
|
||||||
|
return create_styled_error_figure("No Model Found", "Could not find a PyTorch nn.Module class in the file")
|
||||||
|
|
||||||
|
layers = _extract_layer_info(model_class)
|
||||||
|
|
||||||
|
# Extract imports
|
||||||
|
imports = []
|
||||||
|
for node in tree.body:
|
||||||
|
if isinstance(node, ast.Import):
|
||||||
|
for alias in node.names:
|
||||||
|
imports.append(alias.name)
|
||||||
|
elif isinstance(node, ast.ImportFrom):
|
||||||
|
if node.module:
|
||||||
|
imports.append(node.module)
|
||||||
|
|
||||||
|
# Get docstring
|
||||||
|
docstring = ast.get_docstring(model_class) or "No docstring available"
|
||||||
|
if len(docstring) > 200:
|
||||||
|
docstring = docstring[:200] + "..."
|
||||||
|
|
||||||
|
# Build metadata display
|
||||||
|
metadata_text = f"""<b style='font-size:16px;color:#63b3ed'>Model: {model_class.name}</b><br><br>"""
|
||||||
|
metadata_text += f"<b>📝 Description:</b><br><span style='color:#cbd5e0'>{docstring}</span><br><br>"
|
||||||
|
metadata_text += f"<b>🔢 Number of Layers:</b> {len(layers)}<br>"
|
||||||
|
metadata_text += f"<b>📦 Estimated Parameters:</b> ~{_count_parameters(layers):,}<br><br>"
|
||||||
|
metadata_text += f"<b>📚 Key Imports:</b><br>"
|
||||||
|
|
||||||
|
relevant_imports = [imp for imp in imports if "torch" in imp or "nn" in imp][:5]
|
||||||
|
for imp in relevant_imports:
|
||||||
|
metadata_text += f" • {imp}<br>"
|
||||||
|
|
||||||
|
fig = go.Figure()
|
||||||
|
|
||||||
|
fig.add_annotation(
|
||||||
|
text=metadata_text,
|
||||||
|
xref="paper",
|
||||||
|
yref="paper",
|
||||||
|
x=0.05,
|
||||||
|
y=0.95,
|
||||||
|
xanchor="left",
|
||||||
|
yanchor="top",
|
||||||
|
showarrow=False,
|
||||||
|
align="left",
|
||||||
|
borderwidth=2,
|
||||||
|
bordercolor="#4a5568",
|
||||||
|
bgcolor="#2d3748",
|
||||||
|
font=dict(family="Arial, sans-serif", size=13, color="#e2e8f0"),
|
||||||
|
)
|
||||||
|
|
||||||
|
fig.update_layout(
|
||||||
|
title="Model Metadata",
|
||||||
|
template="plotly_dark",
|
||||||
|
height=450,
|
||||||
|
margin=dict(l=40, r=40, t=60, b=40),
|
||||||
|
plot_bgcolor="#1a202c",
|
||||||
|
paper_bgcolor="#1a202c",
|
||||||
|
)
|
||||||
|
|
||||||
|
fig.update_xaxes(visible=False)
|
||||||
|
fig.update_yaxes(visible=False)
|
||||||
|
|
||||||
|
return fig
|
||||||
|
|
||||||
|
|
||||||
|
def code_structure_plot(file_path: Path) -> Figure:
|
||||||
|
"""Visualize the code structure and method definitions in the model."""
|
||||||
|
tree, error = _parse_model_file(file_path)
|
||||||
|
|
||||||
|
if error:
|
||||||
|
return create_styled_error_figure("Parse Error", error, "Ensure the .py file contains valid Python code")
|
||||||
|
|
||||||
|
model_class = _find_model_class(tree)
|
||||||
|
if not model_class:
|
||||||
|
return create_styled_error_figure("No Model Found", "Could not find a PyTorch nn.Module class in the file")
|
||||||
|
|
||||||
|
# Extract methods
|
||||||
|
methods = []
|
||||||
|
for node in model_class.body:
|
||||||
|
if isinstance(node, ast.FunctionDef):
|
||||||
|
# Count lines in method
|
||||||
|
if hasattr(node, "end_lineno") and hasattr(node, "lineno"):
|
||||||
|
lines = node.end_lineno - node.lineno + 1
|
||||||
|
else:
|
||||||
|
lines = 1
|
||||||
|
|
||||||
|
methods.append({"name": node.name, "lines": lines, "args": len(node.args.args) - 1}) # Exclude self
|
||||||
|
|
||||||
|
if not methods:
|
||||||
|
return create_styled_error_figure(
|
||||||
|
"No Methods Found", "Could not extract method information from the model class"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create visualization of methods
|
||||||
|
method_names = [m["name"] for m in methods]
|
||||||
|
method_lines = [m["lines"] for m in methods]
|
||||||
|
method_args = [m["args"] for m in methods]
|
||||||
|
|
||||||
|
fig = go.Figure()
|
||||||
|
|
||||||
|
# Bar chart for method complexity (lines of code)
|
||||||
|
fig.add_trace(
|
||||||
|
go.Bar(
|
||||||
|
x=method_names,
|
||||||
|
y=method_lines,
|
||||||
|
name="Lines of Code",
|
||||||
|
marker=dict(color="rgba(99, 179, 237, 0.8)"),
|
||||||
|
hovertext=[
|
||||||
|
f"{name}<br>Lines: {lines}<br>Arguments: {args}"
|
||||||
|
for name, lines, args in zip(method_names, method_lines, method_args)
|
||||||
|
],
|
||||||
|
hoverinfo="text",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
fig.update_layout(
|
||||||
|
title=f"Method Complexity - {model_class.name}",
|
||||||
|
xaxis_title="Methods",
|
||||||
|
yaxis_title="Lines of Code",
|
||||||
|
template="plotly_dark",
|
||||||
|
height=400,
|
||||||
|
showlegend=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
return fig
|
||||||
Loading…
Reference in New Issue
Block a user