Skip to content

Commit 2051e7f

Browse files
zhaoxin111ZwwWayne
authored andcommitted
Add 'get_ann_info' to dataset_wrappers (#6526)
* Add 'get_ann_info' to dataset_wrappers * fix format * Delete unimportant notes
1 parent 5fc01a4 commit 2051e7f

File tree

2 files changed

+83
-0
lines changed

2 files changed

+83
-0
lines changed

mmdet/datasets/dataset_wrappers.py

+46
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,28 @@ def get_cat_ids(self, idx):
6868
sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
6969
return self.datasets[dataset_idx].get_cat_ids(sample_idx)
7070

71+
def get_ann_info(self, idx):
72+
"""Get annotation of concatenated dataset by index.
73+
74+
Args:
75+
idx (int): Index of data.
76+
77+
Returns:
78+
dict: Annotation info of specified index.
79+
"""
80+
81+
if idx < 0:
82+
if -idx > len(self):
83+
raise ValueError(
84+
'absolute value of index should not exceed dataset length')
85+
idx = len(self) + idx
86+
dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
87+
if dataset_idx == 0:
88+
sample_idx = idx
89+
else:
90+
sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
91+
return self.datasets[dataset_idx].get_ann_info(sample_idx)
92+
7193
def evaluate(self, results, logger=None, **kwargs):
7294
"""Evaluate the results.
7395
@@ -165,6 +187,18 @@ def get_cat_ids(self, idx):
165187

166188
return self.dataset.get_cat_ids(idx % self._ori_len)
167189

190+
def get_ann_info(self, idx):
191+
"""Get annotation of repeat dataset by index.
192+
193+
Args:
194+
idx (int): Index of data.
195+
196+
Returns:
197+
dict: Annotation info of specified index.
198+
"""
199+
200+
return self.dataset.get_ann_info(idx % self._ori_len)
201+
168202
def __len__(self):
169203
"""Length after repetition."""
170204
return self.times * self._ori_len
@@ -280,6 +314,18 @@ def __getitem__(self, idx):
280314
ori_index = self.repeat_indices[idx]
281315
return self.dataset[ori_index]
282316

317+
def get_ann_info(self, idx):
318+
"""Get annotation of dataset by index.
319+
320+
Args:
321+
idx (int): Index of data.
322+
323+
Returns:
324+
dict: Annotation info of specified index.
325+
"""
326+
ori_index = self.repeat_indices[idx]
327+
return self.dataset.get_ann_info(ori_index)
328+
283329
def __len__(self):
284330
"""Length after repetition."""
285331
return len(self.repeat_indices)

tests/test_data/test_datasets/test_dataset_wrapper.py

+37
Original file line numberDiff line numberDiff line change
@@ -21,27 +21,59 @@ def test_dataset_wrapper():
2121
np.random.randint(0, 80, num).tolist()
2222
for num in np.random.randint(1, 20, len_a)
2323
]
24+
ann_info_list_a = []
25+
for _ in range(len_a):
26+
height = np.random.randint(10, 30)
27+
weight = np.random.randint(10, 30)
28+
img = np.ones((height, weight, 3))
29+
gt_bbox = np.concatenate([
30+
np.random.randint(1, 5, (2, 2)),
31+
np.random.randint(1, 5, (2, 2)) + 5
32+
],
33+
axis=1)
34+
gt_labels = np.random.randint(0, 80, 2)
35+
ann_info_list_a.append(
36+
dict(gt_bboxes=gt_bbox, gt_labels=gt_labels, img=img))
2437
dataset_a.data_infos = MagicMock()
2538
dataset_a.data_infos.__len__.return_value = len_a
2639
dataset_a.get_cat_ids = MagicMock(
2740
side_effect=lambda idx: cat_ids_list_a[idx])
41+
dataset_a.get_ann_info = MagicMock(
42+
side_effect=lambda idx: ann_info_list_a[idx])
2843
dataset_b = CustomDataset(
2944
ann_file=MagicMock(), pipeline=[], test_mode=True, img_prefix='')
3045
len_b = 20
3146
cat_ids_list_b = [
3247
np.random.randint(0, 80, num).tolist()
3348
for num in np.random.randint(1, 20, len_b)
3449
]
50+
ann_info_list_b = []
51+
for _ in range(len_b):
52+
height = np.random.randint(10, 30)
53+
weight = np.random.randint(10, 30)
54+
img = np.ones((height, weight, 3))
55+
gt_bbox = np.concatenate([
56+
np.random.randint(1, 5, (2, 2)),
57+
np.random.randint(1, 5, (2, 2)) + 5
58+
],
59+
axis=1)
60+
gt_labels = np.random.randint(0, 80, 2)
61+
ann_info_list_b.append(
62+
dict(gt_bboxes=gt_bbox, gt_labels=gt_labels, img=img))
3563
dataset_b.data_infos = MagicMock()
3664
dataset_b.data_infos.__len__.return_value = len_b
3765
dataset_b.get_cat_ids = MagicMock(
3866
side_effect=lambda idx: cat_ids_list_b[idx])
67+
dataset_b.get_ann_info = MagicMock(
68+
side_effect=lambda idx: ann_info_list_b[idx])
3969

4070
concat_dataset = ConcatDataset([dataset_a, dataset_b])
4171
assert concat_dataset[5] == 5
4272
assert concat_dataset[25] == 15
4373
assert concat_dataset.get_cat_ids(5) == cat_ids_list_a[5]
4474
assert concat_dataset.get_cat_ids(25) == cat_ids_list_b[15]
75+
assert concat_dataset.get_ann_info(5) == ann_info_list_a[5]
76+
assert concat_dataset.get_ann_info(25) == ann_info_list_b[15]
4577
assert len(concat_dataset) == len(dataset_a) + len(dataset_b)
4678

4779
repeat_dataset = RepeatDataset(dataset_a, 10)
@@ -51,6 +83,9 @@ def test_dataset_wrapper():
5183
assert repeat_dataset.get_cat_ids(5) == cat_ids_list_a[5]
5284
assert repeat_dataset.get_cat_ids(15) == cat_ids_list_a[5]
5385
assert repeat_dataset.get_cat_ids(27) == cat_ids_list_a[7]
86+
assert repeat_dataset.get_ann_info(5) == ann_info_list_a[5]
87+
assert repeat_dataset.get_ann_info(15) == ann_info_list_a[5]
88+
assert repeat_dataset.get_ann_info(27) == ann_info_list_a[7]
5489
assert len(repeat_dataset) == 10 * len(dataset_a)
5590

5691
category_freq = defaultdict(int)
@@ -80,6 +115,8 @@ def test_dataset_wrapper():
80115
for idx in np.random.randint(0, len(repeat_factor_dataset), 3):
81116
assert repeat_factor_dataset[idx] == bisect.bisect_right(
82117
repeat_factors_cumsum, idx)
118+
assert repeat_factor_dataset.get_ann_info(idx) == ann_info_list_a[
119+
bisect.bisect_right(repeat_factors_cumsum, idx)]
83120

84121
img_scale = (60, 60)
85122
pipeline = [

0 commit comments

Comments
 (0)