diff --git a/docs/source/conf.py b/docs/source/conf.py
index 81c014b..9db1d19 100644
--- a/docs/source/conf.py
+++ b/docs/source/conf.py
@@ -14,7 +14,7 @@ sys.path.insert(0, os.path.abspath(os.path.join('..', '..')))
project = 'ria-toolkit-oss'
copyright = '2025, Qoherent Inc'
author = 'Qoherent Inc.'
-release = '0.1.3'
+release = '0.1.4'
# -- General configuration ---------------------------------------------------
# https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration
diff --git a/pyproject.toml b/pyproject.toml
index 1c4709b..d06e3be 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -1,6 +1,6 @@
[project]
name = "ria-toolkit-oss"
-version = "0.1.3"
+version = "0.1.4"
description = "An open-source version of the RIA Toolkit, including the fundamental tools to get started developing, testing, and deploying radio intelligence applications"
license = { text = "AGPL-3.0-only" }
readme = "README.md"
diff --git a/src/ria_toolkit_oss/viz/pytorch_model.py b/src/ria_toolkit_oss/viz/pytorch_model.py
new file mode 100644
index 0000000..ef9733d
--- /dev/null
+++ b/src/ria_toolkit_oss/viz/pytorch_model.py
@@ -0,0 +1,560 @@
+"""Visualization functions for PyTorch model (.py) files.
+
+This module provides visualization capabilities for PyTorch model Python files,
+extracting architectural information through AST parsing and static analysis.
+"""
+
+import ast
+from pathlib import Path
+from typing import Any, Dict, List, Optional, Tuple
+
+import plotly.graph_objects as go
+from plotly.graph_objects import Figure
+
+
+def create_styled_error_figure(title: str, message: str, suggestion: str = None) -> go.Figure:
+ """Create a professional error figure with Qoherent dark theme styling."""
+ fig = go.Figure()
+
+ # Create a clean, centered text display using Plotly's text formatting
+ main_text = f"⚠️ {title}
"
+ main_text += f"{message}"
+
+ if suggestion:
+ main_text += "
💡 Suggestion:
"
+ main_text += f"{suggestion}"
+
+ # Add the main text annotation
+ fig.add_annotation(
+ text=main_text,
+ xref="paper",
+ yref="paper",
+ x=0.5,
+ y=0.5,
+ xanchor="center",
+ yanchor="middle",
+ showarrow=False,
+ align="center",
+ borderwidth=2,
+ bordercolor="#4a5568",
+ bgcolor="#2d3748",
+ font=dict(family="Arial, sans-serif", size=14, color="#e2e8f0"),
+ )
+
+ # Update layout with dark theme
+ fig.update_layout(
+ title="",
+ height=400,
+ template="plotly_dark",
+ margin=dict(l=40, r=40, t=40, b=40),
+ plot_bgcolor="#1a202c",
+ paper_bgcolor="#1a202c",
+ font=dict(color="#e2e8f0"),
+ )
+
+ # Remove axes and grid
+ fig.update_xaxes(visible=False)
+ fig.update_yaxes(visible=False)
+
+ return fig
+
+
+def _parse_model_file(file_path: Path) -> Tuple[Optional[ast.Module], Optional[str]]:
+ """Parse a Python model file and return the AST and any error message."""
+ try:
+ with open(file_path, "r", encoding="utf-8") as f:
+ code = f.read()
+ tree = ast.parse(code, filename=str(file_path))
+ return tree, None
+ except SyntaxError as e:
+ return None, f"Syntax error in file: {e}"
+ except Exception as e:
+ return None, f"Failed to parse file: {e}"
+
+
+def _find_model_class(tree: ast.Module) -> Optional[ast.ClassDef]:
+ """Find the main model class (subclass of nn.Module) in the AST."""
+ for node in ast.walk(tree):
+ if isinstance(node, ast.ClassDef):
+ # Check if it inherits from nn.Module or torch.nn.Module
+ for base in node.bases:
+ base_name = ""
+ if isinstance(base, ast.Name):
+ base_name = base.id
+ elif isinstance(base, ast.Attribute):
+ if isinstance(base.value, ast.Name):
+ base_name = f"{base.value.id}.{base.attr}"
+
+ if "Module" in base_name or "nn.Module" in base_name:
+ return node
+ return None
+
+
+def _extract_layer_info(model_class: ast.ClassDef) -> List[Dict[str, Any]]:
+ """Extract layer information from the model's __init__ method."""
+ layers = []
+
+ # Find __init__ method
+ init_method = None
+ for node in model_class.body:
+ if isinstance(node, ast.FunctionDef) and node.name == "__init__":
+ init_method = node
+ break
+
+ if not init_method:
+ return layers
+
+ # Parse assignments in __init__
+ for node in ast.walk(init_method):
+ if isinstance(node, ast.Assign):
+ for target in node.targets:
+ if isinstance(target, ast.Attribute):
+ layer_name = target.attr
+ layer_type = _extract_layer_type(node.value)
+ if layer_type:
+ layers.append(
+ {"name": layer_name, "type": layer_type, "details": _extract_layer_params(node.value)}
+ )
+
+ return layers
+
+
+def _extract_layer_type(node: ast.expr) -> Optional[str]:
+ """Extract the layer type from an AST node."""
+ if isinstance(node, ast.Call):
+ if isinstance(node.func, ast.Name):
+ return node.func.id
+ elif isinstance(node.func, ast.Attribute):
+ return node.func.attr
+ return None
+
+
+def _extract_layer_params(node: ast.Call) -> str:
+ """Extract layer parameters as a string."""
+ params = []
+
+ # Extract positional arguments
+ for arg in node.args:
+ if isinstance(arg, ast.Constant):
+ params.append(str(arg.value))
+ elif isinstance(arg, ast.Name):
+ params.append(arg.id)
+
+ # Extract keyword arguments
+ for keyword in node.keywords:
+ if isinstance(keyword.value, ast.Constant):
+ params.append(f"{keyword.arg}={keyword.value.value}")
+ elif isinstance(keyword.value, ast.Name):
+ params.append(f"{keyword.arg}={keyword.value.id}")
+
+ return ", ".join(params)
+
+
+def _count_parameters(layers: List[Dict[str, Any]]) -> int:
+ """Estimate parameter count from layer definitions (rough estimate)."""
+ # This is a very rough estimate - actual counts would require instantiating the model
+ param_estimates = {
+ "Linear": 1000,
+ "Conv1d": 500,
+ "Conv2d": 5000,
+ "Conv3d": 10000,
+ "LSTM": 4000,
+ "GRU": 3000,
+ "TransformerEncoder": 50000,
+ "Embedding": 10000,
+ }
+
+ total = 0
+ for layer in layers:
+ layer_type = layer["type"]
+ total += param_estimates.get(layer_type, 100)
+
+ return total
+
+
+def model_architecture_plot(file_path: Path) -> Figure:
+ """Visualize the architecture of a PyTorch model from its .py file.
+
+ Parses the model file using AST to extract layers and their connections.
+ """
+ tree, error = _parse_model_file(file_path)
+
+ if error:
+ return create_styled_error_figure(
+ "Parse Error", error, "Ensure the .py file contains valid Python code with a PyTorch nn.Module class"
+ )
+
+ model_class = _find_model_class(tree)
+ if not model_class:
+ return create_styled_error_figure(
+ "No Model Found",
+ "Could not find a PyTorch nn.Module class in the file",
+ "Ensure your model class inherits from torch.nn.Module or nn.Module",
+ )
+
+ layers = _extract_layer_info(model_class)
+
+ if not layers:
+ return create_styled_error_figure(
+ "No Layers Found",
+ "Could not extract layer information from the model",
+ "Ensure your model defines layers in the __init__ method",
+ )
+
+ # Create a hierarchical visualization
+ layer_names = [f"{i+1}. {layer['name']}" for i, layer in enumerate(layers)]
+ layer_types = [layer["type"] for layer in layers]
+ layer_details = [layer["details"] for layer in layers]
+
+ # Create a bar chart showing layers
+ fig = go.Figure()
+
+ fig.add_trace(
+ go.Bar(
+ y=layer_names,
+ x=[1] * len(layer_names),
+ orientation="h",
+ text=layer_types,
+ textposition="inside",
+ hovertext=[
+ f"{name}
Type: {type_}
Params: {details}"
+ for name, type_, details in zip(layer_names, layer_types, layer_details)
+ ],
+ hoverinfo="text",
+ marker=dict(color="rgba(99, 179, 237, 0.8)", line=dict(color="rgba(99, 179, 237, 1.0)", width=2)),
+ )
+ )
+
+ fig.update_layout(
+ title=f"Model Architecture: {model_class.name}",
+ xaxis=dict(visible=False),
+ yaxis=dict(title="Layers", autorange="reversed"),
+ template="plotly_dark",
+ height=max(400, len(layers) * 40),
+ showlegend=False,
+ margin=dict(l=200, r=40, t=60, b=40),
+ )
+
+ return fig
+
+
+def model_complexity_plot(file_path: Path) -> Figure:
+ """Analyze and visualize model complexity metrics."""
+ tree, error = _parse_model_file(file_path)
+
+ if error:
+ return create_styled_error_figure("Parse Error", error, "Ensure the .py file contains valid Python code")
+
+ model_class = _find_model_class(tree)
+ if not model_class:
+ return create_styled_error_figure("No Model Found", "Could not find a PyTorch nn.Module class in the file")
+
+ layers = _extract_layer_info(model_class)
+
+ if not layers:
+ return create_styled_error_figure("No Layers Found", "Could not extract layer information from the model")
+
+ # Count layer types
+ layer_type_counts = {}
+ for layer in layers:
+ layer_type = layer["type"]
+ layer_type_counts[layer_type] = layer_type_counts.get(layer_type, 0) + 1
+
+ # Create pie chart of layer types
+ fig = go.Figure(
+ data=[
+ go.Pie(
+ labels=list(layer_type_counts.keys()),
+ values=list(layer_type_counts.values()),
+ hole=0.3,
+ marker=dict(colors=["#5c79ff", "#63b3ed", "#48bb78", "#f6ad55", "#fc8181"]),
+ )
+ ]
+ )
+
+ fig.update_layout(
+ title="Layer Type Distribution",
+ template="plotly_dark",
+ height=400,
+ )
+
+ return fig
+
+
+def model_metadata_plot(file_path: Path) -> Figure:
+ """Display model metadata and information extracted from the Python file (clean, aligned layout)."""
+ import textwrap
+
+ tree, error = _parse_model_file(file_path)
+ if error:
+ return create_styled_error_figure("Parse Error", error, "Ensure the .py file contains valid Python code")
+
+ model_class = _find_model_class(tree)
+ if not model_class:
+ return create_styled_error_figure("No Model Found", "Could not find a PyTorch nn.Module class in the file")
+
+ layers = _extract_layer_info(model_class)
+
+ # Extract imports
+ imports = []
+ for node in tree.body:
+ if isinstance(node, ast.Import):
+ for alias in node.names:
+ imports.append(alias.name)
+ elif isinstance(node, ast.ImportFrom) and node.module:
+ imports.append(node.module)
+
+ # Get docstring and wrap it
+ docstring = ast.get_docstring(model_class) or "No docstring available"
+ wrapped_doc = "
".join(textwrap.wrap(docstring, width=70))
+
+ relevant_imports = [imp for imp in imports if "torch" in imp or "nn" in imp][:4]
+ param_count = _count_parameters(layers)
+
+ # Define card grid (aligned 2x2)
+ cards = [
+ {"x": 0.05, "y": 0.93, "width": 0.43, "height": 0.38, "title": "📦 Model Overview", "color": "#2d5f8d"},
+ {"x": 0.52, "y": 0.93, "width": 0.43, "height": 0.38, "title": "🔢 Statistics", "color": "#2d6b5f"},
+ {"x": 0.05, "y": 0.46, "width": 0.43, "height": 0.38, "title": "📝 Description", "color": "#5d4b7a"},
+ {"x": 0.52, "y": 0.46, "width": 0.43, "height": 0.38, "title": "📚 Dependencies", "color": "#7a5b3d"},
+ ]
+
+ fig = go.Figure()
+
+ # Draw background cards with consistent opacity
+ for card in cards:
+ fig.add_shape(
+ type="rect",
+ xref="paper",
+ yref="paper",
+ x0=card["x"],
+ y0=card["y"] - card["height"],
+ x1=card["x"] + card["width"],
+ y1=card["y"],
+ fillcolor=card["color"],
+ line=dict(color="#4a5568", width=2),
+ opacity=0.3,
+ layer="below",
+ )
+ # Header bar
+ fig.add_shape(
+ type="rect",
+ xref="paper",
+ yref="paper",
+ x0=card["x"],
+ y0=card["y"] - 0.07,
+ x1=card["x"] + card["width"],
+ y1=card["y"],
+ fillcolor=card["color"],
+ line=dict(width=0),
+ opacity=0.45,
+ layer="below",
+ )
+
+ # --- CARD 1: Model Overview ---
+ card = cards[0]
+ fig.add_annotation(
+ text=f"{card['title']}",
+ xref="paper",
+ yref="paper",
+ x=card["x"] + 0.03,
+ y=card["y"] - 0.02,
+ xanchor="left",
+ yanchor="middle",
+ showarrow=False,
+ align="left",
+ font=dict(size=15, color="#ffffff", family="Inter, Arial, sans-serif"),
+ )
+ fig.add_annotation(
+ text=f"{model_class.name}
"
+ f"PyTorch Neural Network",
+ xref="paper",
+ yref="paper",
+ x=card["x"] + 0.04,
+ y=card["y"] - 0.13,
+ xanchor="left",
+ yanchor="top",
+ showarrow=False,
+ align="left",
+ font=dict(size=15, color="#cbd5e0", family="Inter, Arial, sans-serif"),
+ )
+
+ # --- CARD 2: Statistics ---
+ card = cards[1]
+ y_center = card["y"] - card["height"] / 2
+ fig.add_annotation(
+ text=f"{card['title']}",
+ xref="paper",
+ yref="paper",
+ x=card["x"] + 0.03,
+ y=card["y"] - 0.02,
+ xanchor="left",
+ yanchor="middle",
+ showarrow=False,
+ font=dict(size=15, color="#ffffff", family="Inter, Arial, sans-serif"),
+ )
+ fig.add_annotation(
+ text=f"{len(layers)}
"
+ f"LAYERS",
+ xref="paper",
+ yref="paper",
+ x=card["x"] + card["width"] / 2,
+ y=y_center + 0.07,
+ xanchor="center",
+ yanchor="middle",
+ showarrow=False,
+ align="center",
+ )
+ fig.add_annotation(
+ text=f"~{param_count:,}
"
+ f"PARAMETERS",
+ xref="paper",
+ yref="paper",
+ x=card["x"] + card["width"] / 2,
+ y=y_center - 0.10,
+ xanchor="center",
+ yanchor="middle",
+ showarrow=False,
+ align="center",
+ )
+
+ # --- CARD 3: Description ---
+ card = cards[2]
+ fig.add_annotation(
+ text=f"{card['title']}",
+ xref="paper",
+ yref="paper",
+ x=card["x"] + 0.03,
+ y=card["y"] - 0.02,
+ xanchor="left",
+ yanchor="middle",
+ showarrow=False,
+ font=dict(size=15, color="#ffffff", family="Inter, Arial, sans-serif"),
+ )
+ fig.add_annotation(
+ text=f"{wrapped_doc}",
+ xref="paper",
+ yref="paper",
+ x=card["x"] + 0.04,
+ y=card["y"] - 0.13,
+ xanchor="left",
+ yanchor="top",
+ showarrow=False,
+ align="left",
+ )
+
+ # --- CARD 4: Dependencies ---
+ card = cards[3]
+ fig.add_annotation(
+ text=f"{card['title']}",
+ xref="paper",
+ yref="paper",
+ x=card["x"] + 0.03,
+ y=card["y"] - 0.02,
+ xanchor="left",
+ yanchor="middle",
+ showarrow=False,
+ font=dict(size=15, color="#ffffff", family="Inter, Arial, sans-serif"),
+ )
+ imports_text = (
+ "
".join(
+ [
+ f"▸ "
+ f"{imp}"
+ for imp in relevant_imports
+ ]
+ )
+ if relevant_imports
+ else "No torch imports detected"
+ )
+ fig.add_annotation(
+ text=imports_text,
+ xref="paper",
+ yref="paper",
+ x=card["x"] + 0.04,
+ y=card["y"] - 0.13,
+ xanchor="left",
+ yanchor="top",
+ showarrow=False,
+ align="left",
+ )
+
+ # Layout polish
+ fig.update_layout(
+ title=dict(
+ text="Model Metadata",
+ font=dict(size=20, color="#e2e8f0", family="Inter, Arial, sans-serif"),
+ x=0.5,
+ xanchor="center",
+ ),
+ template="plotly_dark",
+ height=500,
+ margin=dict(l=20, r=20, t=70, b=20),
+ plot_bgcolor="#1a202c",
+ paper_bgcolor="#1a202c",
+ )
+ fig.update_xaxes(visible=False)
+ fig.update_yaxes(visible=False)
+ return fig
+
+
+def code_structure_plot(file_path: Path) -> Figure:
+ """Visualize the code structure and method definitions in the model."""
+ tree, error = _parse_model_file(file_path)
+
+ if error:
+ return create_styled_error_figure("Parse Error", error, "Ensure the .py file contains valid Python code")
+
+ model_class = _find_model_class(tree)
+ if not model_class:
+ return create_styled_error_figure("No Model Found", "Could not find a PyTorch nn.Module class in the file")
+
+ # Extract methods
+ methods = []
+ for node in model_class.body:
+ if isinstance(node, ast.FunctionDef):
+ # Count lines in method
+ if hasattr(node, "end_lineno") and hasattr(node, "lineno"):
+ lines = node.end_lineno - node.lineno + 1
+ else:
+ lines = 1
+
+ methods.append({"name": node.name, "lines": lines, "args": len(node.args.args) - 1}) # Exclude self
+
+ if not methods:
+ return create_styled_error_figure(
+ "No Methods Found", "Could not extract method information from the model class"
+ )
+
+ # Create visualization of methods
+ method_names = [m["name"] for m in methods]
+ method_lines = [m["lines"] for m in methods]
+ method_args = [m["args"] for m in methods]
+
+ fig = go.Figure()
+
+ # Bar chart for method complexity (lines of code)
+ fig.add_trace(
+ go.Bar(
+ x=method_names,
+ y=method_lines,
+ name="Lines of Code",
+ marker=dict(color="rgba(99, 179, 237, 0.8)"),
+ hovertext=[
+ f"{name}
Lines: {lines}
Arguments: {args}"
+ for name, lines, args in zip(method_names, method_lines, method_args)
+ ],
+ hoverinfo="text",
+ )
+ )
+
+ fig.update_layout(
+ title=f"Method Complexity - {model_class.name}",
+ xaxis_title="Methods",
+ yaxis_title="Lines of Code",
+ template="plotly_dark",
+ height=400,
+ showlegend=False,
+ )
+
+ return fig