clean up code, adjusted default parameters

This commit is contained in:
Liyu Xiao 2025-08-21 10:56:23 -04:00
parent a298384f7e
commit 04e5f4db8a
6 changed files with 12 additions and 16 deletions

View File

@ -5,9 +5,6 @@ dataset:
# Number of samples per recording # Number of samples per recording
recording_length: 1024 recording_length: 1024
# Set this to scale the number of generated recordings
mult_factor: 5
# List of signal modulation schemes to include in the dataset # List of signal modulation schemes to include in the dataset
modulation_types: modulation_types:
- bpsk - bpsk
@ -27,7 +24,7 @@ dataset:
snr_step: 3 snr_step: 3
# Number of iterations (signal recordings) per modulation and SNR combination # Number of iterations (signal recordings) per modulation and SNR combination
num_iterations: 3 num_iterations: 100
# Modulation scheme settings; keys must match the `modulation_types` list above # Modulation scheme settings; keys must match the `modulation_types` list above
# Each entry includes: # Each entry includes:

View File

@ -9,7 +9,6 @@ import yaml
@dataclass @dataclass
class DataSetConfig: class DataSetConfig:
num_slices: int num_slices: int
mult_factor: int
train_split: float train_split: float
seed: int seed: int
modulation_types: list modulation_types: list
@ -42,7 +41,11 @@ class AppConfig:
class AppSettings: class AppSettings:
"""Application settings, to be initialized from app.yaml configuration file.""" """
Application settings,
to be initialized from
app.yaml configuration file.
"""
def __init__(self, config_file: str): def __init__(self, config_file: str):
# Load the YAML configuration file # Load the YAML configuration file

View File

@ -29,7 +29,7 @@ def generate_modulated_signals(output_dir: str) -> None:
for modulation in settings.modulation_types: for modulation in settings.modulation_types:
for snr in np.arange(settings.snr_start, settings.snr_stop, settings.snr_step): for snr in np.arange(settings.snr_start, settings.snr_stop, settings.snr_step):
for _ in range(settings.mult_factor): for _ in range(settings.num_iterations):
recording_length = settings.recording_length recording_length = settings.recording_length
beta = ( beta = (
settings.beta settings.beta

View File

@ -1,5 +1,4 @@
import lightning as L import lightning as L
import numpy as np
import timm import timm
import torch import torch
from torch import nn from torch import nn

View File

@ -2,15 +2,15 @@ import os
import numpy as np import numpy as np
import torch import torch
from sklearn.metrics import classification_report
os.environ["NNPACK"] = "0"
from matplotlib import pyplot as plt from matplotlib import pyplot as plt
from mobilenetv3 import RFClassifier, mobilenetv3 from mobilenetv3 import RFClassifier, mobilenetv3
from modulation_dataset import ModulationH5Dataset from modulation_dataset import ModulationH5Dataset
from sklearn.metrics import classification_report
from helpers.app_settings import get_app_settings from helpers.app_settings import get_app_settings
os.environ["NNPACK"] = "0"
def load_validation_data(): def load_validation_data():
val_dataset = ModulationH5Dataset( val_dataset = ModulationH5Dataset(

View File

@ -1,23 +1,22 @@
import os import os
import sys import sys
os.environ["NNPACK"] = "0"
import lightning as L import lightning as L
import mobilenetv3 import mobilenetv3
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
import torchmetrics import torchmetrics
from lightning.pytorch.callbacks import ModelCheckpoint from lightning.pytorch.callbacks import ModelCheckpoint, TQDMProgressBar
from modulation_dataset import ModulationH5Dataset from modulation_dataset import ModulationH5Dataset
from helpers.app_settings import get_app_settings from helpers.app_settings import get_app_settings
os.environ["NNPACK"] = "0"
script_dir = os.path.dirname(os.path.abspath(__file__)) script_dir = os.path.dirname(os.path.abspath(__file__))
data_dir = os.path.abspath(os.path.join(script_dir, "..")) data_dir = os.path.abspath(os.path.join(script_dir, ".."))
project_root = os.path.abspath(os.path.join(os.getcwd(), "..")) project_root = os.path.abspath(os.path.join(os.getcwd(), ".."))
if project_root not in sys.path: if project_root not in sys.path:
sys.path.insert(0, project_root) sys.path.insert(0, project_root)
from lightning.pytorch.callbacks import TQDMProgressBar
class CustomProgressBar(TQDMProgressBar): class CustomProgressBar(TQDMProgressBar):
@ -59,8 +58,6 @@ def train_model():
print("X shape:", x.shape) print("X shape:", x.shape)
print("Y values:", y[:10]) print("Y values:", y[:10])
break break
unique_labels = list(set([row[label].decode("utf-8") for row in ds_train.metadata]))
num_classes = len(ds_train.label_encoder.classes_) num_classes = len(ds_train.label_encoder.classes_)
hparams = { hparams = {