Added Training bar
Some checks failed
RIA Hub Workflow Demo / ria-demo (push) Has been cancelled

This commit is contained in:
Liyu Xiao 2025-05-26 14:27:53 -04:00
parent a092b92174
commit 92a0ed11e4

View File

@ -15,26 +15,6 @@ if project_root not in sys.path:
sys.path.insert(0, project_root)
class CleanProgressCallback(Callback):
"""Clean progress callback that only shows epoch summaries"""
def on_train_epoch_end(self, trainer, pl_module):
epoch = trainer.current_epoch + 1
# Get metrics
train_loss = trainer.callback_metrics.get("train_loss")
val_loss = trainer.callback_metrics.get("val_loss")
val_acc = trainer.callback_metrics.get("val_acc")
# Print clean output
print(f"Epoch {epoch}:")
if train_loss is not None:
print(f" Train Loss: {train_loss:.4f}")
if val_loss is not None:
print(f" Val Loss: {val_loss:.4f}")
if val_acc is not None:
print(f" Val Acc: {val_acc:.4f}")
print("-" * 30)
def train_model():
@ -149,12 +129,10 @@ def train_model():
enable_version_counter=False,
)
clean_progress = CleanProgressCallback()
trainer = L.Trainer(
max_epochs=epochs,
callbacks=[checkpoint_callback, clean_progress],
accelerator="gpu",
callbacks=[checkpoint_callback],
accelerator="cpu",
devices=1,
benchmark=True,
precision="bf16-mixed",