clean up code, adjusted default parameters

This commit is contained in:
Liyu Xiao 2025-08-21 10:56:23 -04:00
parent a298384f7e
commit 04e5f4db8a
6 changed files with 12 additions and 16 deletions

View File

@ -5,9 +5,6 @@ dataset:
# Number of samples per recording
recording_length: 1024
# Set this to scale the number of generated recordings
mult_factor: 5
# List of signal modulation schemes to include in the dataset
modulation_types:
- bpsk
@ -27,7 +24,7 @@ dataset:
snr_step: 3
# Number of iterations (signal recordings) per modulation and SNR combination
num_iterations: 3
num_iterations: 100
# Modulation scheme settings; keys must match the `modulation_types` list above
# Each entry includes:

View File

@ -9,7 +9,6 @@ import yaml
@dataclass
class DataSetConfig:
num_slices: int
mult_factor: int
train_split: float
seed: int
modulation_types: list
@ -42,7 +41,11 @@ class AppConfig:
class AppSettings:
"""Application settings, to be initialized from app.yaml configuration file."""
"""
Application settings,
to be initialized from
app.yaml configuration file.
"""
def __init__(self, config_file: str):
# Load the YAML configuration file

View File

@ -29,7 +29,7 @@ def generate_modulated_signals(output_dir: str) -> None:
for modulation in settings.modulation_types:
for snr in np.arange(settings.snr_start, settings.snr_stop, settings.snr_step):
for _ in range(settings.mult_factor):
for _ in range(settings.num_iterations):
recording_length = settings.recording_length
beta = (
settings.beta

View File

@ -1,5 +1,4 @@
import lightning as L
import numpy as np
import timm
import torch
from torch import nn

View File

@ -2,15 +2,15 @@ import os
import numpy as np
import torch
from sklearn.metrics import classification_report
os.environ["NNPACK"] = "0"
from matplotlib import pyplot as plt
from mobilenetv3 import RFClassifier, mobilenetv3
from modulation_dataset import ModulationH5Dataset
from sklearn.metrics import classification_report
from helpers.app_settings import get_app_settings
os.environ["NNPACK"] = "0"
def load_validation_data():
val_dataset = ModulationH5Dataset(

View File

@ -1,23 +1,22 @@
import os
import sys
os.environ["NNPACK"] = "0"
import lightning as L
import mobilenetv3
import torch
import torch.nn.functional as F
import torchmetrics
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.callbacks import ModelCheckpoint, TQDMProgressBar
from modulation_dataset import ModulationH5Dataset
from helpers.app_settings import get_app_settings
os.environ["NNPACK"] = "0"
script_dir = os.path.dirname(os.path.abspath(__file__))
data_dir = os.path.abspath(os.path.join(script_dir, ".."))
project_root = os.path.abspath(os.path.join(os.getcwd(), ".."))
if project_root not in sys.path:
sys.path.insert(0, project_root)
from lightning.pytorch.callbacks import TQDMProgressBar
class CustomProgressBar(TQDMProgressBar):
@ -59,8 +58,6 @@ def train_model():
print("X shape:", x.shape)
print("Y values:", y[:10])
break
unique_labels = list(set([row[label].decode("utf-8") for row in ds_train.metadata]))
num_classes = len(ds_train.label_encoder.classes_)
hparams = {