Skip to content

Commit 9714aa4

Browse files
ppwwyyxxfacebook-github-bot
authored andcommitted
binary mask threshold
Summary: Change area_threshold to 0. Add some tests. Pull Request resolved: fairinternal/detectron2#412 Differential Revision: D21813820 Pulled By: ppwwyyxx fbshipit-source-id: 6f170a6aef0d88a14d438f8ea1a192dffa3560e2
1 parent 3bdf3ab commit 9714aa4

File tree

3 files changed

+22
-6
lines changed

3 files changed

+22
-6
lines changed

demo/predictor.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ class AsyncPredictor:
133133
"""
134134
A predictor that runs the model asynchronously, possibly on >1 GPUs.
135135
Because rendering the visualization takes considerably amount of time,
136-
this helps improve throughput when rendering videos.
136+
this helps improve throughput a little bit when rendering videos.
137137
"""
138138

139139
class _StopToken:

detectron2/utils/visualizer.py

+6-5
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from matplotlib.backends.backend_agg import FigureCanvasAgg
1515
from PIL import Image
1616

17+
from detectron2.data import MetadataCatalog
1718
from detectron2.structures import BitMasks, Boxes, BoxMode, Keypoints, PolygonMasks, RotatedBoxes
1819

1920
from .colormap import random_color
@@ -306,7 +307,7 @@ def get_image(self):
306307

307308

308309
class Visualizer:
309-
def __init__(self, img_rgb, metadata, scale=1.0, instance_mode=ColorMode.IMAGE):
310+
def __init__(self, img_rgb, metadata=None, scale=1.0, instance_mode=ColorMode.IMAGE):
310311
"""
311312
Args:
312313
img_rgb: a numpy array of shape (H, W, C), where H and W correspond to
@@ -317,6 +318,8 @@ def __init__(self, img_rgb, metadata, scale=1.0, instance_mode=ColorMode.IMAGE):
317318
metadata (MetadataCatalog): image metadata.
318319
"""
319320
self.img = np.asarray(img_rgb).clip(0, 255).astype(np.uint8)
321+
if metadata is None:
322+
metadata = MetadataCatalog.get("__nonexist__")
320323
self.metadata = metadata
321324
self.output = VisImage(self.img, scale=scale)
322325
self.cpu_device = torch.device("cpu")
@@ -946,7 +949,7 @@ def draw_line(self, x_data, y_data, color, linestyle="-", linewidth=None):
946949
return self.output
947950

948951
def draw_binary_mask(
949-
self, binary_mask, color=None, *, edge_color=None, text=None, alpha=0.5, area_threshold=4096
952+
self, binary_mask, color=None, *, edge_color=None, text=None, alpha=0.5, area_threshold=0
950953
):
951954
"""
952955
Args:
@@ -967,8 +970,6 @@ def draw_binary_mask(
967970
if color is None:
968971
color = random_color(rgb=True, maximum=1)
969972
color = mplc.to_rgb(color)
970-
if area_threshold is None:
971-
area_threshold = 4096
972973

973974
has_valid_segment = False
974975
binary_mask = binary_mask.astype("uint8") # opencv needs uint8
@@ -979,7 +980,7 @@ def draw_binary_mask(
979980
# draw polygons for regular masks
980981
for segment in mask.polygons:
981982
area = mask_util.area(mask_util.frPyObjects([segment], shape2d[0], shape2d[1]))
982-
if area < area_threshold:
983+
if area < (area_threshold or 0):
983984
continue
984985
has_valid_segment = True
985986
segment = segment.reshape(-1, 2)

tests/test_visualizer.py

+15
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
import numpy as np
66
import unittest
7+
import cv2
78
import torch
89

910
from detectron2.data import MetadataCatalog
@@ -141,3 +142,17 @@ def test_draw_no_metadata(self):
141142

142143
v = Visualizer(img, MetadataCatalog.get("asdfasdf"))
143144
v.draw_instance_predictions(inst)
145+
146+
def test_draw_binary_mask(self):
147+
img, boxes, _, _, masks = self._random_data()
148+
img[:, :, 0] = 0 # remove red color
149+
mask = masks[0]
150+
mask_with_hole = np.zeros_like(mask).astype("uint8")
151+
mask_with_hole = cv2.rectangle(mask_with_hole, (10, 10), (50, 50), 1, 5)
152+
153+
for m in [mask, mask_with_hole]:
154+
v = Visualizer(img)
155+
o = v.draw_binary_mask(m, color="red", text="test")
156+
o = o.get_image().astype("float32")
157+
# red color is drawn on the image
158+
self.assertTrue(o[:, :, 0].sum() > 0)

0 commit comments

Comments
 (0)