Documentation and formatting updates #1
|
@ -12,8 +12,8 @@ def convert_to_onnx(ckpt_path: str, fp16: bool = False) -> None:
|
|||
Convert a PyTorch model to ONNX format.
|
||||
|
||||
Parameters:
|
||||
output_path (str): The path to save the converted ONNX model.
|
||||
fp16 (bool): 16 float point percision
|
||||
ckpt_path (str): The path to save the converted ONNX model.
|
||||
fp16 (bool): 16 float point precision
|
||||
"""
|
||||
settings = get_app_settings()
|
||||
|
||||
|
@ -68,8 +68,6 @@ def convert_to_onnx(ckpt_path: str, fp16: bool = False) -> None:
|
|||
|
||||
if __name__ == "__main__":
|
||||
|
||||
settings = get_app_settings()
|
||||
|
||||
model_checkpoint = "inference_recognition_model.ckpt"
|
||||
|
||||
print("Converting to ONNX...")
|
||||
|
|
|
@ -5,8 +5,6 @@ import time
|
|||
import numpy as np
|
||||
import onnxruntime as ort
|
||||
|
||||
from helpers.app_settings import get_app_settings
|
||||
|
||||
|
||||
def profile_onnx_model(
|
||||
path_to_onnx: str, num_runs: int = 100, warmup_runs: int = 5
|
||||
|
@ -86,6 +84,5 @@ def profile_onnx_model(
|
|||
|
||||
|
||||
if __name__ == "__main__":
|
||||
settings = get_app_settings()
|
||||
output_path = os.path.join("onnx_files", "inference_recognition_model.onnx")
|
||||
profile_onnx_model(output_path)
|
||||
|
|
|
@ -50,8 +50,6 @@ def write_hdf5_file(records: List, output_path: str, dataset_name: str = "data")
|
|||
)
|
||||
|
||||
first_rec, _ = records[0] # records[0] is a tuple of (data, md)
|
||||
sample = first_rec
|
||||
shape, dtype = sample.shape, sample.dtype
|
||||
|
||||
with h5py.File(output_path, "w") as hf:
|
||||
data_arr = np.stack([rec[0] for rec in records])
|
||||
|
|
|
@ -24,11 +24,9 @@ class SqueezeExcite(nn.Module):
|
|||
def __init__(
|
||||
self,
|
||||
in_chs,
|
||||
se_ratio=0.25,
|
||||
reduced_base_chs=None,
|
||||
act_layer=nn.SiLU,
|
||||
gate_fn=torch.sigmoid,
|
||||
divisor=1,
|
||||
**_,
|
||||
):
|
||||
super(SqueezeExcite, self).__init__()
|
||||
|
@ -77,13 +75,6 @@ class GBN(torch.nn.Module):
|
|||
self.act = act
|
||||
|
||||
def forward(self, x):
|
||||
# chunks = x.chunk(int(np.ceil(x.shape[0] / self.virtual_batch_size)), 0)
|
||||
# res = [self.bn(x_) for x_ in chunks]
|
||||
# return self.drop(self.act(torch.cat(res, dim=0)))
|
||||
# x = self.bn(x)
|
||||
# x = self.act(x)
|
||||
# x = self.drop(x)
|
||||
# return x
|
||||
return self.drop(self.act(self.bn(x)))
|
||||
|
||||
|
||||
|
|
|
@ -142,5 +142,4 @@ def plot_confusion_matrix_with_counts(
|
|||
|
||||
if __name__ == "__main__":
|
||||
settings = get_app_settings()
|
||||
ckpt_path = os.path.join("checkpoint_files", "inference_recognition_model.ckpt")
|
||||
evaluate_checkpoint(ckpt_path)
|
||||
evaluate_checkpoint(os.path.join("checkpoint_files", "inference_recognition_model.ckpt"))
|
||||
|
|
Loading…
Reference in New Issue
Block a user