Skip to content

Commit fc96e3f

Browse files
authored
Add current_metric while saving checkpoints (#223)
Add current_metric while saving checkpoints
1 parent a49a357 commit fc96e3f

File tree

4 files changed

+16
-10
lines changed

4 files changed

+16
-10
lines changed

torch_em/self_training/fix_match.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@ def __init__(
136136
# functionality for saving checkpoints and initialization
137137
#
138138

139-
def save_checkpoint(self, name, best_metric, **extra_save_dict):
139+
def save_checkpoint(self, name, current_metric, best_metric, **extra_save_dict):
140140
train_loader_kwargs = get_constructor_arguments(self.train_loader)
141141
val_loader_kwargs = get_constructor_arguments(self.val_loader)
142142
extra_state = {
@@ -152,7 +152,7 @@ def save_checkpoint(self, name, best_metric, **extra_save_dict):
152152
},
153153
}
154154
extra_state.update(**extra_save_dict)
155-
super().save_checkpoint(name, best_metric, **extra_state)
155+
super().save_checkpoint(name, current_metric, best_metric, **extra_state)
156156

157157
# distribution alignment - encourages the distribution of the model's generated pseudo labels to match the marginal
158158
# distribution of pseudo labels from the source transfer

torch_em/self_training/mean_teacher.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,7 @@ def _momentum_update(self):
171171
# functionality for saving checkpoints and initialization
172172
#
173173

174-
def save_checkpoint(self, name, best_metric, **extra_save_dict):
174+
def save_checkpoint(self, name, current_metric, best_metric, **extra_save_dict):
175175
train_loader_kwargs = get_constructor_arguments(self.train_loader)
176176
val_loader_kwargs = get_constructor_arguments(self.val_loader)
177177
extra_state = {
@@ -188,7 +188,7 @@ def save_checkpoint(self, name, best_metric, **extra_save_dict):
188188
},
189189
}
190190
extra_state.update(**extra_save_dict)
191-
super().save_checkpoint(name, best_metric, **extra_state)
191+
super().save_checkpoint(name, current_metric, best_metric, **extra_state)
192192

193193
def load_checkpoint(self, checkpoint="best"):
194194
save_dict = super().load_checkpoint(checkpoint)

torch_em/trainer/default_trainer.py

+8-4
Original file line numberDiff line numberDiff line change
@@ -458,14 +458,15 @@ def _initialize(self, iterations, load_from_checkpoint, epochs=None):
458458
best_metric = np.inf
459459
return best_metric
460460

461-
def save_checkpoint(self, name, best_metric, train_time=0.0, **extra_save_dict):
461+
def save_checkpoint(self, name, current_metric, best_metric, train_time=0.0, **extra_save_dict):
462462
save_path = os.path.join(self.checkpoint_folder, f"{name}.pt")
463463
extra_init_dict = extra_save_dict.pop("init", {})
464464
save_dict = {
465465
"iteration": self._iteration,
466466
"epoch": self._epoch,
467467
"best_epoch": self._best_epoch,
468468
"best_metric": best_metric,
469+
"current_metric": current_metric,
469470
"model_state": self.model.state_dict(),
470471
"optimizer_state": self.optimizer.state_dict(),
471472
"init": self.init_data | extra_init_dict,
@@ -494,6 +495,7 @@ def load_checkpoint(self, checkpoint="best"):
494495
self._epoch = save_dict["epoch"]
495496
self._best_epoch = save_dict["best_epoch"]
496497
self.best_metric = save_dict["best_metric"]
498+
self.current_metric = save_dict["current_metric"]
497499
self.train_time = save_dict.get("train_time", 0.0)
498500

499501
model_state = save_dict["model_state"]
@@ -573,14 +575,16 @@ def fit(self, iterations=None, load_from_checkpoint=None, epochs=None, save_ever
573575
if current_metric < best_metric:
574576
best_metric = current_metric
575577
self._best_epoch = self._epoch
576-
self.save_checkpoint("best", best_metric, train_time=total_train_time)
578+
self.save_checkpoint("best", current_metric, best_metric, train_time=total_train_time)
577579

578580
# save this checkpoint as the latest checkpoint
579-
self.save_checkpoint("latest", best_metric, train_time=total_train_time)
581+
self.save_checkpoint("latest", current_metric, best_metric, train_time=total_train_time)
580582

581583
# if we save after every k-th epoch then check if we need to save now
582584
if save_every_kth_epoch is not None and (self._epoch + 1) % save_every_kth_epoch == 0:
583-
self.save_checkpoint(f"epoch-{self._epoch + 1}", best_metric, train_time=total_train_time)
585+
self.save_checkpoint(
586+
f"epoch-{self._epoch + 1}", current_metric, best_metric, train_time=total_train_time
587+
)
584588

585589
# if early stopping has been specified then check if the stopping condition is met
586590
if self.early_stopping is not None:

torch_em/trainer/spoco_trainer.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,10 @@ def _momentum_update(self):
3232
for param_model, param_teacher in zip(self.model.parameters(), self.model2.parameters()):
3333
param_teacher.data = param_teacher.data * self.momentum + param_model.data * (1. - self.momentum)
3434

35-
def save_checkpoint(self, name, best_metric, **extra_save_dict):
36-
super().save_checkpoint(name, best_metric, model2_state=self.model2.state_dict(), **extra_save_dict)
35+
def save_checkpoint(self, name, current_metric, best_metric, **extra_save_dict):
36+
super().save_checkpoint(
37+
name, current_metric, best_metric, model2_state=self.model2.state_dict(), **extra_save_dict
38+
)
3739

3840
def load_checkpoint(self, checkpoint="best"):
3941
save_dict = super().load_checkpoint(checkpoint)

0 commit comments

Comments
 (0)