Linting, removed hard-coded values, sped up upload times for files
Some checks failed
Modulation Recognition Demo / ria-demo (pull_request) Failing after 1m25s
Some checks failed
Modulation Recognition Demo / ria-demo (pull_request) Failing after 1m25s
This commit is contained in:
parent
1c7ddef5cb
commit
fe952b8eb6
|
@ -2,11 +2,9 @@ name: Modulation Recognition Demo
|
|||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
[main]
|
||||
branches: [main]
|
||||
pull_request:
|
||||
branches:
|
||||
[main]
|
||||
branches: [main]
|
||||
|
||||
jobs:
|
||||
ria-demo:
|
||||
|
@ -46,22 +44,19 @@ jobs:
|
|||
fi
|
||||
pip install -r requirements.txt
|
||||
|
||||
|
||||
- name: 1. Generate Recordings
|
||||
run: |
|
||||
mkdir -p data/recordings
|
||||
PYTHONPATH=. python scripts/dataset_manager/data_gen.py --output-dir data/recordings
|
||||
- name: 📦 Compress Recordings
|
||||
run: tar -czf recordings.tar.gz -C data/recordings .
|
||||
|
||||
- name: ⬆️ Upload recordings
|
||||
uses: actions/upload-artifact@v3
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: recordings
|
||||
path: data/recordings/**
|
||||
path: recordings.tar.gz
|
||||
|
||||
- name: 2. Build HDF5 Dataset
|
||||
run: |
|
||||
mkdir -p data/dataset
|
||||
PYTHONPATH=. python scripts/dataset_manager/produce_dataset.py
|
||||
mkdir -p data/dataset
|
||||
PYTHONPATH=. python scripts/dataset_manager/produce_dataset.py
|
||||
shell: bash
|
||||
|
||||
- name: ⬆️ Upload Dataset
|
||||
|
@ -72,16 +67,16 @@ jobs:
|
|||
|
||||
- name: 3. Train Model
|
||||
env:
|
||||
NO_NNPACK: 1
|
||||
PYTORCH_NO_NNPACK: 1
|
||||
NO_NNPACK: 1
|
||||
PYTORCH_NO_NNPACK: 1
|
||||
run: |
|
||||
mkdir -p checkpoint_files
|
||||
PYTHONPATH=. python scripts/model_builder/train.py 2>/dev/null
|
||||
|
||||
|
||||
- name: 4. Plot Model
|
||||
env:
|
||||
NO_NNPACK: 1
|
||||
PYTORCH_NO_NNPACK: 1
|
||||
NO_NNPACK: 1
|
||||
PYTORCH_NO_NNPACK: 1
|
||||
run: |
|
||||
PYTHONPATH=. python scripts/model_builder/plot_data.py 2>/dev/null
|
||||
|
||||
|
@ -98,7 +93,7 @@ jobs:
|
|||
run: |
|
||||
mkdir -p onnx_files
|
||||
MKL_DISABLE_FAST_MM=1 PYTHONPATH=. python scripts/application_packager/convert_to_onnx.py 2>/dev/null
|
||||
|
||||
|
||||
- name: ⬆️ Upload ONNX file
|
||||
uses: actions/upload-artifact@v3
|
||||
with:
|
||||
|
@ -108,13 +103,13 @@ jobs:
|
|||
- name: 6. Profile ONNX model
|
||||
run: |
|
||||
PYTHONPATH=. python scripts/application_packager/profile_onnx.py
|
||||
|
||||
|
||||
- name: ⬆️ Upload JSON trace
|
||||
uses: actions/upload-artifact@v3
|
||||
with:
|
||||
name: profile-data
|
||||
path: '**/onnxruntime_profile_*.json'
|
||||
|
||||
name: profile-data
|
||||
path: "**/onnxruntime_profile_*.json"
|
||||
|
||||
- name: 7. Convert ONNX graph to an ORT file
|
||||
run: |
|
||||
PYTHONPATH=. python scripts/application_packager/convert_to_ort.py
|
||||
|
|
|
@ -5,6 +5,9 @@ dataset:
|
|||
# Number of samples per recording
|
||||
recording_length: 1024
|
||||
|
||||
# Set this to scale the number of generated recordings
|
||||
mult_factor: 5
|
||||
|
||||
# List of signal modulation schemes to include in the dataset
|
||||
modulation_types:
|
||||
- bpsk
|
||||
|
@ -50,7 +53,7 @@ dataset:
|
|||
|
||||
# Training and validation split ratios; must sum to 1
|
||||
train_split: 0.8
|
||||
val_split : 0.2
|
||||
val_split: 0.2
|
||||
|
||||
training:
|
||||
# Number of training examples processed together before the model updates its weights
|
||||
|
|
|
@ -9,6 +9,7 @@ import yaml
|
|||
@dataclass
|
||||
class DataSetConfig:
|
||||
num_slices: int
|
||||
mult_factor: int
|
||||
train_split: float
|
||||
seed: int
|
||||
modulation_types: list
|
||||
|
|
|
@ -2,9 +2,9 @@ import os
|
|||
|
||||
import numpy as np
|
||||
import torch
|
||||
from scripts.model_builder.mobilenetv3 import RFClassifier, mobilenetv3
|
||||
|
||||
from helpers.app_settings import get_app_settings
|
||||
from scripts.model_builder.mobilenetv3 import RFClassifier, mobilenetv3
|
||||
|
||||
|
||||
def convert_to_onnx(ckpt_path: str, fp16: bool = False) -> None:
|
||||
|
|
|
@ -29,7 +29,7 @@ def generate_modulated_signals(output_dir: str) -> None:
|
|||
|
||||
for modulation in settings.modulation_types:
|
||||
for snr in np.arange(settings.snr_start, settings.snr_stop, settings.snr_step):
|
||||
for i in range(3):
|
||||
for _ in range(settings.mult_factor):
|
||||
recording_length = settings.recording_length
|
||||
beta = (
|
||||
settings.beta
|
||||
|
|
|
@ -49,8 +49,6 @@ def write_hdf5_file(records: List, output_path: str, dataset_name: str = "data")
|
|||
int(md["sps"]),
|
||||
)
|
||||
|
||||
first_rec, _ = records[0] # records[0] is a tuple of (data, md)
|
||||
|
||||
with h5py.File(output_path, "w") as hf:
|
||||
data_arr = np.stack([rec[0] for rec in records])
|
||||
dset = hf.create_dataset(dataset_name, data=data_arr, compression="gzip")
|
||||
|
|
Loading…
Reference in New Issue
Block a user