Skip to content

Commit 59b2d14

Browse files
feat: add convnext model and config (#201)
Co-authored-by: GuoxiaWang <mingzilaochongtu@gmail.com>
1 parent f7704ed commit 59b2d14

16 files changed

+1086
-16
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
* [Swin](./task/classification/swin/)
2424
* [DeiT](./task/classification/deit/)
2525
* [CaiT](./task/classification/cait/)
26+
* [ConvNeXt](./task/classification/convnext)
2627
* [MoCo v3](./task/ssl/mocov3/)
2728
* [MAE](./task/ssl/mae/)
2829
* [ConvMAE](./task/ssl/mae/)

plsc/engine/classification/evaluation.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ def default_eval(engine, epoch_id=0):
4141
dataset) if not engine.use_dali else engine.eval_dataloader.size
4242
max_iter = len(engine.eval_dataloader) - 1 if platform.system(
4343
) == "Windows" else len(engine.eval_dataloader)
44+
4445
for iter_id, batch in enumerate(engine.eval_dataloader):
4546
if iter_id >= max_iter:
4647
break
@@ -63,7 +64,6 @@ def default_eval(engine, epoch_id=0):
6364
custom_black_list=engine.fp16_custom_black_list,
6465
level=engine.fp16_level):
6566
out = engine.model(batch[0])
66-
6767
# calc loss
6868
if engine.eval_loss_func is not None:
6969
loss_dict = engine.eval_loss_func(out, batch[1])
@@ -132,6 +132,7 @@ def default_eval(engine, epoch_id=0):
132132
len(engine.eval_dataloader), metric_msg, time_msg, ips_msg))
133133

134134
tic = time.time()
135+
135136
if engine.use_dali:
136137
engine.eval_dataloader.reset()
137138

plsc/engine/classification/train.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,9 @@ def default_train_one_epoch(engine, epoch_id):
8484
if iter_id % engine.print_batch_step == 0:
8585
log_info(engine, batch_size, epoch_id, iter_id)
8686
tic = time.time()
87-
87+
# ema update
88+
if engine.enabled_ema:
89+
engine.ema.update()
8890
# eval model and save model if possible
8991
eval_metric_info = {
9092
"epoch": epoch_id,

plsc/engine/engine.py

Lines changed: 99 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@
3838
from plsc.optimizer import build_optimizer
3939
from plsc.utils import io
4040
from plsc.core import recompute_warp, GradScaler, param_sync
41-
41+
from plsc.models.utils import EMA
4242
from . import classification
4343
from . import recognition
4444

@@ -284,13 +284,36 @@ def worker_init_fn(worker_id):
284284
self.data_parallel_recompute = self.config[
285285
"DistributedStrategy"].get("recompute", None) is not None
286286

287+
self.enabled_ema = True if "EMA" in self.config else False
288+
if self.enabled_ema and self.mode == 'train':
289+
ema_cfg = self.config.get("EMA", {})
290+
self.ema_eval = ema_cfg.pop('ema_eval', False)
291+
self.ema_eval_start_epoch = ema_cfg.pop('eval_start_epoch', 0)
292+
if self.ema_eval:
293+
logger.info(
294+
f'You have enable ema evaluation and start from {self.ema_eval_start_epoch} epoch, and it will save the best ema state.'
295+
)
296+
else:
297+
logger.info(
298+
f'You have enable ema, and also can set ema_eval=True and eval_start_epoch to enable ema evaluation.'
299+
)
300+
self.ema = EMA(self.optimizer._param_groups, **ema_cfg)
301+
self.ema.register()
302+
287303
def train(self):
288304
assert self.mode == "train"
289305
self.best_metric = {
290306
"metric": 0.0,
291307
"epoch": 0,
292308
"global_step": 0,
293309
}
310+
311+
if self.enabled_ema and self.ema_eval:
312+
self.ema_best_metric = {
313+
"metric": 0.0,
314+
"epoch": 0,
315+
"global_step": 0,
316+
}
294317
# key:
295318
# val: metrics list word
296319
self.output_info = dict()
@@ -301,6 +324,12 @@ def train(self):
301324

302325
# load checkpoint and resume
303326
if self.config["Global"]["checkpoint"] is not None:
327+
if self.enabled_ema:
328+
ema_metric_info = io.load_ema_checkpoint(
329+
self.config["Global"]["checkpoint"] + '_ema', self.ema)
330+
if ema_metric_info is not None and self.ema_eval:
331+
self.ema_best_metric.update(ema_metric_info)
332+
304333
metric_info = io.load_checkpoint(
305334
self.config["Global"]["checkpoint"], self.model,
306335
self.optimizer, self.scaler)
@@ -362,19 +391,36 @@ def train(self):
362391
step=epoch_id,
363392
writer=self.vdl_writer)
364393

394+
if self.enabled_ema and self.ema_eval and epoch_id > self.ema_eval_start_epoch:
395+
self.ema.apply_shadow()
396+
ema_eval_metric_info = self.eval(epoch_id)
397+
398+
if ema_eval_metric_info["metric"] > self.ema_best_metric[
399+
"metric"]:
400+
self.ema_best_metric = ema_eval_metric_info.copy()
401+
io.save_ema_checkpoint(
402+
self.model,
403+
self.ema,
404+
self.output_dir,
405+
self.ema_best_metric,
406+
model_name=self.config["Model"]["name"],
407+
prefix="best_model_ema",
408+
max_num_checkpoint=self.config["Global"][
409+
"max_num_latest_checkpoint"], )
410+
411+
logger.info("[Eval][Epoch {}][ema best metric: {}]".format(
412+
epoch_id, self.ema_best_metric["metric"]))
413+
logger.scaler(
414+
name="ema_eval_metric",
415+
value=eval_metric_info["metric"],
416+
step=epoch_id,
417+
writer=self.vdl_writer)
418+
419+
self.ema.restore()
420+
365421
# save model
366-
if epoch_id % self.save_interval == 0:
367-
if self.config["Global"]["max_num_latest_checkpoint"] != 0:
368-
io.save_checkpoint(
369-
self.model,
370-
self.optimizer,
371-
self.scaler,
372-
eval_metric_info,
373-
self.output_dir,
374-
model_name=self.config["Model"]["name"],
375-
prefix="epoch_{}".format(epoch_id),
376-
max_num_checkpoint=self.config["Global"][
377-
"max_num_latest_checkpoint"], )
422+
if epoch_id % self.save_interval == 0 or epoch_id == self.config[
423+
"Global"]["epochs"]:
378424
# save the latest model
379425
io.save_checkpoint(
380426
self.model,
@@ -387,6 +433,46 @@ def train(self):
387433
max_num_checkpoint=self.config["Global"][
388434
"max_num_latest_checkpoint"], )
389435

436+
if self.config["Global"]["max_num_latest_checkpoint"] != 0:
437+
io.save_checkpoint(
438+
self.model,
439+
self.optimizer,
440+
self.scaler,
441+
eval_metric_info,
442+
self.output_dir,
443+
model_name=self.config["Model"]["name"],
444+
prefix="epoch_{}".format(epoch_id),
445+
max_num_checkpoint=self.config["Global"][
446+
"max_num_latest_checkpoint"], )
447+
448+
if self.enabled_ema:
449+
if epoch_id == self.config["Global"]["epochs"]:
450+
self.ema.apply_shadow()
451+
452+
io.save_ema_checkpoint(
453+
self.model,
454+
self.ema,
455+
self.output_dir,
456+
None,
457+
model_name=self.config["Model"]["name"],
458+
prefix="latest_ema",
459+
max_num_checkpoint=self.config["Global"][
460+
"max_num_latest_checkpoint"], )
461+
462+
if self.config["Global"]["max_num_latest_checkpoint"] != 0:
463+
io.save_ema_checkpoint(
464+
self.model,
465+
self.ema,
466+
self.output_dir,
467+
None,
468+
model_name=self.config["Model"]["name"],
469+
prefix="epoch_{}_ema".format(epoch_id),
470+
max_num_checkpoint=self.config["Global"][
471+
"max_num_latest_checkpoint"], )
472+
473+
if epoch_id == self.config["Global"]["epochs"]:
474+
self.ema.restore()
475+
390476
if self.vdl_writer is not None:
391477
self.vdl_writer.close()
392478

plsc/models/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from .mae import *
2626
from .convmae import *
2727
from .swin_transformer import *
28+
from .convnext import *
2829

2930
__all__ = ["build_model"]
3031

0 commit comments

Comments
 (0)