@@ -458,14 +458,15 @@ def _initialize(self, iterations, load_from_checkpoint, epochs=None):
458
458
best_metric = np .inf
459
459
return best_metric
460
460
461
- def save_checkpoint (self , name , best_metric , train_time = 0.0 , ** extra_save_dict ):
461
+ def save_checkpoint (self , name , current_metric , best_metric , train_time = 0.0 , ** extra_save_dict ):
462
462
save_path = os .path .join (self .checkpoint_folder , f"{ name } .pt" )
463
463
extra_init_dict = extra_save_dict .pop ("init" , {})
464
464
save_dict = {
465
465
"iteration" : self ._iteration ,
466
466
"epoch" : self ._epoch ,
467
467
"best_epoch" : self ._best_epoch ,
468
468
"best_metric" : best_metric ,
469
+ "current_metric" : current_metric ,
469
470
"model_state" : self .model .state_dict (),
470
471
"optimizer_state" : self .optimizer .state_dict (),
471
472
"init" : self .init_data | extra_init_dict ,
@@ -494,6 +495,7 @@ def load_checkpoint(self, checkpoint="best"):
494
495
self ._epoch = save_dict ["epoch" ]
495
496
self ._best_epoch = save_dict ["best_epoch" ]
496
497
self .best_metric = save_dict ["best_metric" ]
498
+ self .current_metric = save_dict ["current_metric" ]
497
499
self .train_time = save_dict .get ("train_time" , 0.0 )
498
500
499
501
model_state = save_dict ["model_state" ]
@@ -573,14 +575,16 @@ def fit(self, iterations=None, load_from_checkpoint=None, epochs=None, save_ever
573
575
if current_metric < best_metric :
574
576
best_metric = current_metric
575
577
self ._best_epoch = self ._epoch
576
- self .save_checkpoint ("best" , best_metric , train_time = total_train_time )
578
+ self .save_checkpoint ("best" , current_metric , best_metric , train_time = total_train_time )
577
579
578
580
# save this checkpoint as the latest checkpoint
579
- self .save_checkpoint ("latest" , best_metric , train_time = total_train_time )
581
+ self .save_checkpoint ("latest" , current_metric , best_metric , train_time = total_train_time )
580
582
581
583
# if we save after every k-th epoch then check if we need to save now
582
584
if save_every_kth_epoch is not None and (self ._epoch + 1 ) % save_every_kth_epoch == 0 :
583
- self .save_checkpoint (f"epoch-{ self ._epoch + 1 } " , best_metric , train_time = total_train_time )
585
+ self .save_checkpoint (
586
+ f"epoch-{ self ._epoch + 1 } " , current_metric , best_metric , train_time = total_train_time
587
+ )
584
588
585
589
# if early stopping has been specified then check if the stopping condition is met
586
590
if self .early_stopping is not None :
0 commit comments