forked from fisadev/virtualbackground
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathviba.py
408 lines (342 loc) · 15.2 KB
/
viba.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
import json
import logging
from asyncio import run, get_running_loop, sleep, gather
from datetime import datetime
from pathlib import Path
import click
import cv2
import numpy as np
import pyfakewebcam
import requests
import tensorflow as tf
from tfjs_graph_converter.api import load_graph_model
logger = logging.getLogger('viba')
def download_file(url, path):
"""
Download a file from a url to a path in disk.
"""
response = requests.get(url)
with open(path, 'wb') as destination_file:
destination_file.write(response.content)
class Cam:
"""
A camera, either real (to read frames from), or fake (to send frames to).
"""
def __init__(self, device, size, fps, is_real=True):
self.device = device
self.size = size
self.fps = fps
self.is_real = is_real
if self.is_real:
self.init_real_cam()
else:
self.init_fake_cam()
def init_real_cam(self):
"""
Initialize a real camera, using opencv to capture frames from it.
"""
capturer = cv2.VideoCapture(self.device)
capturer.set(cv2.CAP_PROP_FRAME_WIDTH, self.width)
capturer.set(cv2.CAP_PROP_FRAME_HEIGHT, self.height)
capturer.set(cv2.CAP_PROP_FPS, self.fps)
self.interface = capturer
logger.info("Real camera initialized at device %s, resolution %s and fps %s",
self.device, self.size, self.fps)
def init_fake_cam(self):
"""
Initialize a fake camera, using pyfakewebcam to send frames to it.
"""
self.interface = pyfakewebcam.FakeWebcam(self.device, self.width, self.height)
logger.info("Fake camera initialized at device %s, resolution %s",
self.device, self.size)
@property
def width(self):
return self.size[0]
@property
def height(self):
return self.size[1]
async def read_frame(self):
"""
Read a real frame from a real camera.
"""
assert self.is_real
# the capturer can return None when the camera is being used by someone else, we just
# keep trying until we get a real frame
frame = None
while frame is None:
await sleep(0)
_, frame = self.interface.read()
# why would anyone use BGR??
return cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
async def write_frame(self, frame):
"""
Write a fake frame to be displayed in a fake camera.
"""
assert not self.is_real
frame_height, frame_width, _ = frame.shape
if (frame_width, frame_height) != self.size:
frame = cv2.resize(frame, self.size)
await sleep(0)
self.interface.schedule_frame(frame)
class SegmenterModel:
"""
A model able to segment people in images, using tensorflow.
Currently based on mobilenet.
"""
BASE_MODELS_PATH = Path('./models/')
BASE_MODELS_URL = 'https://storage.googleapis.com/tfjs-models/savedmodel/{}'
KNOWN_MODELS_URLS = {
'mobilenet_quant4_100_stride16': 'bodypix/mobilenet/float/100/model-stride16.json',
'mobilenet_quant4_075_stride16': 'bodypix/mobilenet/float/075/model-stride16.json',
# there are plenty of other models, but this project has been tested and tuned up for these
# you can try others if you want :)
}
def __init__(self, model_name, segmentation_threshold, use_gpu):
if model_name not in self.KNOWN_MODELS_URLS:
raise ValueError("Uknown model: {}".format(model_name))
self.model_name = model_name
self.segmentation_threshold = segmentation_threshold
self.use_gpu = use_gpu
self.init_graph()
def init_graph(self):
"""
Get a model from the published tensorflowjs models. If it's not present, download it.
"""
self.download_tfjs_model()
self.graph = load_graph_model(str(self.model_path))
logger.info("TensorflowJS model %s loaded", self.model_name)
@property
def model_path(self):
"""
Build path to the specified model.
"""
return self.BASE_MODELS_PATH / self.model_name / 'model.json'
def download_tfjs_model(self):
"""
If the model from the published tensorflowjs models isn't present, download it.
"""
model_dir_path = self.model_path.parent
model_dir_path.mkdir(parents=True, exist_ok=True)
if self.model_path.exists():
logger.info("TensorflowJS model found on disk")
else:
logger.info("TensorflowJS model not present, will download it")
# the model is a json file, but that file points to additional weight files, so we
# have to first get the json, read it, and then we know which extra files we need to
# get too
logger.info("Downloading model definition...")
model_url = self.BASE_MODELS_URL.format(self.KNOWN_MODELS_URLS[self.model_name])
download_file(
model_url,
self.model_path,
)
# download the extra weight files
parent_model_url = '/'.join(model_url.split('/')[:-1]) + '/{}'
definition = json.loads(self.model_path.read_text())
for weights_manifest in definition['weightsManifest']:
for weights_file_name in weights_manifest['paths']:
logger.info("Downloading weights: %s...", weights_file_name)
download_file(
parent_model_url.format(weights_file_name),
model_dir_path / weights_file_name,
)
def apply_model(self, inputs):
"""
Apply the tensorflow model to get the segmentation mask outputs.
"""
with tf.compat.v1.Session(graph=self.graph) as sess:
with tf.device("/gpu:0" if self.use_gpu else "/cpu:0"):
input_tensor = self.graph.get_tensor_by_name('sub_2:0')
results = sess.run(['float_segments:0'], feed_dict={input_tensor: inputs})
# results will contain only one result, because we are asking for a single output
# tensor (the segments)
# and we also know we have only one image, so there will be only one result. So we
# can remove the "image index" dimension with squeeze
segments = np.squeeze(results[0], 0)
# convert the segment values to the range between (0, 1)
segment_scores = tf.sigmoid(segments)
# and then consider only as human those regions above the segmentation threshold
mask = tf.math.greater(segment_scores, tf.constant(self.segmentation_threshold))
segmentation_mask = mask.eval()
segmentation_mask = np.reshape(
segmentation_mask, (segmentation_mask.shape[0], segmentation_mask.shape[1])
).astype(np.uint8)
return segmentation_mask
async def get_mask(self, image):
"""
Get the mask that defines the area occupied by people in the image.
"""
original_height, original_width, _ = image.shape
# normalization of inputs, and construction of a "samples" set from the single image we
# have
image = (image / 127.) - 1
image_as_inputs = np.expand_dims(image, 0)
await sleep(0)
# the model is a big chunk of blocking code, but if we are using GPU, then we can actually
# leave it running and keep doing more stuff. That's why we use asyncio combined with a
# threaded executor for the model
loop = get_running_loop()
segmentation_mask = await loop.run_in_executor(None, self.apply_model, image_as_inputs)
segmentation_mask = cv2.resize(segmentation_mask, (original_width, original_height))
return segmentation_mask
class VirtualBackground:
"""
A virtual background achieved by reading frames from a real camera, finding the location of
people in the image, replacing the background with a custom image, and then sending the
modified frame to a fake camera.
"""
def __init__(self, model, real_cam, fake_cam, background_path):
self.model = model
self.real_cam = real_cam
self.fake_cam = fake_cam
self.last_frame = None
self.frames_count = 0
self.current_mask = None
self.masks_count = 0
self.load_background(background_path)
logger.info("Everything ready to run")
def load_background(self, background_path):
"""
Load and prepare the background to use in the frames.
"""
background = cv2.imread(background_path)
background = cv2.resize(background, self.real_cam.size)
self.background = cv2.cvtColor(background, cv2.COLOR_BGR2RGB)
async def frames_loop(self):
"""
Get new frames, apply the current mask and schedule them as fast as possible.
"""
logger.info("Started frames loop")
last_stats_at = datetime.now()
while True:
logger.debug("Reading a new frame...")
frame = await self.real_cam.read_frame()
self.last_frame = frame
logger.debug("Enhancing the frame...")
if self.current_mask is None:
enhanced_frame = frame
else:
enhanced_frame = await self.enhance_frame(frame)
logger.debug("Sending the enhanced frame to the fake cam...")
await self.fake_cam.write_frame(enhanced_frame)
self.frames_count += 1
# print useful speed stats every 10 seconds or so
seconds_since_stats = (datetime.now() - last_stats_at).total_seconds()
if seconds_since_stats >= 10:
logger.info("Camera working at %.2f fps, and %.2f masks per second",
self.frames_count / seconds_since_stats,
self.masks_count / seconds_since_stats)
last_stats_at = datetime.now()
self.frames_count = 0
self.masks_count = 0
async def mask_loop(self):
"""
Get and refresh the mask as fast as possible.
"""
logger.info("Started mask loop")
while True:
if self.last_frame is None:
await sleep(0)
continue
logger.debug("Updating mask...")
raw_mask = await self.model.get_mask(self.last_frame)
self.current_mask = await self.post_process_mask(raw_mask)
self.masks_count += 1
async def enhance_frame(self, frame):
"""
Given a frame, enhance it using the current mask, and any optional effects.
"""
frame = frame.copy()
for channel in range(frame.shape[2]):
frame[:, :, channel] = (
frame[:, :, channel] * self.current_mask
+ self.background[:, :, channel] * (1 - self.current_mask)
)
await sleep(0)
return frame
async def post_process_mask(self, mask):
"""
Dilate and blur the mask, so the limits between the person and background aren't that
abrupt and visible.
"""
mask = cv2.dilate(mask, np.ones((10, 10), np.uint8), iterations=2)
await sleep(0)
mask = cv2.blur(mask.astype(float), (30, 30))
return mask
def run(self):
"""
Run the virtual camera.
"""
logger.info("Main loop launched!")
async def main_loop():
"""
The main loop just stays alive while the frames and mask loops are alive.
"""
await gather(self.frames_loop(), self.mask_loop())
run(main_loop())
@click.command()
@click.option("--background", default="./sample_background.jpg",
help="The background image to use in the webcam.")
@click.option('--use-gpu', is_flag=True,
help="Force the use of a CUDA enabled GPU, to improve performance. Remember that "
"this has extra dependencies, more info in the README.")
@click.option("--real-cam-resolution", default=(640, 480), type=(int, int),
help="The resolution of the real webcam. We highly recommend using a small value "
"because of performance reasons, specially if you aren't using a high end GPU with "
"viba. The value must be a tuple with the structure: (width, height). Example: "
"--real-cam-resolution 640 480")
@click.option("--fake-cam-resolution", default=(960, 720), type=(int, int),
help="The resolution of the fake webcam. We recommend using a small value "
"because of performance reasons, but this isn't as important as the real cam "
"resolution. Also, useful info: some web conference services like Jitsi ignore "
"webcams bellow 720p. The value must be a tuple with the structure: "
"(width, height). Example: --fake-cam-resolution 640 480")
@click.option("--real-cam-fps", default=30, type=int,
help="The speed (frames per second) of the real webcam.")
@click.option("--real-cam-device", default="/dev/video0",
help="The linux device in which the real cam exists.")
@click.option("--fake-cam-device", default="/dev/video20",
help="The linux device in which the fake cam exists (the one created using "
"v4l2loopback.")
@click.option("--model-name", default="mobilenet_quant4_100_stride16",
type=click.Choice(list(SegmenterModel.KNOWN_MODELS_URLS.keys()),
case_sensitive=False),
help="The tensorflowjs model that will be used to detect people in the video. If "
"you have trouble with performance, you can try using "
"'mobilenet_quant4_075_stride16', which is a little bit faster.")
@click.option("--segmentation-threshold", default=0.7, type=click.FloatRange(0, 1),
help="How much of the image will be considered as a 'person'. A lower value means "
"less confidence required, so more regions will be considered as a 'person'. A "
"higher value means the opposite. Must be a value between 0 and 1.")
@click.option('--debug', is_flag=True,
help="Debug mode: print a lot of extra text during execution, to debug issues.")
def main(background, use_gpu, real_cam_resolution, fake_cam_resolution, real_cam_fps,
real_cam_device, fake_cam_device, model_name, segmentation_threshold, debug):
"""
Cli script for Viba, the virtual background utility.
"""
FORMAT = '%(asctime)-15s|%(levelname)s|%(name)s|%(message)s'
logging.basicConfig(format=FORMAT, level=logging.DEBUG if debug else logging.INFO)
viba = VirtualBackground(
model=SegmenterModel(
model_name=model_name,
segmentation_threshold=segmentation_threshold,
use_gpu=use_gpu,
),
real_cam=Cam(
device=real_cam_device,
size=real_cam_resolution,
fps=real_cam_fps,
is_real=True,
),
fake_cam=Cam(
device=fake_cam_device,
size=fake_cam_resolution,
fps=None,
is_real=False,
),
background_path=background,
)
viba.run()
if __name__ == '__main__':
main()