Skip to content

Commit 3763fcf

Browse files
Early stop fixed
1 parent de592f4 commit 3763fcf

File tree

2 files changed

+21
-8
lines changed

2 files changed

+21
-8
lines changed

boudams/trainer.py

+17-8
Original file line numberDiff line numberDiff line change
@@ -128,13 +128,22 @@ def __init__(self, optimizer, mode=PlateauModes.accuracy, **kwargs):
128128
self.lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(
129129
optimizer, mode=mode.value, **kwargs) # Max because accuracy :)
130130
self.mode = mode
131+
self.steps = 0
131132

132133
def step(self, score):
134+
scheduler_steps = self.lr_scheduler.num_bad_epochs
133135
self.lr_scheduler.step(getattr(score, self.mode.name))
134-
135-
@property
136-
def steps(self):
137-
return self.lr_scheduler.num_bad_epochs
136+
# No change in number of bad epochs =
137+
# we are progressing
138+
if scheduler_steps == self.lr_scheduler.num_bad_epochs:
139+
self.steps = 0
140+
# Otherwise, we are not
141+
else:
142+
self.steps += 1
143+
144+
if self.steps >= self.patience * 2:
145+
# If we haven't progressed even by lowering twice
146+
raise EarlyStopException("No progress for %s , stoping now... " % self.steps)
138147

139148
@property
140149
def patience(self):
@@ -230,10 +239,6 @@ def run(
230239

231240
# Run a check on saving the current model
232241
best_valid_loss = self._temp_save(fid, best_valid_loss, dev_score)
233-
234-
# Advance Learning Rate if needed
235-
lr_scheduler.step(dev_score)
236-
237242
print(f'\tTrain Loss: {train_score.loss:.3f} | Perplexity: {train_score.perplexity:7.3f} | '
238243
f' Acc.: {train_score.accuracy:.3f} | '
239244
f' Lev.: {train_score.leven:.3f} | '
@@ -246,6 +251,9 @@ def run(
246251
print(lr_scheduler)
247252
print()
248253

254+
# Advance Learning Rate if needed
255+
lr_scheduler.step(dev_score)
256+
249257
if lr_scheduler.steps >= lr_patience and lr_scheduler.lr < min_lr:
250258
raise EarlyStopException()
251259

@@ -260,6 +268,7 @@ def run(
260268
break
261269
except EarlyStopException:
262270
print("Reached plateau for too long, stopping.")
271+
break
263272

264273
best_valid_loss = self._temp_save(fid, best_valid_loss, dev_score)
265274

train.sh

+4
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
venv/bin/boudams train --device cuda --epochs 50 *.json
2+
venv/bin/boudams train --device cuda --epochs 50 *.json
3+
4+

0 commit comments

Comments
 (0)