pytorch-widgets #10

Merged
benchinnery merged 4 commits from pytorch-widgets into main 2025-11-13 11:02:18 -05:00
Showing only changes of commit 48f6b303f5 - Show all commits

View File

@ -5,7 +5,6 @@ extracting architectural information through AST parsing and static analysis.
""" """
import ast import ast
import re
from pathlib import Path from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple from typing import Any, Dict, List, Optional, Tuple
@ -283,9 +282,10 @@ def model_complexity_plot(file_path: Path) -> Figure:
def model_metadata_plot(file_path: Path) -> Figure: def model_metadata_plot(file_path: Path) -> Figure:
"""Display model metadata and information extracted from the Python file.""" """Display model metadata and information extracted from the Python file (clean, aligned layout)."""
tree, error = _parse_model_file(file_path) import textwrap
tree, error = _parse_model_file(file_path)
if error: if error:
return create_styled_error_figure("Parse Error", error, "Ensure the .py file contains valid Python code") return create_styled_error_figure("Parse Error", error, "Ensure the .py file contains valid Python code")
@ -301,56 +301,200 @@ def model_metadata_plot(file_path: Path) -> Figure:
if isinstance(node, ast.Import): if isinstance(node, ast.Import):
for alias in node.names: for alias in node.names:
imports.append(alias.name) imports.append(alias.name)
elif isinstance(node, ast.ImportFrom): elif isinstance(node, ast.ImportFrom) and node.module:
if node.module: imports.append(node.module)
imports.append(node.module)
# Get docstring # Get docstring and wrap it
docstring = ast.get_docstring(model_class) or "No docstring available" docstring = ast.get_docstring(model_class) or "No docstring available"
if len(docstring) > 200: wrapped_doc = "<br>".join(textwrap.wrap(docstring, width=70))
docstring = docstring[:200] + "..."
# Build metadata display relevant_imports = [imp for imp in imports if "torch" in imp or "nn" in imp][:4]
metadata_text = f"""<b style='font-size:16px;color:#63b3ed'>Model: {model_class.name}</b><br><br>""" param_count = _count_parameters(layers)
metadata_text += f"<b>📝 Description:</b><br><span style='color:#cbd5e0'>{docstring}</span><br><br>"
metadata_text += f"<b>🔢 Number of Layers:</b> {len(layers)}<br>"
metadata_text += f"<b>📦 Estimated Parameters:</b> ~{_count_parameters(layers):,}<br><br>"
metadata_text += f"<b>📚 Key Imports:</b><br>"
relevant_imports = [imp for imp in imports if "torch" in imp or "nn" in imp][:5] # Define card grid (aligned 2x2)
for imp in relevant_imports: cards = [
metadata_text += f"{imp}<br>" {"x": 0.05, "y": 0.93, "width": 0.43, "height": 0.38, "title": "📦 Model Overview", "color": "#2d5f8d"},
{"x": 0.52, "y": 0.93, "width": 0.43, "height": 0.38, "title": "🔢 Statistics", "color": "#2d6b5f"},
{"x": 0.05, "y": 0.46, "width": 0.43, "height": 0.38, "title": "📝 Description", "color": "#5d4b7a"},
{"x": 0.52, "y": 0.46, "width": 0.43, "height": 0.38, "title": "📚 Dependencies", "color": "#7a5b3d"},
]
fig = go.Figure() fig = go.Figure()
# Draw background cards with consistent opacity
for card in cards:
fig.add_shape(
type="rect",
xref="paper",
yref="paper",
x0=card["x"],
y0=card["y"] - card["height"],
x1=card["x"] + card["width"],
y1=card["y"],
fillcolor=card["color"],
line=dict(color="#4a5568", width=2),
opacity=0.3,
layer="below",
)
# Header bar
fig.add_shape(
type="rect",
xref="paper",
yref="paper",
x0=card["x"],
y0=card["y"] - 0.07,
x1=card["x"] + card["width"],
y1=card["y"],
fillcolor=card["color"],
line=dict(width=0),
opacity=0.45,
layer="below",
)
# --- CARD 1: Model Overview ---
card = cards[0]
fig.add_annotation( fig.add_annotation(
text=metadata_text, text=f"<b>{card['title']}</b>",
xref="paper", xref="paper",
yref="paper", yref="paper",
x=0.05, x=card["x"] + 0.03,
y=0.95, y=card["y"] - 0.02,
xanchor="left",
yanchor="middle",
showarrow=False,
align="left",
font=dict(size=15, color="#ffffff", family="Inter, Arial, sans-serif"),
)
fig.add_annotation(
text=f"<b style='font-size:26px;color:#ffffff'>{model_class.name}</b><br>"
f"<span style='color:#94a3b8;font-size:15px'>PyTorch Neural Network</span>",
xref="paper",
yref="paper",
x=card["x"] + 0.04,
y=card["y"] - 0.13,
xanchor="left", xanchor="left",
yanchor="top", yanchor="top",
showarrow=False, showarrow=False,
align="left", align="left",
borderwidth=2, font=dict(size=15, color="#cbd5e0", family="Inter, Arial, sans-serif"),
bordercolor="#4a5568",
bgcolor="#2d3748",
font=dict(family="Arial, sans-serif", size=13, color="#e2e8f0"),
) )
# --- CARD 2: Statistics ---
card = cards[1]
y_center = card["y"] - card["height"] / 2
fig.add_annotation(
text=f"<b>{card['title']}</b>",
xref="paper",
yref="paper",
x=card["x"] + 0.03,
y=card["y"] - 0.02,
xanchor="left",
yanchor="middle",
showarrow=False,
font=dict(size=15, color="#ffffff", family="Inter, Arial, sans-serif"),
)
fig.add_annotation(
text=f"<b style='font-size:44px;color:#63b3ed'>{len(layers)}</b><br>"
f"<span style='color:#94a3b8;font-size:13px;letter-spacing:1.5px'>LAYERS</span>",
xref="paper",
yref="paper",
x=card["x"] + card["width"] / 2,
y=y_center + 0.07,
xanchor="center",
yanchor="middle",
showarrow=False,
align="center",
)
fig.add_annotation(
text=f"<b style='font-size:36px;color:#48bb78'>~{param_count:,}</b><br>"
f"<span style='color:#94a3b8;font-size:13px;letter-spacing:1.5px'>PARAMETERS</span>",
xref="paper",
yref="paper",
x=card["x"] + card["width"] / 2,
y=y_center - 0.10,
xanchor="center",
yanchor="middle",
showarrow=False,
align="center",
)
# --- CARD 3: Description ---
card = cards[2]
fig.add_annotation(
text=f"<b>{card['title']}</b>",
xref="paper",
yref="paper",
x=card["x"] + 0.03,
y=card["y"] - 0.02,
xanchor="left",
yanchor="middle",
showarrow=False,
font=dict(size=15, color="#ffffff", family="Inter, Arial, sans-serif"),
)
fig.add_annotation(
text=f"<span style='color:#cbd5e0;font-size:14px;line-height:1.5'>{wrapped_doc}</span>",
xref="paper",
yref="paper",
x=card["x"] + 0.04,
y=card["y"] - 0.13,
xanchor="left",
yanchor="top",
showarrow=False,
align="left",
)
# --- CARD 4: Dependencies ---
card = cards[3]
fig.add_annotation(
text=f"<b>{card['title']}</b>",
xref="paper",
yref="paper",
x=card["x"] + 0.03,
y=card["y"] - 0.02,
xanchor="left",
yanchor="middle",
showarrow=False,
font=dict(size=15, color="#ffffff", family="Inter, Arial, sans-serif"),
)
imports_text = (
"<br>".join(
[
f"<span style='color:#48bb78;font-size:16px'>▸</span> "
f"<span style='color:#e2e8f0;font-family:\"Courier New\",monospace;font-size:14px'>{imp}</span>"
for imp in relevant_imports
]
)
if relevant_imports
else "<span style='color:#94a3b8;font-style:italic;font-size:14px'>No torch imports detected</span>"
)
fig.add_annotation(
text=imports_text,
xref="paper",
yref="paper",
x=card["x"] + 0.04,
y=card["y"] - 0.13,
xanchor="left",
yanchor="top",
showarrow=False,
align="left",
)
# Layout polish
fig.update_layout( fig.update_layout(
title="Model Metadata", title=dict(
text="<b>Model Metadata</b>",
font=dict(size=20, color="#e2e8f0", family="Inter, Arial, sans-serif"),
x=0.5,
xanchor="center",
),
template="plotly_dark", template="plotly_dark",
height=450, height=500,
margin=dict(l=40, r=40, t=60, b=40), margin=dict(l=20, r=20, t=70, b=20),
plot_bgcolor="#1a202c", plot_bgcolor="#1a202c",
paper_bgcolor="#1a202c", paper_bgcolor="#1a202c",
) )
fig.update_xaxes(visible=False) fig.update_xaxes(visible=False)
fig.update_yaxes(visible=False) fig.update_yaxes(visible=False)
return fig return fig