Skip to content

Commit 0db1b9b

Browse files
authored
[Tools] Support replacing the ${key} with the value of cfg.key (#7492)
* Support replacing config * Support replacing config * Add unit test for replace_cfig * pre-commit * fix * modify the docstring * rename function * fix a bug * fix a bug and simplify the code * simplify the code * add replace_cfg_vals for some scripts * add replace_cfg_vals for some scripts * add some unit tests
1 parent b1f40ef commit 0db1b9b

12 files changed

+194
-10
lines changed

mmdet/utils/__init__.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from .compat_config import compat_cfg
44
from .logger import get_caller_name, get_root_logger, log_img_scale
55
from .misc import find_latest_checkpoint, update_data_root
6+
from .replace_cfg_vals import replace_cfg_vals
67
from .setup_env import setup_multi_processes
78
from .split_batch import split_batch
89
from .util_distribution import build_ddp, build_dp, get_device
@@ -11,5 +12,5 @@
1112
'get_root_logger', 'collect_env', 'find_latest_checkpoint',
1213
'update_data_root', 'setup_multi_processes', 'get_caller_name',
1314
'log_img_scale', 'compat_cfg', 'split_batch', 'build_ddp', 'build_dp',
14-
'get_device'
15+
'get_device', 'replace_cfg_vals'
1516
]

mmdet/utils/replace_cfg_vals.py

+70
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
# Copyright (c) OpenMMLab. All rights reserved.
2+
import re
3+
4+
from mmcv.utils import Config
5+
6+
7+
def replace_cfg_vals(ori_cfg):
8+
"""Replace the string "${key}" with the corresponding value.
9+
10+
Replace the "${key}" with the value of ori_cfg.key in the config. And
11+
support replacing the chained ${key}. Such as, replace "${key0.key1}"
12+
with the value of cfg.key0.key1. Code is modified from `vars.py
13+
< https://github.com/microsoft/SoftTeacher/blob/main/ssod/utils/vars.py>`_ # noqa: E501
14+
15+
Args:
16+
ori_cfg (mmcv.utils.config.Config):
17+
The origin config with "${key}" generated from a file.
18+
19+
Returns:
20+
updated_cfg [mmcv.utils.config.Config]:
21+
The config with "${key}" replaced by the corresponding value.
22+
"""
23+
24+
def get_value(cfg, key):
25+
for k in key.split('.'):
26+
cfg = cfg[k]
27+
return cfg
28+
29+
def replace_value(cfg):
30+
if isinstance(cfg, dict):
31+
return {key: replace_value(value) for key, value in cfg.items()}
32+
elif isinstance(cfg, list):
33+
return [replace_value(item) for item in cfg]
34+
elif isinstance(cfg, tuple):
35+
return tuple([replace_value(item) for item in cfg])
36+
elif isinstance(cfg, str):
37+
# the format of string cfg may be:
38+
# 1) "${key}", which will be replaced with cfg.key directly
39+
# 2) "xxx${key}xxx" or "xxx${key1}xxx${key2}xxx",
40+
# which will be replaced with the string of the cfg.key
41+
keys = pattern_key.findall(cfg)
42+
values = [get_value(ori_cfg, key[2:-1]) for key in keys]
43+
if len(keys) == 1 and keys[0] == cfg:
44+
# the format of string cfg is "${key}"
45+
cfg = values[0]
46+
else:
47+
for key, value in zip(keys, values):
48+
# the format of string cfg is
49+
# "xxx${key}xxx" or "xxx${key1}xxx${key2}xxx"
50+
assert not isinstance(value, (dict, list, tuple)), \
51+
f'for the format of string cfg is ' \
52+
f"'xxxxx${key}xxxxx' or 'xxx${key}xxx${key}xxx', " \
53+
f"the type of the value of '${key}' " \
54+
f'can not be dict, list, or tuple' \
55+
f'but you input {type(value)} in {cfg}'
56+
cfg = cfg.replace(key, str(value))
57+
return cfg
58+
else:
59+
return cfg
60+
61+
# the pattern of string "${key}"
62+
pattern_key = re.compile(r'\$\{[a-zA-Z\d_.]*\}')
63+
# the type of ori_cfg._cfg_dict is mmcv.utils.config.ConfigDict
64+
updated_cfg = Config(
65+
replace_value(ori_cfg._cfg_dict), filename=ori_cfg.filename)
66+
# replace the model with model_wrapper
67+
if updated_cfg.get('model_wrapper', None) is not None:
68+
updated_cfg.model = updated_cfg.model_wrapper
69+
updated_cfg.pop('model_wrapper')
70+
return updated_cfg
+83
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
import os.path as osp
2+
import tempfile
3+
from copy import deepcopy
4+
5+
import pytest
6+
from mmcv.utils import Config
7+
8+
from mmdet.utils import replace_cfg_vals
9+
10+
11+
def test_replace_cfg_vals():
12+
temp_file = tempfile.NamedTemporaryFile()
13+
cfg_path = f'{temp_file.name}.py'
14+
with open(cfg_path, 'w') as f:
15+
f.write('configs')
16+
17+
ori_cfg_dict = dict()
18+
ori_cfg_dict['cfg_name'] = osp.basename(temp_file.name)
19+
ori_cfg_dict['work_dir'] = 'work_dirs/${cfg_name}/${percent}/${fold}'
20+
ori_cfg_dict['percent'] = 5
21+
ori_cfg_dict['fold'] = 1
22+
ori_cfg_dict['model_wrapper'] = dict(
23+
type='SoftTeacher', detector='${model}')
24+
ori_cfg_dict['model'] = dict(
25+
type='FasterRCNN',
26+
backbone=dict(type='ResNet'),
27+
neck=dict(type='FPN'),
28+
rpn_head=dict(type='RPNHead'),
29+
roi_head=dict(type='StandardRoIHead'),
30+
train_cfg=dict(
31+
rpn=dict(
32+
assigner=dict(type='MaxIoUAssigner'),
33+
sampler=dict(type='RandomSampler'),
34+
),
35+
rpn_proposal=dict(nms=dict(type='nms', iou_threshold=0.7)),
36+
rcnn=dict(
37+
assigner=dict(type='MaxIoUAssigner'),
38+
sampler=dict(type='RandomSampler'),
39+
),
40+
),
41+
test_cfg=dict(
42+
rpn=dict(nms=dict(type='nms', iou_threshold=0.7)),
43+
rcnn=dict(nms=dict(type='nms', iou_threshold=0.5)),
44+
),
45+
)
46+
ori_cfg_dict['iou_threshold'] = dict(
47+
rpn_proposal_nms='${model.train_cfg.rpn_proposal.nms.iou_threshold}',
48+
test_rpn_nms='${model.test_cfg.rpn.nms.iou_threshold}',
49+
test_rcnn_nms='${model.test_cfg.rcnn.nms.iou_threshold}',
50+
)
51+
52+
ori_cfg_dict['str'] = 'Hello, world!'
53+
ori_cfg_dict['dict'] = {'Hello': 'world!'}
54+
ori_cfg_dict['list'] = [
55+
'Hello, world!',
56+
]
57+
ori_cfg_dict['tuple'] = ('Hello, world!', )
58+
ori_cfg_dict['test_str'] = 'xxx${str}xxx'
59+
60+
ori_cfg = Config(ori_cfg_dict, filename=cfg_path)
61+
updated_cfg = replace_cfg_vals(deepcopy(ori_cfg))
62+
63+
assert updated_cfg.work_dir \
64+
== f'work_dirs/{osp.basename(temp_file.name)}/5/1'
65+
assert updated_cfg.model.detector == ori_cfg.model
66+
assert updated_cfg.iou_threshold.rpn_proposal_nms \
67+
== ori_cfg.model.train_cfg.rpn_proposal.nms.iou_threshold
68+
assert updated_cfg.test_str == 'xxxHello, world!xxx'
69+
ori_cfg_dict['test_dict'] = 'xxx${dict}xxx'
70+
ori_cfg_dict['test_list'] = 'xxx${list}xxx'
71+
ori_cfg_dict['test_tuple'] = 'xxx${tuple}xxx'
72+
with pytest.raises(AssertionError):
73+
cfg = deepcopy(ori_cfg)
74+
cfg['test_dict'] = 'xxx${dict}xxx'
75+
updated_cfg = replace_cfg_vals(cfg)
76+
with pytest.raises(AssertionError):
77+
cfg = deepcopy(ori_cfg)
78+
cfg['test_list'] = 'xxx${list}xxx'
79+
updated_cfg = replace_cfg_vals(cfg)
80+
with pytest.raises(AssertionError):
81+
cfg = deepcopy(ori_cfg)
82+
cfg['test_tuple'] = 'xxx${tuple}xxx'
83+
updated_cfg = replace_cfg_vals(cfg)

tools/analysis_tools/analyze_results.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from mmdet.core.evaluation import eval_map
1010
from mmdet.core.visualization import imshow_gt_det_bboxes
1111
from mmdet.datasets import build_dataset, get_loading_pipeline
12-
from mmdet.utils import update_data_root
12+
from mmdet.utils import replace_cfg_vals, update_data_root
1313

1414

1515
def bbox_map_eval(det_result, annotation):
@@ -188,6 +188,9 @@ def main():
188188

189189
cfg = Config.fromfile(args.config)
190190

191+
# replace the ${key} with the value of cfg.key
192+
cfg = replace_cfg_vals(cfg)
193+
191194
# update data root according to MMDET_DATASETS
192195
update_data_root(cfg)
193196

tools/analysis_tools/benchmark.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from mmdet.datasets import (build_dataloader, build_dataset,
1414
replace_ImageToTensor)
1515
from mmdet.models import build_detector
16-
from mmdet.utils import update_data_root
16+
from mmdet.utils import replace_cfg_vals, update_data_root
1717

1818

1919
def parse_args():
@@ -172,6 +172,9 @@ def main():
172172

173173
cfg = Config.fromfile(args.config)
174174

175+
# replace the ${key} with the value of cfg.key
176+
cfg = replace_cfg_vals(cfg)
177+
175178
# update data root according to MMDET_DATASETS
176179
update_data_root(cfg)
177180

tools/analysis_tools/confusion_matrix.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
from mmdet.core.evaluation.bbox_overlaps import bbox_overlaps
1212
from mmdet.datasets import build_dataset
13-
from mmdet.utils import update_data_root
13+
from mmdet.utils import replace_cfg_vals, update_data_root
1414

1515

1616
def parse_args():
@@ -232,6 +232,9 @@ def main():
232232

233233
cfg = Config.fromfile(args.config)
234234

235+
# replace the ${key} with the value of cfg.key
236+
cfg = replace_cfg_vals(cfg)
237+
235238
# update data root according to MMDET_DATASETS
236239
update_data_root(cfg)
237240

tools/analysis_tools/eval_metric.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from mmcv import Config, DictAction
66

77
from mmdet.datasets import build_dataset
8-
from mmdet.utils import update_data_root
8+
from mmdet.utils import replace_cfg_vals, update_data_root
99

1010

1111
def parse_args():
@@ -50,6 +50,9 @@ def main():
5050

5151
cfg = Config.fromfile(args.config)
5252

53+
# replace the ${key} with the value of cfg.key
54+
cfg = replace_cfg_vals(cfg)
55+
5356
# update data root according to MMDET_DATASETS
5457
update_data_root(cfg)
5558

tools/analysis_tools/optimize_anchors.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929

3030
from mmdet.core import bbox_cxcywh_to_xyxy, bbox_overlaps, bbox_xyxy_to_cxcywh
3131
from mmdet.datasets import build_dataset
32-
from mmdet.utils import get_root_logger, update_data_root
32+
from mmdet.utils import get_root_logger, replace_cfg_vals, update_data_root
3333

3434

3535
def parse_args():
@@ -325,6 +325,9 @@ def main():
325325
cfg = args.config
326326
cfg = Config.fromfile(cfg)
327327

328+
# replace the ${key} with the value of cfg.key
329+
cfg = replace_cfg_vals(cfg)
330+
328331
# update data root according to MMDET_DATASETS
329332
update_data_root(cfg)
330333

tools/misc/browse_dataset.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from mmdet.core.utils import mask2ndarray
1212
from mmdet.core.visualization import imshow_det_bboxes
1313
from mmdet.datasets.builder import build_dataset
14-
from mmdet.utils import update_data_root
14+
from mmdet.utils import replace_cfg_vals, update_data_root
1515

1616

1717
def parse_args():
@@ -57,6 +57,9 @@ def skip_pipeline_steps(config):
5757

5858
cfg = Config.fromfile(config_path)
5959

60+
# replace the ${key} with the value of cfg.key
61+
cfg = replace_cfg_vals(cfg)
62+
6063
# update data root according to MMDET_DATASETS
6164
update_data_root(cfg)
6265

tools/misc/print_config.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
from mmcv import Config, DictAction
66

7-
from mmdet.utils import update_data_root
7+
from mmdet.utils import replace_cfg_vals, update_data_root
88

99

1010
def parse_args():
@@ -45,6 +45,9 @@ def main():
4545

4646
cfg = Config.fromfile(args.config)
4747

48+
# replace the ${key} with the value of cfg.key
49+
cfg = replace_cfg_vals(cfg)
50+
4851
# update data root according to MMDET_DATASETS
4952
update_data_root(cfg)
5053

tools/test.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,8 @@
1717
replace_ImageToTensor)
1818
from mmdet.models import build_detector
1919
from mmdet.utils import (build_ddp, build_dp, compat_cfg, get_device,
20-
setup_multi_processes, update_data_root)
20+
replace_cfg_vals, setup_multi_processes,
21+
update_data_root)
2122

2223

2324
def parse_args():
@@ -134,6 +135,9 @@ def main():
134135

135136
cfg = Config.fromfile(args.config)
136137

138+
# replace the ${key} with the value of cfg.key
139+
cfg = replace_cfg_vals(cfg)
140+
137141
# update data root according to MMDET_DATASETS
138142
update_data_root(cfg)
139143

tools/train.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,8 @@
1818
from mmdet.datasets import build_dataset
1919
from mmdet.models import build_detector
2020
from mmdet.utils import (collect_env, get_device, get_root_logger,
21-
setup_multi_processes, update_data_root)
21+
replace_cfg_vals, setup_multi_processes,
22+
update_data_root)
2223

2324

2425
def parse_args():
@@ -109,6 +110,9 @@ def main():
109110

110111
cfg = Config.fromfile(args.config)
111112

113+
# replace the ${key} with the value of cfg.key
114+
cfg = replace_cfg_vals(cfg)
115+
112116
# update data root according to MMDET_DATASETS
113117
update_data_root(cfg)
114118

@@ -142,6 +146,7 @@ def main():
142146
# use config filename as default work_dir if cfg.work_dir is None
143147
cfg.work_dir = osp.join('./work_dirs',
144148
osp.splitext(osp.basename(args.config))[0])
149+
145150
if args.resume_from is not None:
146151
cfg.resume_from = args.resume_from
147152
cfg.auto_resume = args.auto_resume

0 commit comments

Comments
 (0)