liyu-dev #3

Merged
Liyux merged 6 commits from liyu-dev into main 2025-08-21 11:05:55 -04:00
Showing only changes of commit 1c7ddef5cb - Show all commits

View File

@ -13,128 +13,114 @@ from helpers.app_settings import get_app_settings
def load_validation_data(): def load_validation_data():
val_dataset = ModulationH5Dataset( val_dataset = ModulationH5Dataset(
"data/dataset/val.h5", label_name="modulation", data_key="validation_data" "data/dataset/val.h5", label_name="modulation", data_key="validation_data"
) )
x = np.stack([x.numpy() for x, _ in val_dataset]) # shape: (N, C, L)
y = np.array([y.item() for _, y in val_dataset]) # shape: (N,)
class_names = list(val_dataset.label_encoder.classes_)
x = np.stack([x.numpy() for x, _ in val_dataset]) # shape: (N, C, L) return x, y, class_names
y = np.array([y.item() for _, y in val_dataset]) # shape: (N,)
class_names = list(val_dataset.label_encoder.classes_)
return x, y, class_names
def build_model_from_ckpt( def build_model_from_ckpt(
ckpt_path: str, in_channels: int, num_classes: int ckpt_path: str, in_channels: int, num_classes: int
) -> torch.nn.Module: ) -> torch.nn.Module:
""" """
Build and return a PyTorch model loaded from a checkpoint. Build and return a PyTorch model loaded from a checkpoint.
""" """
model = RFClassifier( model = RFClassifier(
model=mobilenetv3( model=mobilenetv3(
model_size="mobilenetv3_small_050", model_size="mobilenetv3_small_050",
num_classes=num_classes, num_classes=num_classes,
in_chans=in_channels, in_chans=in_channels,
) )
) )
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
checkpoint = torch.load(ckpt_path, weights_only=True, map_location=device) checkpoint = torch.load(ckpt_path, weights_only=True, map_location=device)
model.load_state_dict(checkpoint["state_dict"]) model.load_state_dict(checkpoint["state_dict"])
model.eval() model.eval()
return model return model
def evaluate_checkpoint(ckpt_path: str): def evaluate_checkpoint(ckpt_path: str):
""" """
Loads the model from checkpoint and evaluates it on a validation set. Loads the model from checkpoint and evaluates it on a validation set.
Prints classification metrics and plots a confusion matrix. Prints classification metrics and plots a confusion matrix.
""" """
# Load validation data
X_val, y_true, class_names = load_validation_data()
num_classes = len(class_names)
in_channels = X_val.shape[1]
# Load validation data # Load model
X_val, y_true, class_names = load_validation_data() model = build_model_from_ckpt(
num_classes = len(class_names) ckpt_path, in_channels=in_channels, num_classes=num_classes
in_channels = X_val.shape[1] )
# Inference
y_pred = []
with torch.no_grad():
for x in X_val:
x_tensor = torch.tensor(x[np.newaxis, ...], dtype=torch.float32)
logits = model(x_tensor)
pred = torch.argmax(logits, dim=1).item()
y_pred.append(pred)
# Load model # Print classification report
model = build_model_from_ckpt( print("\nClassification Report:")
ckpt_path, in_channels=in_channels, num_classes=num_classes print(
) classification_report(y_true, y_pred, target_names=class_names, zero_division=0)
)
# Inference
y_pred = []
with torch.no_grad():
for x in X_val:
x_tensor = torch.tensor(x[np.newaxis, ...], dtype=torch.float32)
logits = model(x_tensor)
pred = torch.argmax(logits, dim=1).item()
y_pred.append(pred)
# Print classification report
print("\nClassification Report:")
print(
classification_report(y_true, y_pred, target_names=class_names, zero_division=0)
)
print_confusion_matrix(
y_true=np.array(y_true),
y_pred=np.array(y_pred),
classes=class_names,
normalize=True,
title="Normalized Confusion Matrix",
)
print_confusion_matrix(
y_true=np.array(y_true),
y_pred=np.array(y_pred),
classes=class_names,
normalize=True,
title="Normalized Confusion Matrix",
)
def print_confusion_matrix( def print_confusion_matrix(
y_true: np.ndarray, y_true: np.ndarray,
y_pred: np.ndarray, y_pred: np.ndarray,
classes: list[str], classes: list[str],
normalize: bool = True, normalize: bool = True,
title: str = "Confusion Matrix (counts and normalized)", title: str = "Confusion Matrix (counts and normalized)",
) -> None: ) -> None:
""" """
Plot a confusion matrix showing both raw counts and (optionally) normalized values. Plot a confusion matrix showing both raw counts and (optionally) normalized values.
Args: Args:
y_true: true labels (integers 0..C-1) y_true: true labels (integers 0..C-1)
y_pred: predicted labels (same shape as y_true) y_pred: predicted labels (same shape as y_true)
classes: list of classname strings in index order classes: list of classname strings in index order
normalize: if True, each row is normalized to sum=1 normalize: if True, each row is normalized to sum=1
title: title for the plot title: title for the plot
""" """
# 1) build raw CM # 1) build raw CM
c = len(classes) c = len(classes)
cm = np.zeros((c, c), dtype=int) cm = np.zeros((c, c), dtype=int)
for t, p in zip(y_true, y_pred): for t, p in zip(y_true, y_pred):
cm[t, p] += 1 cm[t, p] += 1
# 2) normalize if requested
if normalize:
with np.errstate(divide="ignore", invalid="ignore"):
cm_norm = cm.astype(float) / cm.sum(axis=1)[:, None]
cm_norm = np.nan_to_num(cm_norm)
print_confusion_matrix_helper(cm_norm, classes)
else:
print_confusion_matrix_helper(cm, classes)
# 2) normalize if requested
if normalize:
with np.errstate(divide="ignore", invalid="ignore"):
cm_norm = cm.astype(float) / cm.sum(axis=1)[:, None]
cm_norm = np.nan_to_num(cm_norm)
print_confusion_matrix_helper(cm_norm, classes)
else:
print_confusion_matrix_helper(cm, classes)
import numpy as np import numpy as np
def print_confusion_matrix_helper(matrix, classes=None, normalize=False, digits=2): def print_confusion_matrix_helper(matrix, classes=None, normalize=False, digits=2):
""" """
Pretty prints a confusion matrix with x/y labels. Pretty prints a confusion matrix with x/y labels.
@ -148,13 +134,13 @@ def print_confusion_matrix_helper(matrix, classes=None, normalize=False, digits=
matrix = np.array(matrix) matrix = np.array(matrix)
num_classes = matrix.shape[0] num_classes = matrix.shape[0]
labels = classes or list(range(num_classes)) labels = classes or list(range(num_classes))
# Header # Header
print(" " * 9 + "Ground Truth →") print(" " * 9 + "Ground Truth →")
header = "Pred ↓ | " + " ".join([f"{str(label):>6}" for label in labels]) header = "Pred ↓ | " + " ".join([f"{str(label):>6}" for label in labels])
print(header) print(header)
print("-" * len(header)) print("-" * len(header))
# Rows # Rows
for i in range(num_classes): for i in range(num_classes):
row_vals = matrix[i] row_vals = matrix[i]
@ -166,9 +152,9 @@ def print_confusion_matrix_helper(matrix, classes=None, normalize=False, digits=
row_str = " ".join([f"{int(val):>6}" for val in row_vals]) row_str = " ".join([f"{int(val):>6}" for val in row_vals])
print(f"{str(labels[i]):>7} | {row_str}") print(f"{str(labels[i]):>7} | {row_str}")
if __name__ == "__main__": if __name__ == "__main__":
settings = get_app_settings() settings = get_app_settings()
evaluate_checkpoint(os.path.join("checkpoint_files", "inference_recognition_model.ckpt")) evaluate_checkpoint(
os.path.join("checkpoint_files", "inference_recognition_model.ckpt")
)