modrec-workflow/scripts/dataset_manager/produce_dataset.py
Liyu Xiao fe952b8eb6
Some checks failed
Modulation Recognition Demo / ria-demo (pull_request) Failing after 1m25s
Linting, removed hard-coded values, sped up upload times for files
2025-08-21 10:44:32 -04:00

167 lines
5.0 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
from typing import List
import h5py
import numpy as np
from split_dataset import split, split_recording
from utils.io import from_npy
from helpers.app_settings import DataSetConfig, get_app_settings
meta_dtype = np.dtype(
[
("rec_id", "S256"),
("snippet_idx", np.int32),
("modulation", "S32"),
("snr", np.int32),
("beta", np.float32),
("sps", np.int32),
]
)
info_dtype = np.dtype(
[
("num_records", np.int32),
("dataset_name", "S64"), # up to 64byte UTF-8 strings
("creator", "S64"),
]
)
def write_hdf5_file(records: List, output_path: str, dataset_name: str = "data") -> str:
"""
Writes a list of records to an HDF5 file.
Parameters:
records (list): List of records to be written to the file
output_path (str): Path to the output HDF5 file
dataset_name (str): Name of the dataset in the HDF5 file (default: "data")
Returns:
str: Path to the created HDF5 file
"""
meta_arr = np.empty(len(records), dtype=meta_dtype)
for i, (_, md) in enumerate(records):
meta_arr[i] = (
md["rec_id"].encode("utf-8"),
md["snippet_idx"],
md["modulation"].encode("utf-8"),
int(md["snr"]),
float(md["beta"]),
int(md["sps"]),
)
with h5py.File(output_path, "w") as hf:
data_arr = np.stack([rec[0] for rec in records])
dset = hf.create_dataset(dataset_name, data=data_arr, compression="gzip")
mg = hf.create_group("metadata")
mg.create_dataset("metadata", data=meta_arr, compression="gzip")
print(dset.shape, f"snippets created in {dataset_name}")
info_arr = np.array(
[
(
len(records),
dataset_name.encode("utf-8"),
b"generate_dataset.py", # already bytes
)
],
dtype=info_dtype,
)
mg.create_dataset("dataset_info", data=info_arr)
return output_path
def complex_to_channel(data: np.ndarray) -> np.ndarray:
"""
Converts complex-valued IQ data of shape (1, N) to a 2-channel real array of shape (2, N).
Parameters:
data (np.ndarray): Complex-valued array of shape (1, N)
Returns:
np.ndarray: Real-valued array of shape (2, N) with separate real and imaginary channels
"""
assert np.iscomplexobj(data) # check if the data is in the form a+bi
real = np.real(data[0]) # (N,)
imag = np.imag(data[0]) # (N,)
stacked = np.stack([real, imag], axis=0) # shape (2, N)
return stacked.astype(np.float32)
def generate_datasets(cfg: DataSetConfig) -> tuple:
"""
Generates a dataset from a folder of .npy files and saves it to an HDF5 file
Parameters:
cfg (DataSetConfig): Dataset configuration loaded from app.yaml
Returns:
dset (h5py.Dataset): The created dataset object
"""
parent = os.path.dirname("data/dataset")
if not parent:
os.makedirs("data/dataset", exist_ok=True)
# we assume the recordings are in .npy format
files = os.listdir("data/recordings")
if not files:
raise ValueError("No files found in the specified directory.")
records = []
for fname in files:
rec = from_npy(os.path.join("data/recordings", fname))
data = rec.data # here data is a numpy array with the shape (1, N)
data = complex_to_channel(data) # convert to 2-channel real array
md = rec.metadata # pull metadata from the recording
md.setdefault("recid", len(records))
records.append((data, md))
# split each recording into <num_slices> snippets each
records = split_recording(records, cfg.num_slices)
train_records, val_records = split(records, cfg.train_split, cfg.seed)
train_path = os.path.join("data/dataset", "train.h5")
val_path = os.path.join("data/dataset", "val.h5")
write_hdf5_file(train_records, train_path, "training_data")
write_hdf5_file(val_records, val_path, "validation_data")
return train_path, val_path
def main():
settings = get_app_settings()
dataset_cfg = settings.dataset
print("📦 Generating training and validation datasets...")
print(f" ➤ Slicing each recording into {dataset_cfg.num_slices} snippets")
print(
f" ➤ Train/Val split: {int(dataset_cfg.train_split * 100)}% / {int((1 - dataset_cfg.train_split) * 100)}%"
)
print(f" ➤ Output directory: data/dataset\n")
train_path, val_path = generate_datasets(dataset_cfg)
# Count number of samples in each file
with h5py.File(train_path, "r") as f:
num_train = f["training_data"].shape[0]
with h5py.File(val_path, "r") as f:
num_val = f["validation_data"].shape[0]
print("✅ Dataset generation complete!")
print(f" 🔹 Training samples saved to: {train_path} ({num_train} samples)")
print(f" 🔸 Validation samples saved to: {val_path} ({num_val} samples)")
if __name__ == "__main__":
main()