-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdemo_DIS_onnx.py
135 lines (110 loc) · 3.71 KB
/
demo_DIS_onnx.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
#!/usr/bin/env python
import os
import copy
import time
import argparse
import cv2 as cv
import numpy as np
import onnxruntime
def run_inference(onnx_session, image, score_th=None):
# ONNX Input Size
input_size = onnx_session.get_inputs()[0].shape
input_width = input_size[3]
input_height = input_size[2]
# Pre process:Resize, BGR->RGB, float32 cast
input_image = cv.resize(image, dsize=(input_width, input_height))
input_image = cv.cvtColor(input_image, cv.COLOR_BGR2RGB)
mean = [0.5, 0.5, 0.5]
std = [1.0, 1.0, 1.0]
input_image = (input_image / 255.0 - mean) / std
input_image = input_image.transpose(2, 0, 1)
input_image = np.expand_dims(input_image, axis=0)
input_image = input_image.astype('float32')
# Inference
input_name = onnx_session.get_inputs()[0].name
output_name = onnx_session.get_outputs()[0].name
result = onnx_session.run([output_name], {input_name: input_image})
# Post process:squeeze, Sigmoid, Normarize, uint8 cast
mask = np.squeeze(result[0])
min_value = np.min(mask)
max_value = np.max(mask)
mask = (mask - min_value) / (max_value - min_value)
if score_th is not None:
mask = np.where(mask < score_th, 0, 1)
mask *= 255
mask = mask.astype('uint8')
return mask
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--device', type=int, default=0)
parser.add_argument('--movie', type=str, default=None)
parser.add_argument(
'--model',
type=str,
default='model/isnet-general-use.onnx',
)
parser.add_argument('--score_th', type=float, default=None)
args = parser.parse_args()
model_path = args.model
score_th = args.score_th
# Initialize video capture
cap_device = args.device
if args.movie is not None:
cap_device = args.movie
cap = cv.VideoCapture(cap_device)
# Load model
onnx_session = onnxruntime.InferenceSession(
model_path,
providers=[
'CUDAExecutionProvider',
'CPUExecutionProvider',
],
)
while True:
start_time = time.time()
# Capture read
ret, frame = cap.read()
if not ret:
break
debug_image = copy.deepcopy(frame)
# Inference execution
mask = run_inference(
onnx_session,
frame,
score_th,
)
elapsed_time = time.time() - start_time
# Map Resize
mask = cv.resize(
mask,
dsize=(debug_image.shape[1], debug_image.shape[0]),
)
# Mask Overlay
overlay_image = np.zeros(debug_image.shape, dtype=np.uint8)
overlay_image[:] = (255, 255, 255)
mask = np.stack((mask, ) * 3, axis=-1).astype('uint8')
mask_image = np.where(mask, debug_image, overlay_image)
# Inference elapsed time
elapsed_time_text = 'Elapsed time: '
elapsed_time_text += str(round((elapsed_time * 1000), 1))
elapsed_time_text += 'ms'
cv.putText(debug_image, elapsed_time_text, (10, 30),
cv.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2, cv.LINE_AA)
key = cv.waitKey(1)
if key == 27: # ESC
break
cv.imshow('DIS : Input', debug_image)
cv.imshow('DIS : Output', mask)
cv.imshow('DIS : Mask', mask_image)
cap.release()
cv.destroyAllWindows()
if __name__ == '__main__':
if not os.path.exists('./model/isnet-general-use.onnx'):
import gdown
MODEL_PATH_URL = 'https://drive.google.com/uc?id=10DU7o4HcjUP3JdBYjvWDO7gknTTRk1Zd'
gdown.download(
MODEL_PATH_URL,
'./model/isnet-general-use.onnx',
use_cookies=False,
)
main()