Documentation and formatting updates #1

Merged
Liyux merged 12 commits from michael-review into main 2025-07-08 10:50:41 -04:00
5 changed files with 3 additions and 20 deletions
Showing only changes of commit 53d0552fd4 - Show all commits

View File

@ -12,8 +12,8 @@ def convert_to_onnx(ckpt_path: str, fp16: bool = False) -> None:
Convert a PyTorch model to ONNX format.
Parameters:
output_path (str): The path to save the converted ONNX model.
fp16 (bool): 16 float point percision
ckpt_path (str): The path to save the converted ONNX model.
fp16 (bool): 16 float point precision
"""
settings = get_app_settings()
@ -68,8 +68,6 @@ def convert_to_onnx(ckpt_path: str, fp16: bool = False) -> None:
if __name__ == "__main__":
settings = get_app_settings()
model_checkpoint = "inference_recognition_model.ckpt"
print("Converting to ONNX...")

View File

@ -5,8 +5,6 @@ import time
import numpy as np
import onnxruntime as ort
from helpers.app_settings import get_app_settings
def profile_onnx_model(
path_to_onnx: str, num_runs: int = 100, warmup_runs: int = 5
@ -86,6 +84,5 @@ def profile_onnx_model(
if __name__ == "__main__":
settings = get_app_settings()
output_path = os.path.join("onnx_files", "inference_recognition_model.onnx")
profile_onnx_model(output_path)

View File

@ -50,8 +50,6 @@ def write_hdf5_file(records: List, output_path: str, dataset_name: str = "data")
)
first_rec, _ = records[0] # records[0] is a tuple of (data, md)
sample = first_rec
shape, dtype = sample.shape, sample.dtype
with h5py.File(output_path, "w") as hf:
data_arr = np.stack([rec[0] for rec in records])

View File

@ -24,11 +24,9 @@ class SqueezeExcite(nn.Module):
def __init__(
self,
in_chs,
se_ratio=0.25,
reduced_base_chs=None,
act_layer=nn.SiLU,
gate_fn=torch.sigmoid,
divisor=1,
**_,
):
super(SqueezeExcite, self).__init__()
@ -77,13 +75,6 @@ class GBN(torch.nn.Module):
self.act = act
def forward(self, x):
# chunks = x.chunk(int(np.ceil(x.shape[0] / self.virtual_batch_size)), 0)
# res = [self.bn(x_) for x_ in chunks]
# return self.drop(self.act(torch.cat(res, dim=0)))
# x = self.bn(x)
# x = self.act(x)
# x = self.drop(x)
# return x
return self.drop(self.act(self.bn(x)))

View File

@ -142,5 +142,4 @@ def plot_confusion_matrix_with_counts(
if __name__ == "__main__":
settings = get_app_settings()
ckpt_path = os.path.join("checkpoint_files", "inference_recognition_model.ckpt")
evaluate_checkpoint(ckpt_path)
evaluate_checkpoint(os.path.join("checkpoint_files", "inference_recognition_model.ckpt"))