diff --git a/conf/app.yaml b/conf/app.yaml index 9010e0c..9996838 100644 --- a/conf/app.yaml +++ b/conf/app.yaml @@ -21,7 +21,7 @@ training: inference: model_path: checkpoints/inference_recognition_model.ckpt num_classes: 4 - output_path: onnx_files/inference_recognition_model.onnx + output_path: results/inference_recognition_model.onnx app: build_dir: dist \ No newline at end of file diff --git a/data/training/train.py b/data/training/train.py index 2dd2284..92b74d2 100644 --- a/data/training/train.py +++ b/data/training/train.py @@ -26,19 +26,20 @@ import mobilenetv3 def train_model(): settings = get_app_settings() - dataset = settings.dataset.modulation_types + training_cfg = settings.training + dataset_cfg = settings.dataset train_flag = True batch_size = 128 epochs = 1 - checkpoint_filename = f"/Users/liyuxiao/Documents/CS/qoherent/modrec-workflow/results/interference_recognition_model" + checkpoint_filename = f"{training_cfg.checkpoint_path}" train_data = ( - "/Users/liyuxiao/Documents/CS/qoherent/modrec-workflow/data/dataset/train.h5" + f"{settings.dataset_cfg.output_dir}/train.h5" ) val_data = ( - "/Users/liyuxiao/Documents/CS/qoherent/modrec-workflow/data/dataset/val.h5" + f"{settings.dataset_cfg.output_dir}/val.h5" ) dataset_name = "Modulation Inference - Initial Model"