@@ -516,7 +516,14 @@ def load_checkpoint(self, checkpoint="best"):
516
516
517
517
return save_dict
518
518
519
- def fit (self , iterations = None , load_from_checkpoint = None , epochs = None , save_every_kth_epoch = None ):
519
+ def fit (
520
+ self ,
521
+ iterations = None ,
522
+ load_from_checkpoint = None ,
523
+ epochs = None ,
524
+ save_every_kth_epoch = None ,
525
+ progress = None ,
526
+ ):
520
527
"""Run neural network training.
521
528
522
529
Exactly one of 'iterations' or 'epochs' has to be passed.
@@ -527,6 +534,8 @@ def fit(self, iterations=None, load_from_checkpoint=None, epochs=None, save_ever
527
534
epochs [int] - how long to train, specified in epochs (default: None)
528
535
save_every_kth_epoch [int] - save checkpoints after every kth epoch separately.
529
536
The corresponding checkpoints will be saved with the naming scheme 'epoch-{epoch}.pt'. (default: None)
537
+ progress [progress_bar] - optional progress bar for integration with external tools.
538
+ Expected to follow the tqdm interface.
530
539
"""
531
540
best_metric = self ._initialize (iterations , load_from_checkpoint , epochs )
532
541
print (
@@ -547,12 +556,14 @@ def fit(self, iterations=None, load_from_checkpoint=None, epochs=None, save_ever
547
556
validate = self ._validate
548
557
print ("Training with single precision" )
549
558
550
- progress = tqdm (
551
- total = epochs * len (self .train_loader ) if iterations is None else iterations ,
552
- desc = f"Epoch { self ._epoch } " , leave = True
553
- )
554
- msg = "Epoch %i: average [s/it]: %f, current metric: %f, best metric: %f"
559
+ total_iterations = epochs * len (self .train_loader ) if iterations is None else iterations
560
+ if progress is None :
561
+ progress = tqdm (total = total_iterations , desc = f"Epoch { self ._epoch } " , leave = True )
562
+ else :
563
+ progress .total = total_iterations
564
+ progress .set_description (f"Epoch { self ._epoch } " )
555
565
566
+ msg = "Epoch %i: average [s/it]: %f, current metric: %f, best metric: %f"
556
567
train_epochs = self .max_epoch - self ._epoch
557
568
t_start = time .time ()
558
569
for _ in range (train_epochs ):
0 commit comments