Skip to content

Commit b1f40ef

Browse files
authored
Speedup the Video Inference by Accelerating data-loading Stage (#7832)
* add a faster inference for video * Fix typos * modify typo * modify the numpy array to torch gpu * fix lint * add description * add documents * fix typro * fix lint * fix lint * fix lint again * fix a mistake
1 parent 280cc7d commit b1f40ef

File tree

3 files changed

+166
-0
lines changed

3 files changed

+166
-0
lines changed

demo/video_gpuaccel_demo.py

+113
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
# Copyright (c) OpenMMLab. All rights reserved.
2+
import argparse
3+
4+
import cv2
5+
import mmcv
6+
import numpy as np
7+
import torch
8+
from torchvision.transforms import functional as F
9+
10+
from mmdet.apis import init_detector
11+
from mmdet.datasets.pipelines import Compose
12+
13+
try:
14+
import ffmpegcv
15+
except ImportError:
16+
raise ImportError(
17+
'Please install ffmpegcv with:\n\n pip install ffmpegcv')
18+
19+
20+
def parse_args():
21+
parser = argparse.ArgumentParser(
22+
description='MMDetection video demo with GPU acceleration')
23+
parser.add_argument('video', help='Video file')
24+
parser.add_argument('config', help='Config file')
25+
parser.add_argument('checkpoint', help='Checkpoint file')
26+
parser.add_argument(
27+
'--device', default='cuda:0', help='Device used for inference')
28+
parser.add_argument(
29+
'--score-thr', type=float, default=0.3, help='Bbox score threshold')
30+
parser.add_argument('--out', type=str, help='Output video file')
31+
parser.add_argument('--show', action='store_true', help='Show video')
32+
parser.add_argument(
33+
'--nvdecode', action='store_true', help='Use NVIDIA decoder')
34+
parser.add_argument(
35+
'--wait-time',
36+
type=float,
37+
default=1,
38+
help='The interval of show (s), 0 is block')
39+
args = parser.parse_args()
40+
return args
41+
42+
43+
def prefetch_img_metas(cfg, ori_wh):
44+
w, h = ori_wh
45+
cfg.data.test.pipeline[0].type = 'LoadImageFromWebcam'
46+
test_pipeline = Compose(cfg.data.test.pipeline)
47+
data = {'img': np.zeros((h, w, 3), dtype=np.uint8)}
48+
data = test_pipeline(data)
49+
img_metas = data['img_metas'][0].data
50+
return img_metas
51+
52+
53+
def process_img(frame_resize, img_metas, device):
54+
assert frame_resize.shape == img_metas['pad_shape']
55+
frame_cuda = torch.from_numpy(frame_resize).to(device).float()
56+
frame_cuda = frame_cuda.permute(2, 0, 1) # HWC to CHW
57+
mean = torch.from_numpy(img_metas['img_norm_cfg']['mean']).to(device)
58+
std = torch.from_numpy(img_metas['img_norm_cfg']['std']).to(device)
59+
frame_cuda = F.normalize(frame_cuda, mean=mean, std=std, inplace=True)
60+
frame_cuda = frame_cuda[None, :, :, :] # NCHW
61+
data = {'img': [frame_cuda], 'img_metas': [[img_metas]]}
62+
return data
63+
64+
65+
def main():
66+
args = parse_args()
67+
assert args.out or args.show, \
68+
('Please specify at least one operation (save/show the '
69+
'video) with the argument "--out" or "--show"')
70+
71+
model = init_detector(args.config, args.checkpoint, device=args.device)
72+
73+
if args.nvdecode:
74+
VideoCapture = ffmpegcv.VideoCaptureNV
75+
else:
76+
VideoCapture = ffmpegcv.VideoCapture
77+
video_origin = VideoCapture(args.video)
78+
img_metas = prefetch_img_metas(model.cfg,
79+
(video_origin.width, video_origin.height))
80+
resize_wh = img_metas['pad_shape'][1::-1]
81+
video_resize = VideoCapture(
82+
args.video,
83+
resize=resize_wh,
84+
resize_keepratio=True,
85+
resize_keepratioalign='topleft',
86+
pix_fmt='rgb24')
87+
video_writer = None
88+
if args.out:
89+
video_writer = ffmpegcv.VideoWriter(args.out, fps=video_origin.fps)
90+
91+
with torch.no_grad():
92+
for frame_resize, frame_origin in zip(
93+
mmcv.track_iter_progress(video_resize), video_origin):
94+
data = process_img(frame_resize, img_metas, args.device)
95+
result = model(return_loss=False, rescale=True, **data)[0]
96+
frame_mask = model.show_result(
97+
frame_origin, result, score_thr=args.score_thr)
98+
if args.show:
99+
cv2.namedWindow('video', 0)
100+
mmcv.imshow(frame_mask, 'video', args.wait_time)
101+
if args.out:
102+
video_writer.write(frame_mask)
103+
104+
if video_writer:
105+
video_writer.release()
106+
video_origin.release()
107+
video_resize.release()
108+
109+
cv2.destroyAllWindows()
110+
111+
112+
if __name__ == '__main__':
113+
main()

docs/en/1_exist_data_model.md

+26
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,32 @@ python demo/video_demo.py demo/demo.mp4 \
162162
--out result.mp4
163163
```
164164

165+
#### Video demo with GPU acceleration
166+
167+
This script performs inference on a video with GPU acceleration.
168+
169+
```shell
170+
python demo/video_gpuaccel_demo.py \
171+
${VIDEO_FILE} \
172+
${CONFIG_FILE} \
173+
${CHECKPOINT_FILE} \
174+
[--device ${GPU_ID}] \
175+
[--score-thr ${SCORE_THR}] \
176+
[--nvdecode] \
177+
[--out ${OUT_FILE}] \
178+
[--show] \
179+
[--wait-time ${WAIT_TIME}]
180+
```
181+
182+
Examples:
183+
184+
```shell
185+
python demo/video_gpuaccel_demo.py demo/demo.mp4 \
186+
configs/faster_rcnn/faster_rcnn_r50_fpn_1x_coco.py \
187+
checkpoints/faster_rcnn_r50_fpn_1x_coco_20200130-047c8118.pth \
188+
--nvdecode --out result.mp4
189+
```
190+
165191
## Test existing models on standard datasets
166192

167193
To evaluate a model's accuracy, one usually tests the model on some standard datasets.

docs/zh_cn/1_exist_data_model.md

+27
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,33 @@ asyncio.run(main())
160160
--out result.mp4
161161
```
162162

163+
#### 视频样例,显卡加速版本
164+
165+
这是在视频样例上进行推理的脚本,使用显卡加速。
166+
167+
```shell
168+
python demo/video_gpuaccel_demo.py \
169+
${VIDEO_FILE} \
170+
${CONFIG_FILE} \
171+
${CHECKPOINT_FILE} \
172+
[--device ${GPU_ID}] \
173+
[--score-thr ${SCORE_THR}] \
174+
[--nvdecode] \
175+
[--out ${OUT_FILE}] \
176+
[--show] \
177+
[--wait-time ${WAIT_TIME}]
178+
179+
```
180+
181+
运行样例:
182+
183+
```shell
184+
python demo/video_gpuaccel_demo.py demo/demo.mp4 \
185+
configs/faster_rcnn/faster_rcnn_r50_fpn_1x_coco.py \
186+
checkpoints/faster_rcnn_r50_fpn_1x_coco_20200130-047c8118.pth \
187+
--nvdecode --out result.mp4
188+
```
189+
163190
## 在标准数据集上测试现有模型
164191

165192
为了测试一个模型的精度,我们通常会在标准数据集上对其进行测试。MMDetection 支持多个公共数据集,包括 [COCO](https://cocodataset.org/)

0 commit comments

Comments
 (0)