diff --git a/src/ria_toolkit_oss/viz/pytorch_model.py b/src/ria_toolkit_oss/viz/pytorch_model.py
index 2376ce3..ef9733d 100644
--- a/src/ria_toolkit_oss/viz/pytorch_model.py
+++ b/src/ria_toolkit_oss/viz/pytorch_model.py
@@ -5,7 +5,6 @@ extracting architectural information through AST parsing and static analysis.
"""
import ast
-import re
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple
@@ -283,9 +282,10 @@ def model_complexity_plot(file_path: Path) -> Figure:
def model_metadata_plot(file_path: Path) -> Figure:
- """Display model metadata and information extracted from the Python file."""
- tree, error = _parse_model_file(file_path)
+ """Display model metadata and information extracted from the Python file (clean, aligned layout)."""
+ import textwrap
+ tree, error = _parse_model_file(file_path)
if error:
return create_styled_error_figure("Parse Error", error, "Ensure the .py file contains valid Python code")
@@ -301,56 +301,200 @@ def model_metadata_plot(file_path: Path) -> Figure:
if isinstance(node, ast.Import):
for alias in node.names:
imports.append(alias.name)
- elif isinstance(node, ast.ImportFrom):
- if node.module:
- imports.append(node.module)
+ elif isinstance(node, ast.ImportFrom) and node.module:
+ imports.append(node.module)
- # Get docstring
+ # Get docstring and wrap it
docstring = ast.get_docstring(model_class) or "No docstring available"
- if len(docstring) > 200:
- docstring = docstring[:200] + "..."
+ wrapped_doc = "
".join(textwrap.wrap(docstring, width=70))
- # Build metadata display
- metadata_text = f"""Model: {model_class.name}
"""
- metadata_text += f"📝 Description:
{docstring}
"
- metadata_text += f"🔢 Number of Layers: {len(layers)}
"
- metadata_text += f"📦 Estimated Parameters: ~{_count_parameters(layers):,}
"
- metadata_text += f"📚 Key Imports:
"
+ relevant_imports = [imp for imp in imports if "torch" in imp or "nn" in imp][:4]
+ param_count = _count_parameters(layers)
- relevant_imports = [imp for imp in imports if "torch" in imp or "nn" in imp][:5]
- for imp in relevant_imports:
- metadata_text += f" • {imp}
"
+ # Define card grid (aligned 2x2)
+ cards = [
+ {"x": 0.05, "y": 0.93, "width": 0.43, "height": 0.38, "title": "📦 Model Overview", "color": "#2d5f8d"},
+ {"x": 0.52, "y": 0.93, "width": 0.43, "height": 0.38, "title": "🔢 Statistics", "color": "#2d6b5f"},
+ {"x": 0.05, "y": 0.46, "width": 0.43, "height": 0.38, "title": "📝 Description", "color": "#5d4b7a"},
+ {"x": 0.52, "y": 0.46, "width": 0.43, "height": 0.38, "title": "📚 Dependencies", "color": "#7a5b3d"},
+ ]
fig = go.Figure()
+ # Draw background cards with consistent opacity
+ for card in cards:
+ fig.add_shape(
+ type="rect",
+ xref="paper",
+ yref="paper",
+ x0=card["x"],
+ y0=card["y"] - card["height"],
+ x1=card["x"] + card["width"],
+ y1=card["y"],
+ fillcolor=card["color"],
+ line=dict(color="#4a5568", width=2),
+ opacity=0.3,
+ layer="below",
+ )
+ # Header bar
+ fig.add_shape(
+ type="rect",
+ xref="paper",
+ yref="paper",
+ x0=card["x"],
+ y0=card["y"] - 0.07,
+ x1=card["x"] + card["width"],
+ y1=card["y"],
+ fillcolor=card["color"],
+ line=dict(width=0),
+ opacity=0.45,
+ layer="below",
+ )
+
+ # --- CARD 1: Model Overview ---
+ card = cards[0]
fig.add_annotation(
- text=metadata_text,
+ text=f"{card['title']}",
xref="paper",
yref="paper",
- x=0.05,
- y=0.95,
+ x=card["x"] + 0.03,
+ y=card["y"] - 0.02,
+ xanchor="left",
+ yanchor="middle",
+ showarrow=False,
+ align="left",
+ font=dict(size=15, color="#ffffff", family="Inter, Arial, sans-serif"),
+ )
+ fig.add_annotation(
+ text=f"{model_class.name}
"
+ f"PyTorch Neural Network",
+ xref="paper",
+ yref="paper",
+ x=card["x"] + 0.04,
+ y=card["y"] - 0.13,
xanchor="left",
yanchor="top",
showarrow=False,
align="left",
- borderwidth=2,
- bordercolor="#4a5568",
- bgcolor="#2d3748",
- font=dict(family="Arial, sans-serif", size=13, color="#e2e8f0"),
+ font=dict(size=15, color="#cbd5e0", family="Inter, Arial, sans-serif"),
)
+ # --- CARD 2: Statistics ---
+ card = cards[1]
+ y_center = card["y"] - card["height"] / 2
+ fig.add_annotation(
+ text=f"{card['title']}",
+ xref="paper",
+ yref="paper",
+ x=card["x"] + 0.03,
+ y=card["y"] - 0.02,
+ xanchor="left",
+ yanchor="middle",
+ showarrow=False,
+ font=dict(size=15, color="#ffffff", family="Inter, Arial, sans-serif"),
+ )
+ fig.add_annotation(
+ text=f"{len(layers)}
"
+ f"LAYERS",
+ xref="paper",
+ yref="paper",
+ x=card["x"] + card["width"] / 2,
+ y=y_center + 0.07,
+ xanchor="center",
+ yanchor="middle",
+ showarrow=False,
+ align="center",
+ )
+ fig.add_annotation(
+ text=f"~{param_count:,}
"
+ f"PARAMETERS",
+ xref="paper",
+ yref="paper",
+ x=card["x"] + card["width"] / 2,
+ y=y_center - 0.10,
+ xanchor="center",
+ yanchor="middle",
+ showarrow=False,
+ align="center",
+ )
+
+ # --- CARD 3: Description ---
+ card = cards[2]
+ fig.add_annotation(
+ text=f"{card['title']}",
+ xref="paper",
+ yref="paper",
+ x=card["x"] + 0.03,
+ y=card["y"] - 0.02,
+ xanchor="left",
+ yanchor="middle",
+ showarrow=False,
+ font=dict(size=15, color="#ffffff", family="Inter, Arial, sans-serif"),
+ )
+ fig.add_annotation(
+ text=f"{wrapped_doc}",
+ xref="paper",
+ yref="paper",
+ x=card["x"] + 0.04,
+ y=card["y"] - 0.13,
+ xanchor="left",
+ yanchor="top",
+ showarrow=False,
+ align="left",
+ )
+
+ # --- CARD 4: Dependencies ---
+ card = cards[3]
+ fig.add_annotation(
+ text=f"{card['title']}",
+ xref="paper",
+ yref="paper",
+ x=card["x"] + 0.03,
+ y=card["y"] - 0.02,
+ xanchor="left",
+ yanchor="middle",
+ showarrow=False,
+ font=dict(size=15, color="#ffffff", family="Inter, Arial, sans-serif"),
+ )
+ imports_text = (
+ "
".join(
+ [
+ f"▸ "
+ f"{imp}"
+ for imp in relevant_imports
+ ]
+ )
+ if relevant_imports
+ else "No torch imports detected"
+ )
+ fig.add_annotation(
+ text=imports_text,
+ xref="paper",
+ yref="paper",
+ x=card["x"] + 0.04,
+ y=card["y"] - 0.13,
+ xanchor="left",
+ yanchor="top",
+ showarrow=False,
+ align="left",
+ )
+
+ # Layout polish
fig.update_layout(
- title="Model Metadata",
+ title=dict(
+ text="Model Metadata",
+ font=dict(size=20, color="#e2e8f0", family="Inter, Arial, sans-serif"),
+ x=0.5,
+ xanchor="center",
+ ),
template="plotly_dark",
- height=450,
- margin=dict(l=40, r=40, t=60, b=40),
+ height=500,
+ margin=dict(l=20, r=20, t=70, b=20),
plot_bgcolor="#1a202c",
paper_bgcolor="#1a202c",
)
-
fig.update_xaxes(visible=False)
fig.update_yaxes(visible=False)
-
return fig