Skip to content

Commit 8999946

Browse files
ppwwyyxxfacebook-github-bot
authored andcommitted
retinanet inference speedup
Reviewed By: theschnitz Differential Revision: D23924255 fbshipit-source-id: ea85df04b0e56cc5ba7eeccb6d7d1f88300c896f
1 parent 2618f32 commit 8999946

File tree

4 files changed

+20
-18
lines changed

4 files changed

+20
-18
lines changed

MODEL_ZOO.md

+3-3
Original file line numberDiff line numberDiff line change
@@ -219,7 +219,7 @@ All models available for download through this document are licensed under the
219219
<tr><td align="left"><a href="configs/COCO-Detection/retinanet_R_50_FPN_1x.yaml">R50</a></td>
220220
<td align="center">1x</td>
221221
<td align="center">0.205</td>
222-
<td align="center">0.056</td>
222+
<td align="center">0.041</td>
223223
<td align="center">4.1</td>
224224
<td align="center">37.4</td>
225225
<td align="center">190397773</td>
@@ -229,7 +229,7 @@ All models available for download through this document are licensed under the
229229
<tr><td align="left"><a href="configs/COCO-Detection/retinanet_R_50_FPN_3x.yaml">R50</a></td>
230230
<td align="center">3x</td>
231231
<td align="center">0.205</td>
232-
<td align="center">0.056</td>
232+
<td align="center">0.041</td>
233233
<td align="center">4.1</td>
234234
<td align="center">38.7</td>
235235
<td align="center">190397829</td>
@@ -239,7 +239,7 @@ All models available for download through this document are licensed under the
239239
<tr><td align="left"><a href="configs/COCO-Detection/retinanet_R_101_FPN_3x.yaml">R101</a></td>
240240
<td align="center">3x</td>
241241
<td align="center">0.291</td>
242-
<td align="center">0.069</td>
242+
<td align="center">0.054</td>
243243
<td align="center">5.2</td>
244244
<td align="center">40.4</td>
245245
<td align="center">190397697</td>

detectron2/config/defaults.py

+1
Original file line numberDiff line numberDiff line change
@@ -442,6 +442,7 @@
442442
# Inference cls score threshold, only anchors with score > INFERENCE_TH are
443443
# considered for inference (to improve speed)
444444
_C.MODEL.RETINANET.SCORE_THRESH_TEST = 0.05
445+
# Select topk candidates before NMS
445446
_C.MODEL.RETINANET.TOPK_CANDIDATES_TEST = 1000
446447
_C.MODEL.RETINANET.NMS_THRESH_TEST = 0.5
447448

detectron2/modeling/meta_arch/retinanet.py

+11-10
Original file line numberDiff line numberDiff line change
@@ -341,19 +341,20 @@ def inference_single_image(self, anchors, box_cls, box_delta, image_size):
341341
# Iterate over every feature level
342342
for box_cls_i, box_reg_i, anchors_i in zip(box_cls, box_delta, anchors):
343343
# (HxWxAxK,)
344-
box_cls_i = box_cls_i.flatten().sigmoid_()
344+
predicted_prob = box_cls_i.flatten().sigmoid_()
345345

346-
# Keep top k top scoring indices only.
347-
num_topk = min(self.topk_candidates, box_reg_i.size(0))
348-
# torch.sort is actually faster than .topk (at least on GPUs)
349-
predicted_prob, topk_idxs = box_cls_i.sort(descending=True)
350-
predicted_prob = predicted_prob[:num_topk]
351-
topk_idxs = topk_idxs[:num_topk]
352-
353-
# filter out the proposals with low confidence score
346+
# Apply two filtering below to make NMS faster.
347+
# 1. Keep boxes with confidence score higher than threshold
354348
keep_idxs = predicted_prob > self.score_threshold
355349
predicted_prob = predicted_prob[keep_idxs]
356-
topk_idxs = topk_idxs[keep_idxs]
350+
topk_idxs = torch.nonzero(keep_idxs, as_tuple=True)[0]
351+
352+
# 2. Keep top k top scoring boxes only
353+
num_topk = min(self.topk_candidates, topk_idxs.size(0))
354+
# torch.sort is actually faster than .topk (at least on GPUs)
355+
predicted_prob, idxs = predicted_prob.sort(descending=True)
356+
predicted_prob = predicted_prob[:num_topk]
357+
topk_idxs = topk_idxs[idxs[:num_topk]]
357358

358359
anchor_idxs = topk_idxs // self.num_classes
359360
classes_idxs = topk_idxs % self.num_classes

tools/benchmark.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -125,16 +125,16 @@ def benchmark_eval(args):
125125
cfg.defrost()
126126
cfg.DATALOADER.NUM_WORKERS = 0
127127
data_loader = build_detection_test_loader(cfg, cfg.DATASETS.TEST[0])
128-
dummy_data = list(itertools.islice(data_loader, 100))
128+
dummy_data = DatasetFromList(list(itertools.islice(data_loader, 100)), copy=False)
129129

130130
def f():
131131
while True:
132-
yield from DatasetFromList(dummy_data, copy=False)
132+
yield from dummy_data
133133

134-
for _ in range(5): # warmup
135-
model(dummy_data[0])
134+
for k in range(5): # warmup
135+
model(dummy_data[k])
136136

137-
max_iter = 400
137+
max_iter = 300
138138
timer = Timer()
139139
with tqdm.tqdm(total=max_iter) as pbar:
140140
for idx, d in enumerate(f()):

0 commit comments

Comments
 (0)