Linting, removed hard-coded values, sped up upload times for files
Some checks failed
Modulation Recognition Demo / ria-demo (pull_request) Failing after 1m25s
Some checks failed
Modulation Recognition Demo / ria-demo (pull_request) Failing after 1m25s
This commit is contained in:
parent
1c7ddef5cb
commit
fe952b8eb6
|
@ -2,11 +2,9 @@ name: Modulation Recognition Demo
|
||||||
|
|
||||||
on:
|
on:
|
||||||
push:
|
push:
|
||||||
branches:
|
branches: [main]
|
||||||
[main]
|
|
||||||
pull_request:
|
pull_request:
|
||||||
branches:
|
branches: [main]
|
||||||
[main]
|
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
ria-demo:
|
ria-demo:
|
||||||
|
@ -46,22 +44,19 @@ jobs:
|
||||||
fi
|
fi
|
||||||
pip install -r requirements.txt
|
pip install -r requirements.txt
|
||||||
|
|
||||||
|
- name: 📦 Compress Recordings
|
||||||
- name: 1. Generate Recordings
|
run: tar -czf recordings.tar.gz -C data/recordings .
|
||||||
run: |
|
|
||||||
mkdir -p data/recordings
|
|
||||||
PYTHONPATH=. python scripts/dataset_manager/data_gen.py --output-dir data/recordings
|
|
||||||
|
|
||||||
- name: ⬆️ Upload recordings
|
- name: ⬆️ Upload recordings
|
||||||
uses: actions/upload-artifact@v3
|
uses: actions/upload-artifact@v4
|
||||||
with:
|
with:
|
||||||
name: recordings
|
name: recordings
|
||||||
path: data/recordings/**
|
path: recordings.tar.gz
|
||||||
|
|
||||||
- name: 2. Build HDF5 Dataset
|
- name: 2. Build HDF5 Dataset
|
||||||
run: |
|
run: |
|
||||||
mkdir -p data/dataset
|
mkdir -p data/dataset
|
||||||
PYTHONPATH=. python scripts/dataset_manager/produce_dataset.py
|
PYTHONPATH=. python scripts/dataset_manager/produce_dataset.py
|
||||||
shell: bash
|
shell: bash
|
||||||
|
|
||||||
- name: ⬆️ Upload Dataset
|
- name: ⬆️ Upload Dataset
|
||||||
|
@ -72,16 +67,16 @@ jobs:
|
||||||
|
|
||||||
- name: 3. Train Model
|
- name: 3. Train Model
|
||||||
env:
|
env:
|
||||||
NO_NNPACK: 1
|
NO_NNPACK: 1
|
||||||
PYTORCH_NO_NNPACK: 1
|
PYTORCH_NO_NNPACK: 1
|
||||||
run: |
|
run: |
|
||||||
mkdir -p checkpoint_files
|
mkdir -p checkpoint_files
|
||||||
PYTHONPATH=. python scripts/model_builder/train.py 2>/dev/null
|
PYTHONPATH=. python scripts/model_builder/train.py 2>/dev/null
|
||||||
|
|
||||||
- name: 4. Plot Model
|
- name: 4. Plot Model
|
||||||
env:
|
env:
|
||||||
NO_NNPACK: 1
|
NO_NNPACK: 1
|
||||||
PYTORCH_NO_NNPACK: 1
|
PYTORCH_NO_NNPACK: 1
|
||||||
run: |
|
run: |
|
||||||
PYTHONPATH=. python scripts/model_builder/plot_data.py 2>/dev/null
|
PYTHONPATH=. python scripts/model_builder/plot_data.py 2>/dev/null
|
||||||
|
|
||||||
|
@ -98,7 +93,7 @@ jobs:
|
||||||
run: |
|
run: |
|
||||||
mkdir -p onnx_files
|
mkdir -p onnx_files
|
||||||
MKL_DISABLE_FAST_MM=1 PYTHONPATH=. python scripts/application_packager/convert_to_onnx.py 2>/dev/null
|
MKL_DISABLE_FAST_MM=1 PYTHONPATH=. python scripts/application_packager/convert_to_onnx.py 2>/dev/null
|
||||||
|
|
||||||
- name: ⬆️ Upload ONNX file
|
- name: ⬆️ Upload ONNX file
|
||||||
uses: actions/upload-artifact@v3
|
uses: actions/upload-artifact@v3
|
||||||
with:
|
with:
|
||||||
|
@ -108,13 +103,13 @@ jobs:
|
||||||
- name: 6. Profile ONNX model
|
- name: 6. Profile ONNX model
|
||||||
run: |
|
run: |
|
||||||
PYTHONPATH=. python scripts/application_packager/profile_onnx.py
|
PYTHONPATH=. python scripts/application_packager/profile_onnx.py
|
||||||
|
|
||||||
- name: ⬆️ Upload JSON trace
|
- name: ⬆️ Upload JSON trace
|
||||||
uses: actions/upload-artifact@v3
|
uses: actions/upload-artifact@v3
|
||||||
with:
|
with:
|
||||||
name: profile-data
|
name: profile-data
|
||||||
path: '**/onnxruntime_profile_*.json'
|
path: "**/onnxruntime_profile_*.json"
|
||||||
|
|
||||||
- name: 7. Convert ONNX graph to an ORT file
|
- name: 7. Convert ONNX graph to an ORT file
|
||||||
run: |
|
run: |
|
||||||
PYTHONPATH=. python scripts/application_packager/convert_to_ort.py
|
PYTHONPATH=. python scripts/application_packager/convert_to_ort.py
|
||||||
|
|
|
@ -5,6 +5,9 @@ dataset:
|
||||||
# Number of samples per recording
|
# Number of samples per recording
|
||||||
recording_length: 1024
|
recording_length: 1024
|
||||||
|
|
||||||
|
# Set this to scale the number of generated recordings
|
||||||
|
mult_factor: 5
|
||||||
|
|
||||||
# List of signal modulation schemes to include in the dataset
|
# List of signal modulation schemes to include in the dataset
|
||||||
modulation_types:
|
modulation_types:
|
||||||
- bpsk
|
- bpsk
|
||||||
|
@ -50,7 +53,7 @@ dataset:
|
||||||
|
|
||||||
# Training and validation split ratios; must sum to 1
|
# Training and validation split ratios; must sum to 1
|
||||||
train_split: 0.8
|
train_split: 0.8
|
||||||
val_split : 0.2
|
val_split: 0.2
|
||||||
|
|
||||||
training:
|
training:
|
||||||
# Number of training examples processed together before the model updates its weights
|
# Number of training examples processed together before the model updates its weights
|
||||||
|
|
|
@ -9,6 +9,7 @@ import yaml
|
||||||
@dataclass
|
@dataclass
|
||||||
class DataSetConfig:
|
class DataSetConfig:
|
||||||
num_slices: int
|
num_slices: int
|
||||||
|
mult_factor: int
|
||||||
train_split: float
|
train_split: float
|
||||||
seed: int
|
seed: int
|
||||||
modulation_types: list
|
modulation_types: list
|
||||||
|
|
|
@ -2,9 +2,9 @@ import os
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from scripts.model_builder.mobilenetv3 import RFClassifier, mobilenetv3
|
|
||||||
|
|
||||||
from helpers.app_settings import get_app_settings
|
from helpers.app_settings import get_app_settings
|
||||||
|
from scripts.model_builder.mobilenetv3 import RFClassifier, mobilenetv3
|
||||||
|
|
||||||
|
|
||||||
def convert_to_onnx(ckpt_path: str, fp16: bool = False) -> None:
|
def convert_to_onnx(ckpt_path: str, fp16: bool = False) -> None:
|
||||||
|
|
|
@ -29,7 +29,7 @@ def generate_modulated_signals(output_dir: str) -> None:
|
||||||
|
|
||||||
for modulation in settings.modulation_types:
|
for modulation in settings.modulation_types:
|
||||||
for snr in np.arange(settings.snr_start, settings.snr_stop, settings.snr_step):
|
for snr in np.arange(settings.snr_start, settings.snr_stop, settings.snr_step):
|
||||||
for i in range(3):
|
for _ in range(settings.mult_factor):
|
||||||
recording_length = settings.recording_length
|
recording_length = settings.recording_length
|
||||||
beta = (
|
beta = (
|
||||||
settings.beta
|
settings.beta
|
||||||
|
|
|
@ -49,8 +49,6 @@ def write_hdf5_file(records: List, output_path: str, dataset_name: str = "data")
|
||||||
int(md["sps"]),
|
int(md["sps"]),
|
||||||
)
|
)
|
||||||
|
|
||||||
first_rec, _ = records[0] # records[0] is a tuple of (data, md)
|
|
||||||
|
|
||||||
with h5py.File(output_path, "w") as hf:
|
with h5py.File(output_path, "w") as hf:
|
||||||
data_arr = np.stack([rec[0] for rec in records])
|
data_arr = np.stack([rec[0] for rec in records])
|
||||||
dset = hf.create_dataset(dataset_name, data=data_arr, compression="gzip")
|
dset = hf.create_dataset(dataset_name, data=data_arr, compression="gzip")
|
||||||
|
|
Loading…
Reference in New Issue
Block a user