@@ -21,27 +21,59 @@ def test_dataset_wrapper():
21
21
np .random .randint (0 , 80 , num ).tolist ()
22
22
for num in np .random .randint (1 , 20 , len_a )
23
23
]
24
+ ann_info_list_a = []
25
+ for _ in range (len_a ):
26
+ height = np .random .randint (10 , 30 )
27
+ weight = np .random .randint (10 , 30 )
28
+ img = np .ones ((height , weight , 3 ))
29
+ gt_bbox = np .concatenate ([
30
+ np .random .randint (1 , 5 , (2 , 2 )),
31
+ np .random .randint (1 , 5 , (2 , 2 )) + 5
32
+ ],
33
+ axis = 1 )
34
+ gt_labels = np .random .randint (0 , 80 , 2 )
35
+ ann_info_list_a .append (
36
+ dict (gt_bboxes = gt_bbox , gt_labels = gt_labels , img = img ))
24
37
dataset_a .data_infos = MagicMock ()
25
38
dataset_a .data_infos .__len__ .return_value = len_a
26
39
dataset_a .get_cat_ids = MagicMock (
27
40
side_effect = lambda idx : cat_ids_list_a [idx ])
41
+ dataset_a .get_ann_info = MagicMock (
42
+ side_effect = lambda idx : ann_info_list_a [idx ])
28
43
dataset_b = CustomDataset (
29
44
ann_file = MagicMock (), pipeline = [], test_mode = True , img_prefix = '' )
30
45
len_b = 20
31
46
cat_ids_list_b = [
32
47
np .random .randint (0 , 80 , num ).tolist ()
33
48
for num in np .random .randint (1 , 20 , len_b )
34
49
]
50
+ ann_info_list_b = []
51
+ for _ in range (len_b ):
52
+ height = np .random .randint (10 , 30 )
53
+ weight = np .random .randint (10 , 30 )
54
+ img = np .ones ((height , weight , 3 ))
55
+ gt_bbox = np .concatenate ([
56
+ np .random .randint (1 , 5 , (2 , 2 )),
57
+ np .random .randint (1 , 5 , (2 , 2 )) + 5
58
+ ],
59
+ axis = 1 )
60
+ gt_labels = np .random .randint (0 , 80 , 2 )
61
+ ann_info_list_b .append (
62
+ dict (gt_bboxes = gt_bbox , gt_labels = gt_labels , img = img ))
35
63
dataset_b .data_infos = MagicMock ()
36
64
dataset_b .data_infos .__len__ .return_value = len_b
37
65
dataset_b .get_cat_ids = MagicMock (
38
66
side_effect = lambda idx : cat_ids_list_b [idx ])
67
+ dataset_b .get_ann_info = MagicMock (
68
+ side_effect = lambda idx : ann_info_list_b [idx ])
39
69
40
70
concat_dataset = ConcatDataset ([dataset_a , dataset_b ])
41
71
assert concat_dataset [5 ] == 5
42
72
assert concat_dataset [25 ] == 15
43
73
assert concat_dataset .get_cat_ids (5 ) == cat_ids_list_a [5 ]
44
74
assert concat_dataset .get_cat_ids (25 ) == cat_ids_list_b [15 ]
75
+ assert concat_dataset .get_ann_info (5 ) == ann_info_list_a [5 ]
76
+ assert concat_dataset .get_ann_info (25 ) == ann_info_list_b [15 ]
45
77
assert len (concat_dataset ) == len (dataset_a ) + len (dataset_b )
46
78
47
79
repeat_dataset = RepeatDataset (dataset_a , 10 )
@@ -51,6 +83,9 @@ def test_dataset_wrapper():
51
83
assert repeat_dataset .get_cat_ids (5 ) == cat_ids_list_a [5 ]
52
84
assert repeat_dataset .get_cat_ids (15 ) == cat_ids_list_a [5 ]
53
85
assert repeat_dataset .get_cat_ids (27 ) == cat_ids_list_a [7 ]
86
+ assert repeat_dataset .get_ann_info (5 ) == ann_info_list_a [5 ]
87
+ assert repeat_dataset .get_ann_info (15 ) == ann_info_list_a [5 ]
88
+ assert repeat_dataset .get_ann_info (27 ) == ann_info_list_a [7 ]
54
89
assert len (repeat_dataset ) == 10 * len (dataset_a )
55
90
56
91
category_freq = defaultdict (int )
@@ -80,6 +115,8 @@ def test_dataset_wrapper():
80
115
for idx in np .random .randint (0 , len (repeat_factor_dataset ), 3 ):
81
116
assert repeat_factor_dataset [idx ] == bisect .bisect_right (
82
117
repeat_factors_cumsum , idx )
118
+ assert repeat_factor_dataset .get_ann_info (idx ) == ann_info_list_a [
119
+ bisect .bisect_right (repeat_factors_cumsum , idx )]
83
120
84
121
img_scale = (60 , 60 )
85
122
pipeline = [
0 commit comments