diff --git a/src/ria_toolkit_oss/viz/pytorch_model.py b/src/ria_toolkit_oss/viz/pytorch_model.py index 2376ce3..ef9733d 100644 --- a/src/ria_toolkit_oss/viz/pytorch_model.py +++ b/src/ria_toolkit_oss/viz/pytorch_model.py @@ -5,7 +5,6 @@ extracting architectural information through AST parsing and static analysis. """ import ast -import re from pathlib import Path from typing import Any, Dict, List, Optional, Tuple @@ -283,9 +282,10 @@ def model_complexity_plot(file_path: Path) -> Figure: def model_metadata_plot(file_path: Path) -> Figure: - """Display model metadata and information extracted from the Python file.""" - tree, error = _parse_model_file(file_path) + """Display model metadata and information extracted from the Python file (clean, aligned layout).""" + import textwrap + tree, error = _parse_model_file(file_path) if error: return create_styled_error_figure("Parse Error", error, "Ensure the .py file contains valid Python code") @@ -301,56 +301,200 @@ def model_metadata_plot(file_path: Path) -> Figure: if isinstance(node, ast.Import): for alias in node.names: imports.append(alias.name) - elif isinstance(node, ast.ImportFrom): - if node.module: - imports.append(node.module) + elif isinstance(node, ast.ImportFrom) and node.module: + imports.append(node.module) - # Get docstring + # Get docstring and wrap it docstring = ast.get_docstring(model_class) or "No docstring available" - if len(docstring) > 200: - docstring = docstring[:200] + "..." + wrapped_doc = "
".join(textwrap.wrap(docstring, width=70)) - # Build metadata display - metadata_text = f"""Model: {model_class.name}

""" - metadata_text += f"📝 Description:
{docstring}

" - metadata_text += f"🔢 Number of Layers: {len(layers)}
" - metadata_text += f"📦 Estimated Parameters: ~{_count_parameters(layers):,}

" - metadata_text += f"📚 Key Imports:
" + relevant_imports = [imp for imp in imports if "torch" in imp or "nn" in imp][:4] + param_count = _count_parameters(layers) - relevant_imports = [imp for imp in imports if "torch" in imp or "nn" in imp][:5] - for imp in relevant_imports: - metadata_text += f" • {imp}
" + # Define card grid (aligned 2x2) + cards = [ + {"x": 0.05, "y": 0.93, "width": 0.43, "height": 0.38, "title": "📦 Model Overview", "color": "#2d5f8d"}, + {"x": 0.52, "y": 0.93, "width": 0.43, "height": 0.38, "title": "🔢 Statistics", "color": "#2d6b5f"}, + {"x": 0.05, "y": 0.46, "width": 0.43, "height": 0.38, "title": "📝 Description", "color": "#5d4b7a"}, + {"x": 0.52, "y": 0.46, "width": 0.43, "height": 0.38, "title": "📚 Dependencies", "color": "#7a5b3d"}, + ] fig = go.Figure() + # Draw background cards with consistent opacity + for card in cards: + fig.add_shape( + type="rect", + xref="paper", + yref="paper", + x0=card["x"], + y0=card["y"] - card["height"], + x1=card["x"] + card["width"], + y1=card["y"], + fillcolor=card["color"], + line=dict(color="#4a5568", width=2), + opacity=0.3, + layer="below", + ) + # Header bar + fig.add_shape( + type="rect", + xref="paper", + yref="paper", + x0=card["x"], + y0=card["y"] - 0.07, + x1=card["x"] + card["width"], + y1=card["y"], + fillcolor=card["color"], + line=dict(width=0), + opacity=0.45, + layer="below", + ) + + # --- CARD 1: Model Overview --- + card = cards[0] fig.add_annotation( - text=metadata_text, + text=f"{card['title']}", xref="paper", yref="paper", - x=0.05, - y=0.95, + x=card["x"] + 0.03, + y=card["y"] - 0.02, + xanchor="left", + yanchor="middle", + showarrow=False, + align="left", + font=dict(size=15, color="#ffffff", family="Inter, Arial, sans-serif"), + ) + fig.add_annotation( + text=f"{model_class.name}
" + f"PyTorch Neural Network", + xref="paper", + yref="paper", + x=card["x"] + 0.04, + y=card["y"] - 0.13, xanchor="left", yanchor="top", showarrow=False, align="left", - borderwidth=2, - bordercolor="#4a5568", - bgcolor="#2d3748", - font=dict(family="Arial, sans-serif", size=13, color="#e2e8f0"), + font=dict(size=15, color="#cbd5e0", family="Inter, Arial, sans-serif"), ) + # --- CARD 2: Statistics --- + card = cards[1] + y_center = card["y"] - card["height"] / 2 + fig.add_annotation( + text=f"{card['title']}", + xref="paper", + yref="paper", + x=card["x"] + 0.03, + y=card["y"] - 0.02, + xanchor="left", + yanchor="middle", + showarrow=False, + font=dict(size=15, color="#ffffff", family="Inter, Arial, sans-serif"), + ) + fig.add_annotation( + text=f"{len(layers)}
" + f"LAYERS", + xref="paper", + yref="paper", + x=card["x"] + card["width"] / 2, + y=y_center + 0.07, + xanchor="center", + yanchor="middle", + showarrow=False, + align="center", + ) + fig.add_annotation( + text=f"~{param_count:,}
" + f"PARAMETERS", + xref="paper", + yref="paper", + x=card["x"] + card["width"] / 2, + y=y_center - 0.10, + xanchor="center", + yanchor="middle", + showarrow=False, + align="center", + ) + + # --- CARD 3: Description --- + card = cards[2] + fig.add_annotation( + text=f"{card['title']}", + xref="paper", + yref="paper", + x=card["x"] + 0.03, + y=card["y"] - 0.02, + xanchor="left", + yanchor="middle", + showarrow=False, + font=dict(size=15, color="#ffffff", family="Inter, Arial, sans-serif"), + ) + fig.add_annotation( + text=f"{wrapped_doc}", + xref="paper", + yref="paper", + x=card["x"] + 0.04, + y=card["y"] - 0.13, + xanchor="left", + yanchor="top", + showarrow=False, + align="left", + ) + + # --- CARD 4: Dependencies --- + card = cards[3] + fig.add_annotation( + text=f"{card['title']}", + xref="paper", + yref="paper", + x=card["x"] + 0.03, + y=card["y"] - 0.02, + xanchor="left", + yanchor="middle", + showarrow=False, + font=dict(size=15, color="#ffffff", family="Inter, Arial, sans-serif"), + ) + imports_text = ( + "
".join( + [ + f" " + f"{imp}" + for imp in relevant_imports + ] + ) + if relevant_imports + else "No torch imports detected" + ) + fig.add_annotation( + text=imports_text, + xref="paper", + yref="paper", + x=card["x"] + 0.04, + y=card["y"] - 0.13, + xanchor="left", + yanchor="top", + showarrow=False, + align="left", + ) + + # Layout polish fig.update_layout( - title="Model Metadata", + title=dict( + text="Model Metadata", + font=dict(size=20, color="#e2e8f0", family="Inter, Arial, sans-serif"), + x=0.5, + xanchor="center", + ), template="plotly_dark", - height=450, - margin=dict(l=40, r=40, t=60, b=40), + height=500, + margin=dict(l=20, r=20, t=70, b=20), plot_bgcolor="#1a202c", paper_bgcolor="#1a202c", ) - fig.update_xaxes(visible=False) fig.update_yaxes(visible=False) - return fig