|
| 1 | +# Copyright (c) OpenMMLab. All rights reserved. |
| 2 | +import argparse |
| 3 | + |
| 4 | +import cv2 |
| 5 | +import mmcv |
| 6 | +import numpy as np |
| 7 | +import torch |
| 8 | +from torchvision.transforms import functional as F |
| 9 | + |
| 10 | +from mmdet.apis import init_detector |
| 11 | +from mmdet.datasets.pipelines import Compose |
| 12 | + |
| 13 | +try: |
| 14 | + import ffmpegcv |
| 15 | +except ImportError: |
| 16 | + raise ImportError( |
| 17 | + 'Please install ffmpegcv with:\n\n pip install ffmpegcv') |
| 18 | + |
| 19 | + |
| 20 | +def parse_args(): |
| 21 | + parser = argparse.ArgumentParser( |
| 22 | + description='MMDetection video demo with GPU acceleration') |
| 23 | + parser.add_argument('video', help='Video file') |
| 24 | + parser.add_argument('config', help='Config file') |
| 25 | + parser.add_argument('checkpoint', help='Checkpoint file') |
| 26 | + parser.add_argument( |
| 27 | + '--device', default='cuda:0', help='Device used for inference') |
| 28 | + parser.add_argument( |
| 29 | + '--score-thr', type=float, default=0.3, help='Bbox score threshold') |
| 30 | + parser.add_argument('--out', type=str, help='Output video file') |
| 31 | + parser.add_argument('--show', action='store_true', help='Show video') |
| 32 | + parser.add_argument( |
| 33 | + '--nvdecode', action='store_true', help='Use NVIDIA decoder') |
| 34 | + parser.add_argument( |
| 35 | + '--wait-time', |
| 36 | + type=float, |
| 37 | + default=1, |
| 38 | + help='The interval of show (s), 0 is block') |
| 39 | + args = parser.parse_args() |
| 40 | + return args |
| 41 | + |
| 42 | + |
| 43 | +def prefetch_img_metas(cfg, ori_wh): |
| 44 | + w, h = ori_wh |
| 45 | + cfg.data.test.pipeline[0].type = 'LoadImageFromWebcam' |
| 46 | + test_pipeline = Compose(cfg.data.test.pipeline) |
| 47 | + data = {'img': np.zeros((h, w, 3), dtype=np.uint8)} |
| 48 | + data = test_pipeline(data) |
| 49 | + img_metas = data['img_metas'][0].data |
| 50 | + return img_metas |
| 51 | + |
| 52 | + |
| 53 | +def process_img(frame_resize, img_metas, device): |
| 54 | + assert frame_resize.shape == img_metas['pad_shape'] |
| 55 | + frame_cuda = torch.from_numpy(frame_resize).to(device).float() |
| 56 | + frame_cuda = frame_cuda.permute(2, 0, 1) # HWC to CHW |
| 57 | + mean = torch.from_numpy(img_metas['img_norm_cfg']['mean']).to(device) |
| 58 | + std = torch.from_numpy(img_metas['img_norm_cfg']['std']).to(device) |
| 59 | + frame_cuda = F.normalize(frame_cuda, mean=mean, std=std, inplace=True) |
| 60 | + frame_cuda = frame_cuda[None, :, :, :] # NCHW |
| 61 | + data = {'img': [frame_cuda], 'img_metas': [[img_metas]]} |
| 62 | + return data |
| 63 | + |
| 64 | + |
| 65 | +def main(): |
| 66 | + args = parse_args() |
| 67 | + assert args.out or args.show, \ |
| 68 | + ('Please specify at least one operation (save/show the ' |
| 69 | + 'video) with the argument "--out" or "--show"') |
| 70 | + |
| 71 | + model = init_detector(args.config, args.checkpoint, device=args.device) |
| 72 | + |
| 73 | + if args.nvdecode: |
| 74 | + VideoCapture = ffmpegcv.VideoCaptureNV |
| 75 | + else: |
| 76 | + VideoCapture = ffmpegcv.VideoCapture |
| 77 | + video_origin = VideoCapture(args.video) |
| 78 | + img_metas = prefetch_img_metas(model.cfg, |
| 79 | + (video_origin.width, video_origin.height)) |
| 80 | + resize_wh = img_metas['pad_shape'][1::-1] |
| 81 | + video_resize = VideoCapture( |
| 82 | + args.video, |
| 83 | + resize=resize_wh, |
| 84 | + resize_keepratio=True, |
| 85 | + resize_keepratioalign='topleft', |
| 86 | + pix_fmt='rgb24') |
| 87 | + video_writer = None |
| 88 | + if args.out: |
| 89 | + video_writer = ffmpegcv.VideoWriter(args.out, fps=video_origin.fps) |
| 90 | + |
| 91 | + with torch.no_grad(): |
| 92 | + for frame_resize, frame_origin in zip( |
| 93 | + mmcv.track_iter_progress(video_resize), video_origin): |
| 94 | + data = process_img(frame_resize, img_metas, args.device) |
| 95 | + result = model(return_loss=False, rescale=True, **data)[0] |
| 96 | + frame_mask = model.show_result( |
| 97 | + frame_origin, result, score_thr=args.score_thr) |
| 98 | + if args.show: |
| 99 | + cv2.namedWindow('video', 0) |
| 100 | + mmcv.imshow(frame_mask, 'video', args.wait_time) |
| 101 | + if args.out: |
| 102 | + video_writer.write(frame_mask) |
| 103 | + |
| 104 | + if video_writer: |
| 105 | + video_writer.release() |
| 106 | + video_origin.release() |
| 107 | + video_resize.release() |
| 108 | + |
| 109 | + cv2.destroyAllWindows() |
| 110 | + |
| 111 | + |
| 112 | +if __name__ == '__main__': |
| 113 | + main() |
0 commit comments