38
38
from plsc .optimizer import build_optimizer
39
39
from plsc .utils import io
40
40
from plsc .core import recompute_warp , GradScaler , param_sync
41
-
41
+ from plsc . models . utils import EMA
42
42
from . import classification
43
43
from . import recognition
44
44
@@ -284,13 +284,36 @@ def worker_init_fn(worker_id):
284
284
self .data_parallel_recompute = self .config [
285
285
"DistributedStrategy" ].get ("recompute" , None ) is not None
286
286
287
+ self .enabled_ema = True if "EMA" in self .config else False
288
+ if self .enabled_ema and self .mode == 'train' :
289
+ ema_cfg = self .config .get ("EMA" , {})
290
+ self .ema_eval = ema_cfg .pop ('ema_eval' , False )
291
+ self .ema_eval_start_epoch = ema_cfg .pop ('eval_start_epoch' , 0 )
292
+ if self .ema_eval :
293
+ logger .info (
294
+ f'You have enable ema evaluation and start from { self .ema_eval_start_epoch } epoch, and it will save the best ema state.'
295
+ )
296
+ else :
297
+ logger .info (
298
+ f'You have enable ema, and also can set ema_eval=True and eval_start_epoch to enable ema evaluation.'
299
+ )
300
+ self .ema = EMA (self .optimizer ._param_groups , ** ema_cfg )
301
+ self .ema .register ()
302
+
287
303
def train (self ):
288
304
assert self .mode == "train"
289
305
self .best_metric = {
290
306
"metric" : 0.0 ,
291
307
"epoch" : 0 ,
292
308
"global_step" : 0 ,
293
309
}
310
+
311
+ if self .enabled_ema and self .ema_eval :
312
+ self .ema_best_metric = {
313
+ "metric" : 0.0 ,
314
+ "epoch" : 0 ,
315
+ "global_step" : 0 ,
316
+ }
294
317
# key:
295
318
# val: metrics list word
296
319
self .output_info = dict ()
@@ -301,6 +324,12 @@ def train(self):
301
324
302
325
# load checkpoint and resume
303
326
if self .config ["Global" ]["checkpoint" ] is not None :
327
+ if self .enabled_ema :
328
+ ema_metric_info = io .load_ema_checkpoint (
329
+ self .config ["Global" ]["checkpoint" ] + '_ema' , self .ema )
330
+ if ema_metric_info is not None and self .ema_eval :
331
+ self .ema_best_metric .update (ema_metric_info )
332
+
304
333
metric_info = io .load_checkpoint (
305
334
self .config ["Global" ]["checkpoint" ], self .model ,
306
335
self .optimizer , self .scaler )
@@ -362,19 +391,36 @@ def train(self):
362
391
step = epoch_id ,
363
392
writer = self .vdl_writer )
364
393
394
+ if self .enabled_ema and self .ema_eval and epoch_id > self .ema_eval_start_epoch :
395
+ self .ema .apply_shadow ()
396
+ ema_eval_metric_info = self .eval (epoch_id )
397
+
398
+ if ema_eval_metric_info ["metric" ] > self .ema_best_metric [
399
+ "metric" ]:
400
+ self .ema_best_metric = ema_eval_metric_info .copy ()
401
+ io .save_ema_checkpoint (
402
+ self .model ,
403
+ self .ema ,
404
+ self .output_dir ,
405
+ self .ema_best_metric ,
406
+ model_name = self .config ["Model" ]["name" ],
407
+ prefix = "best_model_ema" ,
408
+ max_num_checkpoint = self .config ["Global" ][
409
+ "max_num_latest_checkpoint" ], )
410
+
411
+ logger .info ("[Eval][Epoch {}][ema best metric: {}]" .format (
412
+ epoch_id , self .ema_best_metric ["metric" ]))
413
+ logger .scaler (
414
+ name = "ema_eval_metric" ,
415
+ value = eval_metric_info ["metric" ],
416
+ step = epoch_id ,
417
+ writer = self .vdl_writer )
418
+
419
+ self .ema .restore ()
420
+
365
421
# save model
366
- if epoch_id % self .save_interval == 0 :
367
- if self .config ["Global" ]["max_num_latest_checkpoint" ] != 0 :
368
- io .save_checkpoint (
369
- self .model ,
370
- self .optimizer ,
371
- self .scaler ,
372
- eval_metric_info ,
373
- self .output_dir ,
374
- model_name = self .config ["Model" ]["name" ],
375
- prefix = "epoch_{}" .format (epoch_id ),
376
- max_num_checkpoint = self .config ["Global" ][
377
- "max_num_latest_checkpoint" ], )
422
+ if epoch_id % self .save_interval == 0 or epoch_id == self .config [
423
+ "Global" ]["epochs" ]:
378
424
# save the latest model
379
425
io .save_checkpoint (
380
426
self .model ,
@@ -387,6 +433,46 @@ def train(self):
387
433
max_num_checkpoint = self .config ["Global" ][
388
434
"max_num_latest_checkpoint" ], )
389
435
436
+ if self .config ["Global" ]["max_num_latest_checkpoint" ] != 0 :
437
+ io .save_checkpoint (
438
+ self .model ,
439
+ self .optimizer ,
440
+ self .scaler ,
441
+ eval_metric_info ,
442
+ self .output_dir ,
443
+ model_name = self .config ["Model" ]["name" ],
444
+ prefix = "epoch_{}" .format (epoch_id ),
445
+ max_num_checkpoint = self .config ["Global" ][
446
+ "max_num_latest_checkpoint" ], )
447
+
448
+ if self .enabled_ema :
449
+ if epoch_id == self .config ["Global" ]["epochs" ]:
450
+ self .ema .apply_shadow ()
451
+
452
+ io .save_ema_checkpoint (
453
+ self .model ,
454
+ self .ema ,
455
+ self .output_dir ,
456
+ None ,
457
+ model_name = self .config ["Model" ]["name" ],
458
+ prefix = "latest_ema" ,
459
+ max_num_checkpoint = self .config ["Global" ][
460
+ "max_num_latest_checkpoint" ], )
461
+
462
+ if self .config ["Global" ]["max_num_latest_checkpoint" ] != 0 :
463
+ io .save_ema_checkpoint (
464
+ self .model ,
465
+ self .ema ,
466
+ self .output_dir ,
467
+ None ,
468
+ model_name = self .config ["Model" ]["name" ],
469
+ prefix = "epoch_{}_ema" .format (epoch_id ),
470
+ max_num_checkpoint = self .config ["Global" ][
471
+ "max_num_latest_checkpoint" ], )
472
+
473
+ if epoch_id == self .config ["Global" ]["epochs" ]:
474
+ self .ema .restore ()
475
+
390
476
if self .vdl_writer is not None :
391
477
self .vdl_writer .close ()
392
478
0 commit comments