Merge pull request 'liyu-dev' (#3) from liyu-dev into main
Some checks failed
Modulation Recognition Demo / ria-demo (push) Has been cancelled
Some checks failed
Modulation Recognition Demo / ria-demo (push) Has been cancelled
Reviewed-on: https://git.riahub.ai/qoherent/modrec-workflow/pulls/3
This commit is contained in:
commit
7c7a882f1f
|
@ -2,11 +2,9 @@ name: Modulation Recognition Demo
|
||||||
|
|
||||||
on:
|
on:
|
||||||
push:
|
push:
|
||||||
branches:
|
branches: [main]
|
||||||
[main]
|
|
||||||
pull_request:
|
pull_request:
|
||||||
branches:
|
branches: [main]
|
||||||
[main]
|
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
ria-demo:
|
ria-demo:
|
||||||
|
@ -46,22 +44,24 @@ jobs:
|
||||||
fi
|
fi
|
||||||
pip install -r requirements.txt
|
pip install -r requirements.txt
|
||||||
|
|
||||||
|
|
||||||
- name: 1. Generate Recordings
|
- name: 1. Generate Recordings
|
||||||
run: |
|
run: |
|
||||||
mkdir -p data/recordings
|
mkdir -p data/recordings
|
||||||
PYTHONPATH=. python scripts/dataset_manager/data_gen.py --output-dir data/recordings
|
PYTHONPATH=. python scripts/dataset_manager/data_gen.py --output-dir data/recordings
|
||||||
|
|
||||||
|
- name: 📦 Compress Recordings
|
||||||
|
run: tar -czf recordings.tar.gz -C data/recordings .
|
||||||
|
|
||||||
- name: ⬆️ Upload recordings
|
- name: ⬆️ Upload recordings
|
||||||
uses: actions/upload-artifact@v3
|
uses: actions/upload-artifact@v3
|
||||||
with:
|
with:
|
||||||
name: recordings
|
name: recordings
|
||||||
path: data/recordings/**
|
path: recordings.tar.gz
|
||||||
|
|
||||||
- name: 2. Build HDF5 Dataset
|
- name: 2. Build HDF5 Dataset
|
||||||
run: |
|
run: |
|
||||||
mkdir -p data/dataset
|
mkdir -p data/dataset
|
||||||
PYTHONPATH=. python scripts/dataset_manager/produce_dataset.py
|
PYTHONPATH=. python scripts/dataset_manager/produce_dataset.py
|
||||||
shell: bash
|
shell: bash
|
||||||
|
|
||||||
- name: ⬆️ Upload Dataset
|
- name: ⬆️ Upload Dataset
|
||||||
|
@ -72,16 +72,16 @@ jobs:
|
||||||
|
|
||||||
- name: 3. Train Model
|
- name: 3. Train Model
|
||||||
env:
|
env:
|
||||||
NO_NNPACK: 1
|
NO_NNPACK: 1
|
||||||
PYTORCH_NO_NNPACK: 1
|
PYTORCH_NO_NNPACK: 1
|
||||||
run: |
|
run: |
|
||||||
mkdir -p checkpoint_files
|
mkdir -p checkpoint_files
|
||||||
PYTHONPATH=. python scripts/model_builder/train.py 2>/dev/null
|
PYTHONPATH=. python scripts/model_builder/train.py 2>/dev/null
|
||||||
|
|
||||||
- name: 4. Plot Model
|
- name: 4. Plot Model
|
||||||
env:
|
env:
|
||||||
NO_NNPACK: 1
|
NO_NNPACK: 1
|
||||||
PYTORCH_NO_NNPACK: 1
|
PYTORCH_NO_NNPACK: 1
|
||||||
run: |
|
run: |
|
||||||
PYTHONPATH=. python scripts/model_builder/plot_data.py 2>/dev/null
|
PYTHONPATH=. python scripts/model_builder/plot_data.py 2>/dev/null
|
||||||
|
|
||||||
|
@ -98,7 +98,7 @@ jobs:
|
||||||
run: |
|
run: |
|
||||||
mkdir -p onnx_files
|
mkdir -p onnx_files
|
||||||
MKL_DISABLE_FAST_MM=1 PYTHONPATH=. python scripts/application_packager/convert_to_onnx.py 2>/dev/null
|
MKL_DISABLE_FAST_MM=1 PYTHONPATH=. python scripts/application_packager/convert_to_onnx.py 2>/dev/null
|
||||||
|
|
||||||
- name: ⬆️ Upload ONNX file
|
- name: ⬆️ Upload ONNX file
|
||||||
uses: actions/upload-artifact@v3
|
uses: actions/upload-artifact@v3
|
||||||
with:
|
with:
|
||||||
|
@ -108,13 +108,13 @@ jobs:
|
||||||
- name: 6. Profile ONNX model
|
- name: 6. Profile ONNX model
|
||||||
run: |
|
run: |
|
||||||
PYTHONPATH=. python scripts/application_packager/profile_onnx.py
|
PYTHONPATH=. python scripts/application_packager/profile_onnx.py
|
||||||
|
|
||||||
- name: ⬆️ Upload JSON trace
|
- name: ⬆️ Upload JSON trace
|
||||||
uses: actions/upload-artifact@v3
|
uses: actions/upload-artifact@v3
|
||||||
with:
|
with:
|
||||||
name: profile-data
|
name: profile-data
|
||||||
path: '**/onnxruntime_profile_*.json'
|
path: "**/onnxruntime_profile_*.json"
|
||||||
|
|
||||||
- name: 7. Convert ONNX graph to an ORT file
|
- name: 7. Convert ONNX graph to an ORT file
|
||||||
run: |
|
run: |
|
||||||
PYTHONPATH=. python scripts/application_packager/convert_to_ort.py
|
PYTHONPATH=. python scripts/application_packager/convert_to_ort.py
|
||||||
|
|
|
@ -24,7 +24,7 @@ dataset:
|
||||||
snr_step: 3
|
snr_step: 3
|
||||||
|
|
||||||
# Number of iterations (signal recordings) per modulation and SNR combination
|
# Number of iterations (signal recordings) per modulation and SNR combination
|
||||||
num_iterations: 3
|
num_iterations: 100
|
||||||
|
|
||||||
# Modulation scheme settings; keys must match the `modulation_types` list above
|
# Modulation scheme settings; keys must match the `modulation_types` list above
|
||||||
# Each entry includes:
|
# Each entry includes:
|
||||||
|
@ -50,14 +50,14 @@ dataset:
|
||||||
|
|
||||||
# Training and validation split ratios; must sum to 1
|
# Training and validation split ratios; must sum to 1
|
||||||
train_split: 0.8
|
train_split: 0.8
|
||||||
val_split : 0.2
|
val_split: 0.2
|
||||||
|
|
||||||
training:
|
training:
|
||||||
# Number of training examples processed together before the model updates its weights
|
# Number of training examples processed together before the model updates its weights
|
||||||
batch_size: 256
|
batch_size: 256
|
||||||
|
|
||||||
# Number of complete passes through the training dataset during training
|
# Number of complete passes through the training dataset during training
|
||||||
epochs: 5
|
epochs: 30
|
||||||
|
|
||||||
# Learning rate: step size for weight updates after each batch
|
# Learning rate: step size for weight updates after each batch
|
||||||
# Recommended range for fine-tuning: 1e-6 to 1e-4
|
# Recommended range for fine-tuning: 1e-6 to 1e-4
|
||||||
|
|
|
@ -41,7 +41,11 @@ class AppConfig:
|
||||||
|
|
||||||
|
|
||||||
class AppSettings:
|
class AppSettings:
|
||||||
"""Application settings, to be initialized from app.yaml configuration file."""
|
"""
|
||||||
|
Application settings,
|
||||||
|
to be initialized from
|
||||||
|
app.yaml configuration file.
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self, config_file: str):
|
def __init__(self, config_file: str):
|
||||||
# Load the YAML configuration file
|
# Load the YAML configuration file
|
||||||
|
|
|
@ -2,9 +2,9 @@ import os
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from scripts.model_builder.mobilenetv3 import RFClassifier, mobilenetv3
|
|
||||||
|
|
||||||
from helpers.app_settings import get_app_settings
|
from helpers.app_settings import get_app_settings
|
||||||
|
from scripts.model_builder.mobilenetv3 import RFClassifier, mobilenetv3
|
||||||
|
|
||||||
|
|
||||||
def convert_to_onnx(ckpt_path: str, fp16: bool = False) -> None:
|
def convert_to_onnx(ckpt_path: str, fp16: bool = False) -> None:
|
||||||
|
|
|
@ -29,7 +29,7 @@ def generate_modulated_signals(output_dir: str) -> None:
|
||||||
|
|
||||||
for modulation in settings.modulation_types:
|
for modulation in settings.modulation_types:
|
||||||
for snr in np.arange(settings.snr_start, settings.snr_stop, settings.snr_step):
|
for snr in np.arange(settings.snr_start, settings.snr_stop, settings.snr_step):
|
||||||
for i in range(3):
|
for _ in range(settings.num_iterations):
|
||||||
recording_length = settings.recording_length
|
recording_length = settings.recording_length
|
||||||
beta = (
|
beta = (
|
||||||
settings.beta
|
settings.beta
|
||||||
|
|
|
@ -49,8 +49,6 @@ def write_hdf5_file(records: List, output_path: str, dataset_name: str = "data")
|
||||||
int(md["sps"]),
|
int(md["sps"]),
|
||||||
)
|
)
|
||||||
|
|
||||||
first_rec, _ = records[0] # records[0] is a tuple of (data, md)
|
|
||||||
|
|
||||||
with h5py.File(output_path, "w") as hf:
|
with h5py.File(output_path, "w") as hf:
|
||||||
data_arr = np.stack([rec[0] for rec in records])
|
data_arr = np.stack([rec[0] for rec in records])
|
||||||
dset = hf.create_dataset(dataset_name, data=data_arr, compression="gzip")
|
dset = hf.create_dataset(dataset_name, data=data_arr, compression="gzip")
|
||||||
|
|
|
@ -1,5 +1,4 @@
|
||||||
import lightning as L
|
import lightning as L
|
||||||
import numpy as np
|
|
||||||
import timm
|
import timm
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
|
@ -2,139 +2,125 @@ import os
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from sklearn.metrics import classification_report
|
|
||||||
|
|
||||||
os.environ["NNPACK"] = "0"
|
|
||||||
from matplotlib import pyplot as plt
|
from matplotlib import pyplot as plt
|
||||||
from mobilenetv3 import RFClassifier, mobilenetv3
|
from mobilenetv3 import RFClassifier, mobilenetv3
|
||||||
from modulation_dataset import ModulationH5Dataset
|
from modulation_dataset import ModulationH5Dataset
|
||||||
|
from sklearn.metrics import classification_report
|
||||||
|
|
||||||
from helpers.app_settings import get_app_settings
|
from helpers.app_settings import get_app_settings
|
||||||
|
|
||||||
|
os.environ["NNPACK"] = "0"
|
||||||
|
|
||||||
|
|
||||||
def load_validation_data():
|
def load_validation_data():
|
||||||
val_dataset = ModulationH5Dataset(
|
val_dataset = ModulationH5Dataset(
|
||||||
"data/dataset/val.h5", label_name="modulation", data_key="validation_data"
|
"data/dataset/val.h5", label_name="modulation", data_key="validation_data"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
x = np.stack([x.numpy() for x, _ in val_dataset]) # shape: (N, C, L)
|
||||||
|
y = np.array([y.item() for _, y in val_dataset]) # shape: (N,)
|
||||||
|
class_names = list(val_dataset.label_encoder.classes_)
|
||||||
|
|
||||||
x = np.stack([x.numpy() for x, _ in val_dataset]) # shape: (N, C, L)
|
return x, y, class_names
|
||||||
y = np.array([y.item() for _, y in val_dataset]) # shape: (N,)
|
|
||||||
class_names = list(val_dataset.label_encoder.classes_)
|
|
||||||
|
|
||||||
|
|
||||||
return x, y, class_names
|
|
||||||
|
|
||||||
|
|
||||||
def build_model_from_ckpt(
|
def build_model_from_ckpt(
|
||||||
ckpt_path: str, in_channels: int, num_classes: int
|
ckpt_path: str, in_channels: int, num_classes: int
|
||||||
) -> torch.nn.Module:
|
) -> torch.nn.Module:
|
||||||
"""
|
"""
|
||||||
Build and return a PyTorch model loaded from a checkpoint.
|
Build and return a PyTorch model loaded from a checkpoint.
|
||||||
"""
|
"""
|
||||||
model = RFClassifier(
|
model = RFClassifier(
|
||||||
model=mobilenetv3(
|
model=mobilenetv3(
|
||||||
model_size="mobilenetv3_small_050",
|
model_size="mobilenetv3_small_050",
|
||||||
num_classes=num_classes,
|
num_classes=num_classes,
|
||||||
in_chans=in_channels,
|
in_chans=in_channels,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
checkpoint = torch.load(ckpt_path, weights_only=True, map_location=device)
|
checkpoint = torch.load(ckpt_path, weights_only=True, map_location=device)
|
||||||
model.load_state_dict(checkpoint["state_dict"])
|
model.load_state_dict(checkpoint["state_dict"])
|
||||||
model.eval()
|
model.eval()
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def evaluate_checkpoint(ckpt_path: str):
|
def evaluate_checkpoint(ckpt_path: str):
|
||||||
"""
|
"""
|
||||||
Loads the model from checkpoint and evaluates it on a validation set.
|
Loads the model from checkpoint and evaluates it on a validation set.
|
||||||
Prints classification metrics and plots a confusion matrix.
|
Prints classification metrics and plots a confusion matrix.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# Load validation data
|
||||||
|
X_val, y_true, class_names = load_validation_data()
|
||||||
|
num_classes = len(class_names)
|
||||||
|
in_channels = X_val.shape[1]
|
||||||
|
|
||||||
# Load validation data
|
# Load model
|
||||||
X_val, y_true, class_names = load_validation_data()
|
model = build_model_from_ckpt(
|
||||||
num_classes = len(class_names)
|
ckpt_path, in_channels=in_channels, num_classes=num_classes
|
||||||
in_channels = X_val.shape[1]
|
)
|
||||||
|
|
||||||
|
# Inference
|
||||||
|
y_pred = []
|
||||||
|
with torch.no_grad():
|
||||||
|
for x in X_val:
|
||||||
|
x_tensor = torch.tensor(x[np.newaxis, ...], dtype=torch.float32)
|
||||||
|
logits = model(x_tensor)
|
||||||
|
pred = torch.argmax(logits, dim=1).item()
|
||||||
|
y_pred.append(pred)
|
||||||
|
|
||||||
# Load model
|
# Print classification report
|
||||||
model = build_model_from_ckpt(
|
print("\nClassification Report:")
|
||||||
ckpt_path, in_channels=in_channels, num_classes=num_classes
|
print(
|
||||||
)
|
classification_report(y_true, y_pred, target_names=class_names, zero_division=0)
|
||||||
|
)
|
||||||
|
|
||||||
# Inference
|
|
||||||
y_pred = []
|
|
||||||
with torch.no_grad():
|
|
||||||
for x in X_val:
|
|
||||||
x_tensor = torch.tensor(x[np.newaxis, ...], dtype=torch.float32)
|
|
||||||
logits = model(x_tensor)
|
|
||||||
pred = torch.argmax(logits, dim=1).item()
|
|
||||||
y_pred.append(pred)
|
|
||||||
|
|
||||||
|
|
||||||
# Print classification report
|
|
||||||
print("\nClassification Report:")
|
|
||||||
print(
|
|
||||||
classification_report(y_true, y_pred, target_names=class_names, zero_division=0)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
print_confusion_matrix(
|
|
||||||
y_true=np.array(y_true),
|
|
||||||
y_pred=np.array(y_pred),
|
|
||||||
classes=class_names,
|
|
||||||
normalize=True,
|
|
||||||
title="Normalized Confusion Matrix",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
|
print_confusion_matrix(
|
||||||
|
y_true=np.array(y_true),
|
||||||
|
y_pred=np.array(y_pred),
|
||||||
|
classes=class_names,
|
||||||
|
normalize=True,
|
||||||
|
title="Normalized Confusion Matrix",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def print_confusion_matrix(
|
def print_confusion_matrix(
|
||||||
y_true: np.ndarray,
|
y_true: np.ndarray,
|
||||||
y_pred: np.ndarray,
|
y_pred: np.ndarray,
|
||||||
classes: list[str],
|
classes: list[str],
|
||||||
normalize: bool = True,
|
normalize: bool = True,
|
||||||
title: str = "Confusion Matrix (counts and normalized)",
|
title: str = "Confusion Matrix (counts and normalized)",
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Plot a confusion matrix showing both raw counts and (optionally) normalized values.
|
Plot a confusion matrix showing both raw counts and (optionally) normalized values.
|
||||||
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
y_true: true labels (integers 0..C-1)
|
y_true: true labels (integers 0..C-1)
|
||||||
y_pred: predicted labels (same shape as y_true)
|
y_pred: predicted labels (same shape as y_true)
|
||||||
classes: list of class‐name strings in index order
|
classes: list of class‐name strings in index order
|
||||||
normalize: if True, each row is normalized to sum=1
|
normalize: if True, each row is normalized to sum=1
|
||||||
title: title for the plot
|
title: title for the plot
|
||||||
"""
|
"""
|
||||||
# 1) build raw CM
|
# 1) build raw CM
|
||||||
c = len(classes)
|
c = len(classes)
|
||||||
cm = np.zeros((c, c), dtype=int)
|
cm = np.zeros((c, c), dtype=int)
|
||||||
for t, p in zip(y_true, y_pred):
|
for t, p in zip(y_true, y_pred):
|
||||||
cm[t, p] += 1
|
cm[t, p] += 1
|
||||||
|
|
||||||
|
|
||||||
# 2) normalize if requested
|
|
||||||
if normalize:
|
|
||||||
with np.errstate(divide="ignore", invalid="ignore"):
|
|
||||||
cm_norm = cm.astype(float) / cm.sum(axis=1)[:, None]
|
|
||||||
cm_norm = np.nan_to_num(cm_norm)
|
|
||||||
print_confusion_matrix_helper(cm_norm, classes)
|
|
||||||
else:
|
|
||||||
print_confusion_matrix_helper(cm, classes)
|
|
||||||
|
|
||||||
|
|
||||||
|
# 2) normalize if requested
|
||||||
|
if normalize:
|
||||||
|
with np.errstate(divide="ignore", invalid="ignore"):
|
||||||
|
cm_norm = cm.astype(float) / cm.sum(axis=1)[:, None]
|
||||||
|
cm_norm = np.nan_to_num(cm_norm)
|
||||||
|
print_confusion_matrix_helper(cm_norm, classes)
|
||||||
|
else:
|
||||||
|
print_confusion_matrix_helper(cm, classes)
|
||||||
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
def print_confusion_matrix_helper(matrix, classes=None, normalize=False, digits=2):
|
def print_confusion_matrix_helper(matrix, classes=None, normalize=False, digits=2):
|
||||||
"""
|
"""
|
||||||
Pretty prints a confusion matrix with x/y labels.
|
Pretty prints a confusion matrix with x/y labels.
|
||||||
|
@ -148,13 +134,13 @@ def print_confusion_matrix_helper(matrix, classes=None, normalize=False, digits=
|
||||||
matrix = np.array(matrix)
|
matrix = np.array(matrix)
|
||||||
num_classes = matrix.shape[0]
|
num_classes = matrix.shape[0]
|
||||||
labels = classes or list(range(num_classes))
|
labels = classes or list(range(num_classes))
|
||||||
|
|
||||||
# Header
|
# Header
|
||||||
print(" " * 9 + "Ground Truth →")
|
print(" " * 9 + "Ground Truth →")
|
||||||
header = "Pred ↓ | " + " ".join([f"{str(label):>6}" for label in labels])
|
header = "Pred ↓ | " + " ".join([f"{str(label):>6}" for label in labels])
|
||||||
print(header)
|
print(header)
|
||||||
print("-" * len(header))
|
print("-" * len(header))
|
||||||
|
|
||||||
# Rows
|
# Rows
|
||||||
for i in range(num_classes):
|
for i in range(num_classes):
|
||||||
row_vals = matrix[i]
|
row_vals = matrix[i]
|
||||||
|
@ -166,9 +152,9 @@ def print_confusion_matrix_helper(matrix, classes=None, normalize=False, digits=
|
||||||
row_str = " ".join([f"{int(val):>6}" for val in row_vals])
|
row_str = " ".join([f"{int(val):>6}" for val in row_vals])
|
||||||
print(f"{str(labels[i]):>7} | {row_str}")
|
print(f"{str(labels[i]):>7} | {row_str}")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
settings = get_app_settings()
|
settings = get_app_settings()
|
||||||
evaluate_checkpoint(os.path.join("checkpoint_files", "inference_recognition_model.ckpt"))
|
evaluate_checkpoint(
|
||||||
|
os.path.join("checkpoint_files", "inference_recognition_model.ckpt")
|
||||||
|
)
|
||||||
|
|
|
@ -1,23 +1,22 @@
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
os.environ["NNPACK"] = "0"
|
|
||||||
import lightning as L
|
import lightning as L
|
||||||
import mobilenetv3
|
import mobilenetv3
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import torchmetrics
|
import torchmetrics
|
||||||
from lightning.pytorch.callbacks import ModelCheckpoint
|
from lightning.pytorch.callbacks import ModelCheckpoint, TQDMProgressBar
|
||||||
from modulation_dataset import ModulationH5Dataset
|
from modulation_dataset import ModulationH5Dataset
|
||||||
|
|
||||||
from helpers.app_settings import get_app_settings
|
from helpers.app_settings import get_app_settings
|
||||||
|
|
||||||
|
os.environ["NNPACK"] = "0"
|
||||||
script_dir = os.path.dirname(os.path.abspath(__file__))
|
script_dir = os.path.dirname(os.path.abspath(__file__))
|
||||||
data_dir = os.path.abspath(os.path.join(script_dir, ".."))
|
data_dir = os.path.abspath(os.path.join(script_dir, ".."))
|
||||||
project_root = os.path.abspath(os.path.join(os.getcwd(), ".."))
|
project_root = os.path.abspath(os.path.join(os.getcwd(), ".."))
|
||||||
if project_root not in sys.path:
|
if project_root not in sys.path:
|
||||||
sys.path.insert(0, project_root)
|
sys.path.insert(0, project_root)
|
||||||
from lightning.pytorch.callbacks import TQDMProgressBar
|
|
||||||
|
|
||||||
|
|
||||||
class CustomProgressBar(TQDMProgressBar):
|
class CustomProgressBar(TQDMProgressBar):
|
||||||
|
@ -59,8 +58,6 @@ def train_model():
|
||||||
print("X shape:", x.shape)
|
print("X shape:", x.shape)
|
||||||
print("Y values:", y[:10])
|
print("Y values:", y[:10])
|
||||||
break
|
break
|
||||||
|
|
||||||
unique_labels = list(set([row[label].decode("utf-8") for row in ds_train.metadata]))
|
|
||||||
num_classes = len(ds_train.label_encoder.classes_)
|
num_classes = len(ds_train.label_encoder.classes_)
|
||||||
|
|
||||||
hparams = {
|
hparams = {
|
||||||
|
|
Loading…
Reference in New Issue
Block a user