Skip to content

Commit 07a8cf5

Browse files
Enable passing external pbar to trainer
1 parent daedad0 commit 07a8cf5

File tree

1 file changed

+17
-6
lines changed

1 file changed

+17
-6
lines changed

torch_em/trainer/default_trainer.py

+17-6
Original file line numberDiff line numberDiff line change
@@ -516,7 +516,14 @@ def load_checkpoint(self, checkpoint="best"):
516516

517517
return save_dict
518518

519-
def fit(self, iterations=None, load_from_checkpoint=None, epochs=None, save_every_kth_epoch=None):
519+
def fit(
520+
self,
521+
iterations=None,
522+
load_from_checkpoint=None,
523+
epochs=None,
524+
save_every_kth_epoch=None,
525+
progress=None,
526+
):
520527
"""Run neural network training.
521528
522529
Exactly one of 'iterations' or 'epochs' has to be passed.
@@ -527,6 +534,8 @@ def fit(self, iterations=None, load_from_checkpoint=None, epochs=None, save_ever
527534
epochs [int] - how long to train, specified in epochs (default: None)
528535
save_every_kth_epoch [int] - save checkpoints after every kth epoch separately.
529536
The corresponding checkpoints will be saved with the naming scheme 'epoch-{epoch}.pt'. (default: None)
537+
progress [progress_bar] - optional progress bar for integration with external tools.
538+
Expected to follow the tqdm interface.
530539
"""
531540
best_metric = self._initialize(iterations, load_from_checkpoint, epochs)
532541
print(
@@ -547,12 +556,14 @@ def fit(self, iterations=None, load_from_checkpoint=None, epochs=None, save_ever
547556
validate = self._validate
548557
print("Training with single precision")
549558

550-
progress = tqdm(
551-
total=epochs * len(self.train_loader) if iterations is None else iterations,
552-
desc=f"Epoch {self._epoch}", leave=True
553-
)
554-
msg = "Epoch %i: average [s/it]: %f, current metric: %f, best metric: %f"
559+
total_iterations = epochs * len(self.train_loader) if iterations is None else iterations
560+
if progress is None:
561+
progress = tqdm(total=total_iterations, desc=f"Epoch {self._epoch}", leave=True)
562+
else:
563+
progress.total = total_iterations
564+
progress.set_description(f"Epoch {self._epoch}")
555565

566+
msg = "Epoch %i: average [s/it]: %f, current metric: %f, best metric: %f"
556567
train_epochs = self.max_epoch - self._epoch
557568
t_start = time.time()
558569
for _ in range(train_epochs):

0 commit comments

Comments
 (0)