Compare commits
No commits in common. "b8ccead21ec82843b4438edf127b39fa8ebc76f2" and "8105b829be20747757ba0ef26bccec325bd3b8fb" have entirely different histories.
b8ccead21e
...
8105b829be
|
|
@ -14,7 +14,7 @@ sys.path.insert(0, os.path.abspath(os.path.join('..', '..')))
|
|||
project = 'ria-toolkit-oss'
|
||||
copyright = '2025, Qoherent Inc'
|
||||
author = 'Qoherent Inc.'
|
||||
release = '0.1.4'
|
||||
release = '0.1.3'
|
||||
|
||||
# -- General configuration ---------------------------------------------------
|
||||
# https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
[project]
|
||||
name = "ria-toolkit-oss"
|
||||
version = "0.1.4"
|
||||
version = "0.1.3"
|
||||
description = "An open-source version of the RIA Toolkit, including the fundamental tools to get started developing, testing, and deploying radio intelligence applications"
|
||||
license = { text = "AGPL-3.0-only" }
|
||||
readme = "README.md"
|
||||
|
|
|
|||
|
|
@ -1,416 +0,0 @@
|
|||
"""Visualization functions for PyTorch model (.py) files.
|
||||
|
||||
This module provides visualization capabilities for PyTorch model Python files,
|
||||
extracting architectural information through AST parsing and static analysis.
|
||||
"""
|
||||
|
||||
import ast
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import plotly.graph_objects as go
|
||||
from plotly.graph_objects import Figure
|
||||
|
||||
|
||||
def create_styled_error_figure(title: str, message: str, suggestion: str = None) -> go.Figure:
|
||||
"""Create a professional error figure with Qoherent dark theme styling."""
|
||||
fig = go.Figure()
|
||||
|
||||
# Create a clean, centered text display using Plotly's text formatting
|
||||
main_text = f"<b style='color:#f56565;font-size:18px'>⚠️ {title}</b><br><br>"
|
||||
main_text += f"<span style='color:#e2e8f0;font-size:14px'>{message}</span>"
|
||||
|
||||
if suggestion:
|
||||
main_text += "<br><br><span style='color:#63b3ed;font-size:13px'>💡 <b>Suggestion:</b></span><br>"
|
||||
main_text += f"<span style='color:#cbd5e0;font-size:12px'>{suggestion}</span>"
|
||||
|
||||
# Add the main text annotation
|
||||
fig.add_annotation(
|
||||
text=main_text,
|
||||
xref="paper",
|
||||
yref="paper",
|
||||
x=0.5,
|
||||
y=0.5,
|
||||
xanchor="center",
|
||||
yanchor="middle",
|
||||
showarrow=False,
|
||||
align="center",
|
||||
borderwidth=2,
|
||||
bordercolor="#4a5568",
|
||||
bgcolor="#2d3748",
|
||||
font=dict(family="Arial, sans-serif", size=14, color="#e2e8f0"),
|
||||
)
|
||||
|
||||
# Update layout with dark theme
|
||||
fig.update_layout(
|
||||
title="",
|
||||
height=400,
|
||||
template="plotly_dark",
|
||||
margin=dict(l=40, r=40, t=40, b=40),
|
||||
plot_bgcolor="#1a202c",
|
||||
paper_bgcolor="#1a202c",
|
||||
font=dict(color="#e2e8f0"),
|
||||
)
|
||||
|
||||
# Remove axes and grid
|
||||
fig.update_xaxes(visible=False)
|
||||
fig.update_yaxes(visible=False)
|
||||
|
||||
return fig
|
||||
|
||||
|
||||
def _parse_model_file(file_path: Path) -> Tuple[Optional[ast.Module], Optional[str]]:
|
||||
"""Parse a Python model file and return the AST and any error message."""
|
||||
try:
|
||||
with open(file_path, "r", encoding="utf-8") as f:
|
||||
code = f.read()
|
||||
tree = ast.parse(code, filename=str(file_path))
|
||||
return tree, None
|
||||
except SyntaxError as e:
|
||||
return None, f"Syntax error in file: {e}"
|
||||
except Exception as e:
|
||||
return None, f"Failed to parse file: {e}"
|
||||
|
||||
|
||||
def _find_model_class(tree: ast.Module) -> Optional[ast.ClassDef]:
|
||||
"""Find the main model class (subclass of nn.Module) in the AST."""
|
||||
for node in ast.walk(tree):
|
||||
if isinstance(node, ast.ClassDef):
|
||||
# Check if it inherits from nn.Module or torch.nn.Module
|
||||
for base in node.bases:
|
||||
base_name = ""
|
||||
if isinstance(base, ast.Name):
|
||||
base_name = base.id
|
||||
elif isinstance(base, ast.Attribute):
|
||||
if isinstance(base.value, ast.Name):
|
||||
base_name = f"{base.value.id}.{base.attr}"
|
||||
|
||||
if "Module" in base_name or "nn.Module" in base_name:
|
||||
return node
|
||||
return None
|
||||
|
||||
|
||||
def _extract_layer_info(model_class: ast.ClassDef) -> List[Dict[str, Any]]:
|
||||
"""Extract layer information from the model's __init__ method."""
|
||||
layers = []
|
||||
|
||||
# Find __init__ method
|
||||
init_method = None
|
||||
for node in model_class.body:
|
||||
if isinstance(node, ast.FunctionDef) and node.name == "__init__":
|
||||
init_method = node
|
||||
break
|
||||
|
||||
if not init_method:
|
||||
return layers
|
||||
|
||||
# Parse assignments in __init__
|
||||
for node in ast.walk(init_method):
|
||||
if isinstance(node, ast.Assign):
|
||||
for target in node.targets:
|
||||
if isinstance(target, ast.Attribute):
|
||||
layer_name = target.attr
|
||||
layer_type = _extract_layer_type(node.value)
|
||||
if layer_type:
|
||||
layers.append(
|
||||
{"name": layer_name, "type": layer_type, "details": _extract_layer_params(node.value)}
|
||||
)
|
||||
|
||||
return layers
|
||||
|
||||
|
||||
def _extract_layer_type(node: ast.expr) -> Optional[str]:
|
||||
"""Extract the layer type from an AST node."""
|
||||
if isinstance(node, ast.Call):
|
||||
if isinstance(node.func, ast.Name):
|
||||
return node.func.id
|
||||
elif isinstance(node.func, ast.Attribute):
|
||||
return node.func.attr
|
||||
return None
|
||||
|
||||
|
||||
def _extract_layer_params(node: ast.Call) -> str:
|
||||
"""Extract layer parameters as a string."""
|
||||
params = []
|
||||
|
||||
# Extract positional arguments
|
||||
for arg in node.args:
|
||||
if isinstance(arg, ast.Constant):
|
||||
params.append(str(arg.value))
|
||||
elif isinstance(arg, ast.Name):
|
||||
params.append(arg.id)
|
||||
|
||||
# Extract keyword arguments
|
||||
for keyword in node.keywords:
|
||||
if isinstance(keyword.value, ast.Constant):
|
||||
params.append(f"{keyword.arg}={keyword.value.value}")
|
||||
elif isinstance(keyword.value, ast.Name):
|
||||
params.append(f"{keyword.arg}={keyword.value.id}")
|
||||
|
||||
return ", ".join(params)
|
||||
|
||||
|
||||
def _count_parameters(layers: List[Dict[str, Any]]) -> int:
|
||||
"""Estimate parameter count from layer definitions (rough estimate)."""
|
||||
# This is a very rough estimate - actual counts would require instantiating the model
|
||||
param_estimates = {
|
||||
"Linear": 1000,
|
||||
"Conv1d": 500,
|
||||
"Conv2d": 5000,
|
||||
"Conv3d": 10000,
|
||||
"LSTM": 4000,
|
||||
"GRU": 3000,
|
||||
"TransformerEncoder": 50000,
|
||||
"Embedding": 10000,
|
||||
}
|
||||
|
||||
total = 0
|
||||
for layer in layers:
|
||||
layer_type = layer["type"]
|
||||
total += param_estimates.get(layer_type, 100)
|
||||
|
||||
return total
|
||||
|
||||
|
||||
def model_architecture_plot(file_path: Path) -> Figure:
|
||||
"""Visualize the architecture of a PyTorch model from its .py file.
|
||||
|
||||
Parses the model file using AST to extract layers and their connections.
|
||||
"""
|
||||
tree, error = _parse_model_file(file_path)
|
||||
|
||||
if error:
|
||||
return create_styled_error_figure(
|
||||
"Parse Error", error, "Ensure the .py file contains valid Python code with a PyTorch nn.Module class"
|
||||
)
|
||||
|
||||
model_class = _find_model_class(tree)
|
||||
if not model_class:
|
||||
return create_styled_error_figure(
|
||||
"No Model Found",
|
||||
"Could not find a PyTorch nn.Module class in the file",
|
||||
"Ensure your model class inherits from torch.nn.Module or nn.Module",
|
||||
)
|
||||
|
||||
layers = _extract_layer_info(model_class)
|
||||
|
||||
if not layers:
|
||||
return create_styled_error_figure(
|
||||
"No Layers Found",
|
||||
"Could not extract layer information from the model",
|
||||
"Ensure your model defines layers in the __init__ method",
|
||||
)
|
||||
|
||||
# Create a hierarchical visualization
|
||||
layer_names = [f"{i+1}. {layer['name']}" for i, layer in enumerate(layers)]
|
||||
layer_types = [layer["type"] for layer in layers]
|
||||
layer_details = [layer["details"] for layer in layers]
|
||||
|
||||
# Create a bar chart showing layers
|
||||
fig = go.Figure()
|
||||
|
||||
fig.add_trace(
|
||||
go.Bar(
|
||||
y=layer_names,
|
||||
x=[1] * len(layer_names),
|
||||
orientation="h",
|
||||
text=layer_types,
|
||||
textposition="inside",
|
||||
hovertext=[
|
||||
f"{name}<br>Type: {type_}<br>Params: {details}"
|
||||
for name, type_, details in zip(layer_names, layer_types, layer_details)
|
||||
],
|
||||
hoverinfo="text",
|
||||
marker=dict(color="rgba(99, 179, 237, 0.8)", line=dict(color="rgba(99, 179, 237, 1.0)", width=2)),
|
||||
)
|
||||
)
|
||||
|
||||
fig.update_layout(
|
||||
title=f"Model Architecture: {model_class.name}",
|
||||
xaxis=dict(visible=False),
|
||||
yaxis=dict(title="Layers", autorange="reversed"),
|
||||
template="plotly_dark",
|
||||
height=max(400, len(layers) * 40),
|
||||
showlegend=False,
|
||||
margin=dict(l=200, r=40, t=60, b=40),
|
||||
)
|
||||
|
||||
return fig
|
||||
|
||||
|
||||
def model_complexity_plot(file_path: Path) -> Figure:
|
||||
"""Analyze and visualize model complexity metrics."""
|
||||
tree, error = _parse_model_file(file_path)
|
||||
|
||||
if error:
|
||||
return create_styled_error_figure("Parse Error", error, "Ensure the .py file contains valid Python code")
|
||||
|
||||
model_class = _find_model_class(tree)
|
||||
if not model_class:
|
||||
return create_styled_error_figure("No Model Found", "Could not find a PyTorch nn.Module class in the file")
|
||||
|
||||
layers = _extract_layer_info(model_class)
|
||||
|
||||
if not layers:
|
||||
return create_styled_error_figure("No Layers Found", "Could not extract layer information from the model")
|
||||
|
||||
# Count layer types
|
||||
layer_type_counts = {}
|
||||
for layer in layers:
|
||||
layer_type = layer["type"]
|
||||
layer_type_counts[layer_type] = layer_type_counts.get(layer_type, 0) + 1
|
||||
|
||||
# Create pie chart of layer types
|
||||
fig = go.Figure(
|
||||
data=[
|
||||
go.Pie(
|
||||
labels=list(layer_type_counts.keys()),
|
||||
values=list(layer_type_counts.values()),
|
||||
hole=0.3,
|
||||
marker=dict(colors=["#5c79ff", "#63b3ed", "#48bb78", "#f6ad55", "#fc8181"]),
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
fig.update_layout(
|
||||
title="Layer Type Distribution",
|
||||
template="plotly_dark",
|
||||
height=400,
|
||||
)
|
||||
|
||||
return fig
|
||||
|
||||
|
||||
def model_metadata_plot(file_path: Path) -> Figure:
|
||||
"""Display model metadata and information extracted from the Python file."""
|
||||
tree, error = _parse_model_file(file_path)
|
||||
|
||||
if error:
|
||||
return create_styled_error_figure("Parse Error", error, "Ensure the .py file contains valid Python code")
|
||||
|
||||
model_class = _find_model_class(tree)
|
||||
if not model_class:
|
||||
return create_styled_error_figure("No Model Found", "Could not find a PyTorch nn.Module class in the file")
|
||||
|
||||
layers = _extract_layer_info(model_class)
|
||||
|
||||
# Extract imports
|
||||
imports = []
|
||||
for node in tree.body:
|
||||
if isinstance(node, ast.Import):
|
||||
for alias in node.names:
|
||||
imports.append(alias.name)
|
||||
elif isinstance(node, ast.ImportFrom):
|
||||
if node.module:
|
||||
imports.append(node.module)
|
||||
|
||||
# Get docstring
|
||||
docstring = ast.get_docstring(model_class) or "No docstring available"
|
||||
if len(docstring) > 200:
|
||||
docstring = docstring[:200] + "..."
|
||||
|
||||
# Build metadata display
|
||||
metadata_text = f"""<b style='font-size:16px;color:#63b3ed'>Model: {model_class.name}</b><br><br>"""
|
||||
metadata_text += f"<b>📝 Description:</b><br><span style='color:#cbd5e0'>{docstring}</span><br><br>"
|
||||
metadata_text += f"<b>🔢 Number of Layers:</b> {len(layers)}<br>"
|
||||
metadata_text += f"<b>📦 Estimated Parameters:</b> ~{_count_parameters(layers):,}<br><br>"
|
||||
metadata_text += f"<b>📚 Key Imports:</b><br>"
|
||||
|
||||
relevant_imports = [imp for imp in imports if "torch" in imp or "nn" in imp][:5]
|
||||
for imp in relevant_imports:
|
||||
metadata_text += f" • {imp}<br>"
|
||||
|
||||
fig = go.Figure()
|
||||
|
||||
fig.add_annotation(
|
||||
text=metadata_text,
|
||||
xref="paper",
|
||||
yref="paper",
|
||||
x=0.05,
|
||||
y=0.95,
|
||||
xanchor="left",
|
||||
yanchor="top",
|
||||
showarrow=False,
|
||||
align="left",
|
||||
borderwidth=2,
|
||||
bordercolor="#4a5568",
|
||||
bgcolor="#2d3748",
|
||||
font=dict(family="Arial, sans-serif", size=13, color="#e2e8f0"),
|
||||
)
|
||||
|
||||
fig.update_layout(
|
||||
title="Model Metadata",
|
||||
template="plotly_dark",
|
||||
height=450,
|
||||
margin=dict(l=40, r=40, t=60, b=40),
|
||||
plot_bgcolor="#1a202c",
|
||||
paper_bgcolor="#1a202c",
|
||||
)
|
||||
|
||||
fig.update_xaxes(visible=False)
|
||||
fig.update_yaxes(visible=False)
|
||||
|
||||
return fig
|
||||
|
||||
|
||||
def code_structure_plot(file_path: Path) -> Figure:
|
||||
"""Visualize the code structure and method definitions in the model."""
|
||||
tree, error = _parse_model_file(file_path)
|
||||
|
||||
if error:
|
||||
return create_styled_error_figure("Parse Error", error, "Ensure the .py file contains valid Python code")
|
||||
|
||||
model_class = _find_model_class(tree)
|
||||
if not model_class:
|
||||
return create_styled_error_figure("No Model Found", "Could not find a PyTorch nn.Module class in the file")
|
||||
|
||||
# Extract methods
|
||||
methods = []
|
||||
for node in model_class.body:
|
||||
if isinstance(node, ast.FunctionDef):
|
||||
# Count lines in method
|
||||
if hasattr(node, "end_lineno") and hasattr(node, "lineno"):
|
||||
lines = node.end_lineno - node.lineno + 1
|
||||
else:
|
||||
lines = 1
|
||||
|
||||
methods.append({"name": node.name, "lines": lines, "args": len(node.args.args) - 1}) # Exclude self
|
||||
|
||||
if not methods:
|
||||
return create_styled_error_figure(
|
||||
"No Methods Found", "Could not extract method information from the model class"
|
||||
)
|
||||
|
||||
# Create visualization of methods
|
||||
method_names = [m["name"] for m in methods]
|
||||
method_lines = [m["lines"] for m in methods]
|
||||
method_args = [m["args"] for m in methods]
|
||||
|
||||
fig = go.Figure()
|
||||
|
||||
# Bar chart for method complexity (lines of code)
|
||||
fig.add_trace(
|
||||
go.Bar(
|
||||
x=method_names,
|
||||
y=method_lines,
|
||||
name="Lines of Code",
|
||||
marker=dict(color="rgba(99, 179, 237, 0.8)"),
|
||||
hovertext=[
|
||||
f"{name}<br>Lines: {lines}<br>Arguments: {args}"
|
||||
for name, lines, args in zip(method_names, method_lines, method_args)
|
||||
],
|
||||
hoverinfo="text",
|
||||
)
|
||||
)
|
||||
|
||||
fig.update_layout(
|
||||
title=f"Method Complexity - {model_class.name}",
|
||||
xaxis_title="Methods",
|
||||
yaxis_title="Lines of Code",
|
||||
template="plotly_dark",
|
||||
height=400,
|
||||
showlegend=False,
|
||||
)
|
||||
|
||||
return fig
|
||||
Loading…
Reference in New Issue
Block a user