modrec-workflow/scripts/model_builder/plot_data.py
Michael Luciuk 9979d84e29
All checks were successful
Modulation Recognition Demo / ria-demo (push) Successful in 2m52s
Documentation and formatting updates (#1)
Documentation and formatting updates:
- Updates to project README.
- Adding project health files (`LICENSE` and `SECURITY.md`)
- A few minor formatting changes throughout
- A few typo fixes, removal of unused code, cleanup of shadowed variables, and fixed import ordering with isort.

**Note:** These changes have not been tested.

Co-authored-by: Michael Luciuk <michael.luciuk@gmail.com>
Co-authored-by: Liyu Xiao <liyu@qoherent.ai>
Reviewed-on: https://git.riahub.ai/qoherent/modrec-workflow/pulls/1
Reviewed-by: Liyux <liyux@noreply.localhost>
Co-authored-by: Michael Luciuk <michael@qoherent.ai>
Co-committed-by: Michael Luciuk <michael@qoherent.ai>
2025-07-08 10:50:41 -04:00

145 lines
4.2 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import numpy as np
import torch
from sklearn.metrics import classification_report
os.environ["NNPACK"] = "0"
from matplotlib import pyplot as plt
from mobilenetv3 import RFClassifier, mobilenetv3
from modulation_dataset import ModulationH5Dataset
from helpers.app_settings import get_app_settings
def load_validation_data():
val_dataset = ModulationH5Dataset(
"data/dataset/val.h5", label_name="modulation", data_key="validation_data"
)
X = np.stack([x.numpy() for x, _ in val_dataset]) # shape: (N, C, L)
y = np.array([y.item() for _, y in val_dataset]) # shape: (N,)
class_names = list(val_dataset.label_encoder.classes_)
return X, y, class_names
def build_model_from_ckpt(
ckpt_path: str, in_channels: int, num_classes: int
) -> torch.nn.Module:
"""
Build and return a PyTorch model loaded from a checkpoint.
"""
model = RFClassifier(
model=mobilenetv3(
model_size="mobilenetv3_small_050",
num_classes=num_classes,
in_chans=in_channels,
)
)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
checkpoint = torch.load(ckpt_path, weights_only=True, map_location=device)
model.load_state_dict(checkpoint["state_dict"])
model.eval()
return model
def evaluate_checkpoint(ckpt_path: str):
"""
Loads the model from checkpoint and evaluates it on a validation set.
Prints classification metrics and plots a confusion matrix.
"""
# Load validation data
X_val, y_true, class_names = load_validation_data()
num_classes = len(class_names)
in_channels = X_val.shape[1]
# Load model
model = build_model_from_ckpt(
ckpt_path, in_channels=in_channels, num_classes=num_classes
)
# Inference
y_pred = []
with torch.no_grad():
for x in X_val:
x_tensor = torch.tensor(x[np.newaxis, ...], dtype=torch.float32)
logits = model(x_tensor)
pred = torch.argmax(logits, dim=1).item()
y_pred.append(pred)
# Print classification report
print("\nClassification Report:")
print(
classification_report(y_true, y_pred, target_names=class_names, zero_division=0)
)
plot_confusion_matrix_with_counts(
y_true=np.array(y_true),
y_pred=np.array(y_pred),
classes=class_names,
normalize=True,
title="Normalized Confusion Matrix",
)
def plot_confusion_matrix_with_counts(
y_true: np.ndarray,
y_pred: np.ndarray,
classes: list[str],
normalize: bool = True,
title: str = "Confusion Matrix (counts and normalized)",
) -> None:
"""
Plot a confusion matrix showing both raw counts and (optionally) normalized values.
Args:
y_true: true labels (integers 0..C-1)
y_pred: predicted labels (same shape as y_true)
classes: list of classname strings in index order
normalize: if True, each row is normalized to sum=1
title: title for the plot
"""
# 1) build raw CM
C = len(classes)
cm = np.zeros((C, C), dtype=int)
for t, p in zip(y_true, y_pred):
cm[t, p] += 1
# 2) normalize if requested
if normalize:
with np.errstate(divide="ignore", invalid="ignore"):
cm_norm = cm.astype(float) / cm.sum(axis=1)[:, None]
cm_norm = np.nan_to_num(cm_norm)
else:
cm_norm = cm
# 3) plot
fig, ax = plt.subplots(figsize=(8, 8))
im = ax.imshow(cm_norm, interpolation="nearest")
ax.set_title(title)
ax.set_xlabel("Predicted label")
ax.set_ylabel("True label")
ax.set_xticks(np.arange(C))
ax.set_yticks(np.arange(C))
ax.set_xticklabels(classes, rotation=45, ha="right")
ax.set_yticklabels(classes)
# 4) annotate
for i in range(C):
for j in range(C):
count = cm[i, j]
val = cm_norm[i, j]
txt = f"{count}\n{val:.2f}"
ax.text(j, i, txt, ha="center", va="center")
fig.colorbar(im, ax=ax, label="Normalized value" if normalize else "Count")
plt.tight_layout()
plt.show()
if __name__ == "__main__":
settings = get_app_settings()
evaluate_checkpoint(os.path.join("checkpoint_files", "inference_recognition_model.ckpt"))