Skip to content

Commit 6ead450

Browse files
GT9505ZwwWayne
authored andcommitted
Support bbox_clip_border for the augmentations of YOLOX (#6730)
* support 'bbox_clip_border' for the augmentations of YOLOX * update based on 1-st comments * add comments * fix typos * rename remove_ouside_bboxes to find_inside_bboxes * move comments to docstring
1 parent 5612624 commit 6ead450

File tree

3 files changed

+89
-27
lines changed

3 files changed

+89
-27
lines changed

mmdet/core/bbox/__init__.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from .transforms import (bbox2distance, bbox2result, bbox2roi,
1313
bbox_cxcywh_to_xyxy, bbox_flip, bbox_mapping,
1414
bbox_mapping_back, bbox_rescale, bbox_xyxy_to_cxcywh,
15-
distance2bbox, roi2bbox)
15+
distance2bbox, find_inside_bboxes, roi2bbox)
1616

1717
__all__ = [
1818
'bbox_overlaps', 'BboxOverlaps2D', 'BaseAssigner', 'MaxIoUAssigner',
@@ -24,5 +24,5 @@
2424
'build_bbox_coder', 'BaseBBoxCoder', 'PseudoBBoxCoder',
2525
'DeltaXYWHBBoxCoder', 'TBLRBBoxCoder', 'DistancePointBBoxCoder',
2626
'CenterRegionAssigner', 'bbox_rescale', 'bbox_cxcywh_to_xyxy',
27-
'bbox_xyxy_to_cxcywh', 'RegionAssigner'
27+
'bbox_xyxy_to_cxcywh', 'RegionAssigner', 'find_inside_bboxes'
2828
]

mmdet/core/bbox/transforms.py

+16
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,22 @@
33
import torch
44

55

6+
def find_inside_bboxes(bboxes, img_h, img_w):
7+
"""Find bboxes as long as a part of bboxes is inside the image.
8+
9+
Args:
10+
bboxes (Tensor): Shape (N, 4).
11+
img_h (int): Image height.
12+
img_w (int): Image width.
13+
14+
Returns:
15+
Tensor: Index of the remaining bboxes.
16+
"""
17+
inside_inds = (bboxes[:, 0] < img_w) & (bboxes[:, 2] > 0) \
18+
& (bboxes[:, 1] < img_h) & (bboxes[:, 3] > 0)
19+
return inside_inds
20+
21+
622
def bbox_flip(bboxes, img_shape, direction='horizontal'):
723
"""Flip bboxes horizontally or vertically.
824

mmdet/datasets/pipelines/transforms.py

+71-25
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import numpy as np
1010
from numpy import random
1111

12-
from mmdet.core import PolygonMasks
12+
from mmdet.core import PolygonMasks, find_inside_bboxes
1313
from mmdet.core.evaluation.bbox_overlaps import bbox_overlaps
1414
from ..builder import PIPELINES
1515

@@ -54,8 +54,10 @@ class Resize:
5454
ratio_range (tuple[float]): (min_ratio, max_ratio)
5555
keep_ratio (bool): Whether to keep the aspect ratio when resizing the
5656
image.
57-
bbox_clip_border (bool, optional): Whether clip the objects outside
58-
the border of the image. Defaults to True.
57+
bbox_clip_border (bool, optional): Whether to clip the objects outside
58+
the border of the image. In some dataset like MOT17, the gt bboxes
59+
are allowed to cross the border of images. Therefore, we don't
60+
need to clip the gt bboxes in these cases. Defaults to True.
5961
backend (str): Image resize backend, choices are 'cv2' and 'pillow'.
6062
These two backends generates slightly different results. Defaults
6163
to 'cv2'.
@@ -1982,6 +1984,10 @@ class Mosaic:
19821984
output. Default to (0.5, 1.5).
19831985
min_bbox_size (int | float): The minimum pixel for filtering
19841986
invalid bboxes after the mosaic pipeline. Default to 0.
1987+
bbox_clip_border (bool, optional): Whether to clip the objects outside
1988+
the border of the image. In some dataset like MOT17, the gt bboxes
1989+
are allowed to cross the border of images. Therefore, we don't
1990+
need to clip the gt bboxes in these cases. Defaults to True.
19851991
skip_filter (bool): Whether to skip filtering rules. If it
19861992
is True, the filter rule will not be applied, and the
19871993
`min_bbox_size` is invalid. Default to True.
@@ -1992,12 +1998,14 @@ def __init__(self,
19921998
img_scale=(640, 640),
19931999
center_ratio_range=(0.5, 1.5),
19942000
min_bbox_size=0,
2001+
bbox_clip_border=True,
19952002
skip_filter=True,
19962003
pad_val=114):
19972004
assert isinstance(img_scale, tuple)
19982005
self.img_scale = img_scale
19992006
self.center_ratio_range = center_ratio_range
20002007
self.min_bbox_size = min_bbox_size
2008+
self.bbox_clip_border = bbox_clip_border
20012009
self.skip_filter = skip_filter
20022010
self.pad_val = pad_val
20032011

@@ -2099,16 +2107,24 @@ def _mosaic_transform(self, results):
20992107

21002108
if len(mosaic_labels) > 0:
21012109
mosaic_bboxes = np.concatenate(mosaic_bboxes, 0)
2102-
mosaic_bboxes[:, 0::2] = np.clip(mosaic_bboxes[:, 0::2], 0,
2103-
2 * self.img_scale[1])
2104-
mosaic_bboxes[:, 1::2] = np.clip(mosaic_bboxes[:, 1::2], 0,
2105-
2 * self.img_scale[0])
21062110
mosaic_labels = np.concatenate(mosaic_labels, 0)
21072111

2112+
if self.bbox_clip_border:
2113+
mosaic_bboxes[:, 0::2] = np.clip(mosaic_bboxes[:, 0::2], 0,
2114+
2 * self.img_scale[1])
2115+
mosaic_bboxes[:, 1::2] = np.clip(mosaic_bboxes[:, 1::2], 0,
2116+
2 * self.img_scale[0])
2117+
21082118
if not self.skip_filter:
21092119
mosaic_bboxes, mosaic_labels = \
21102120
self._filter_box_candidates(mosaic_bboxes, mosaic_labels)
21112121

2122+
# remove outside bboxes
2123+
inside_inds = find_inside_bboxes(mosaic_bboxes, 2 * self.img_scale[0],
2124+
2 * self.img_scale[1])
2125+
mosaic_bboxes = mosaic_bboxes[inside_inds]
2126+
mosaic_labels = mosaic_labels[inside_inds]
2127+
21122128
results['img'] = mosaic_img
21132129
results['img_shape'] = mosaic_img.shape
21142130
results['gt_bboxes'] = mosaic_bboxes
@@ -2243,6 +2259,10 @@ class MixUp:
22432259
max_aspect_ratio (float): Aspect ratio of width and height
22442260
threshold to filter bboxes. If max(h/w, w/h) larger than this
22452261
value, the box will be removed. Default: 20.
2262+
bbox_clip_border (bool, optional): Whether to clip the objects outside
2263+
the border of the image. In some dataset like MOT17, the gt bboxes
2264+
are allowed to cross the border of images. Therefore, we don't
2265+
need to clip the gt bboxes in these cases. Defaults to True.
22462266
skip_filter (bool): Whether to skip filtering rules. If it
22472267
is True, the filter rule will not be applied, and the
22482268
`min_bbox_size` and `min_area_ratio` and `max_aspect_ratio`
@@ -2258,6 +2278,7 @@ def __init__(self,
22582278
min_bbox_size=5,
22592279
min_area_ratio=0.2,
22602280
max_aspect_ratio=20,
2281+
bbox_clip_border=True,
22612282
skip_filter=True):
22622283
assert isinstance(img_scale, tuple)
22632284
self.dynamic_scale = img_scale
@@ -2268,6 +2289,7 @@ def __init__(self,
22682289
self.min_bbox_size = min_bbox_size
22692290
self.min_area_ratio = min_area_ratio
22702291
self.max_aspect_ratio = max_aspect_ratio
2292+
self.bbox_clip_border = bbox_clip_border
22712293
self.skip_filter = skip_filter
22722294

22732295
def __call__(self, results):
@@ -2371,21 +2393,29 @@ def _mixup_transform(self, results):
23712393

23722394
# 6. adjust bbox
23732395
retrieve_gt_bboxes = retrieve_results['gt_bboxes']
2374-
retrieve_gt_bboxes[:, 0::2] = np.clip(
2375-
retrieve_gt_bboxes[:, 0::2] * scale_ratio, 0, origin_w)
2376-
retrieve_gt_bboxes[:, 1::2] = np.clip(
2377-
retrieve_gt_bboxes[:, 1::2] * scale_ratio, 0, origin_h)
2396+
retrieve_gt_bboxes[:, 0::2] = retrieve_gt_bboxes[:, 0::2] * scale_ratio
2397+
retrieve_gt_bboxes[:, 1::2] = retrieve_gt_bboxes[:, 1::2] * scale_ratio
2398+
if self.bbox_clip_border:
2399+
retrieve_gt_bboxes[:, 0::2] = np.clip(retrieve_gt_bboxes[:, 0::2],
2400+
0, origin_w)
2401+
retrieve_gt_bboxes[:, 1::2] = np.clip(retrieve_gt_bboxes[:, 1::2],
2402+
0, origin_h)
23782403

23792404
if is_filp:
23802405
retrieve_gt_bboxes[:, 0::2] = (
23812406
origin_w - retrieve_gt_bboxes[:, 0::2][:, ::-1])
23822407

23832408
# 7. filter
23842409
cp_retrieve_gt_bboxes = retrieve_gt_bboxes.copy()
2385-
cp_retrieve_gt_bboxes[:, 0::2] = np.clip(
2386-
cp_retrieve_gt_bboxes[:, 0::2] - x_offset, 0, target_w)
2387-
cp_retrieve_gt_bboxes[:, 1::2] = np.clip(
2388-
cp_retrieve_gt_bboxes[:, 1::2] - y_offset, 0, target_h)
2410+
cp_retrieve_gt_bboxes[:, 0::2] = \
2411+
cp_retrieve_gt_bboxes[:, 0::2] - x_offset
2412+
cp_retrieve_gt_bboxes[:, 1::2] = \
2413+
cp_retrieve_gt_bboxes[:, 1::2] - y_offset
2414+
if self.bbox_clip_border:
2415+
cp_retrieve_gt_bboxes[:, 0::2] = np.clip(
2416+
cp_retrieve_gt_bboxes[:, 0::2], 0, target_w)
2417+
cp_retrieve_gt_bboxes[:, 1::2] = np.clip(
2418+
cp_retrieve_gt_bboxes[:, 1::2], 0, target_h)
23892419

23902420
# 8. mix up
23912421
ori_img = ori_img.astype(np.float32)
@@ -2405,6 +2435,11 @@ def _mixup_transform(self, results):
24052435
mixup_gt_labels = np.concatenate(
24062436
(results['gt_labels'], retrieve_gt_labels), axis=0)
24072437

2438+
# remove outside bbox
2439+
inside_inds = find_inside_bboxes(mixup_gt_bboxes, target_h, target_w)
2440+
mixup_gt_bboxes = mixup_gt_bboxes[inside_inds]
2441+
mixup_gt_labels = mixup_gt_labels[inside_inds]
2442+
24082443
results['img'] = mixup_img.astype(np.uint8)
24092444
results['img_shape'] = mixup_img.shape
24102445
results['gt_bboxes'] = mixup_gt_bboxes
@@ -2471,6 +2506,10 @@ class RandomAffine:
24712506
max_aspect_ratio (float): Aspect ratio of width and height
24722507
threshold to filter bboxes. If max(h/w, w/h) larger than this
24732508
value, the box will be removed.
2509+
bbox_clip_border (bool, optional): Whether to clip the objects outside
2510+
the border of the image. In some dataset like MOT17, the gt bboxes
2511+
are allowed to cross the border of images. Therefore, we don't
2512+
need to clip the gt bboxes in these cases. Defaults to True.
24742513
skip_filter (bool): Whether to skip filtering rules. If it
24752514
is True, the filter rule will not be applied, and the
24762515
`min_bbox_size` and `min_area_ratio` and `max_aspect_ratio`
@@ -2487,6 +2526,7 @@ def __init__(self,
24872526
min_bbox_size=2,
24882527
min_area_ratio=0.2,
24892528
max_aspect_ratio=20,
2529+
bbox_clip_border=True,
24902530
skip_filter=True):
24912531
assert 0 <= max_translate_ratio <= 1
24922532
assert scaling_ratio_range[0] <= scaling_ratio_range[1]
@@ -2500,6 +2540,7 @@ def __init__(self,
25002540
self.min_bbox_size = min_bbox_size
25012541
self.min_area_ratio = min_area_ratio
25022542
self.max_aspect_ratio = max_aspect_ratio
2543+
self.bbox_clip_border = bbox_clip_border
25032544
self.skip_filter = skip_filter
25042545

25052546
def __call__(self, results):
@@ -2560,20 +2601,25 @@ def __call__(self, results):
25602601
warp_bboxes = np.vstack(
25612602
(xs.min(1), ys.min(1), xs.max(1), ys.max(1))).T
25622603

2563-
warp_bboxes[:, [0, 2]] = warp_bboxes[:, [0, 2]].clip(0, width)
2564-
warp_bboxes[:, [1, 3]] = warp_bboxes[:, [1, 3]].clip(0, height)
2604+
if self.bbox_clip_border:
2605+
warp_bboxes[:, [0, 2]] = \
2606+
warp_bboxes[:, [0, 2]].clip(0, width)
2607+
warp_bboxes[:, [1, 3]] = \
2608+
warp_bboxes[:, [1, 3]].clip(0, height)
25652609

2610+
# remove outside bbox
2611+
valid_index = find_inside_bboxes(warp_bboxes, height, width)
25662612
if not self.skip_filter:
25672613
# filter bboxes
2568-
valid_index = self.filter_gt_bboxes(
2614+
filter_index = self.filter_gt_bboxes(
25692615
bboxes * scaling_ratio, warp_bboxes)
2570-
results[key] = warp_bboxes[valid_index]
2571-
if key in ['gt_bboxes']:
2572-
if 'gt_labels' in results:
2573-
results['gt_labels'] = results['gt_labels'][
2574-
valid_index]
2575-
else:
2576-
results[key] = warp_bboxes
2616+
valid_index = valid_index & filter_index
2617+
2618+
results[key] = warp_bboxes[valid_index]
2619+
if key in ['gt_bboxes']:
2620+
if 'gt_labels' in results:
2621+
results['gt_labels'] = results['gt_labels'][
2622+
valid_index]
25772623

25782624
if 'gt_masks' in results:
25792625
raise NotImplementedError(

0 commit comments

Comments
 (0)