14
14
from matplotlib .backends .backend_agg import FigureCanvasAgg
15
15
from PIL import Image
16
16
17
+ from detectron2 .data import MetadataCatalog
17
18
from detectron2 .structures import BitMasks , Boxes , BoxMode , Keypoints , PolygonMasks , RotatedBoxes
18
19
19
20
from .colormap import random_color
@@ -306,7 +307,7 @@ def get_image(self):
306
307
307
308
308
309
class Visualizer :
309
- def __init__ (self , img_rgb , metadata , scale = 1.0 , instance_mode = ColorMode .IMAGE ):
310
+ def __init__ (self , img_rgb , metadata = None , scale = 1.0 , instance_mode = ColorMode .IMAGE ):
310
311
"""
311
312
Args:
312
313
img_rgb: a numpy array of shape (H, W, C), where H and W correspond to
@@ -317,6 +318,8 @@ def __init__(self, img_rgb, metadata, scale=1.0, instance_mode=ColorMode.IMAGE):
317
318
metadata (MetadataCatalog): image metadata.
318
319
"""
319
320
self .img = np .asarray (img_rgb ).clip (0 , 255 ).astype (np .uint8 )
321
+ if metadata is None :
322
+ metadata = MetadataCatalog .get ("__nonexist__" )
320
323
self .metadata = metadata
321
324
self .output = VisImage (self .img , scale = scale )
322
325
self .cpu_device = torch .device ("cpu" )
@@ -946,7 +949,7 @@ def draw_line(self, x_data, y_data, color, linestyle="-", linewidth=None):
946
949
return self .output
947
950
948
951
def draw_binary_mask (
949
- self , binary_mask , color = None , * , edge_color = None , text = None , alpha = 0.5 , area_threshold = 4096
952
+ self , binary_mask , color = None , * , edge_color = None , text = None , alpha = 0.5 , area_threshold = 0
950
953
):
951
954
"""
952
955
Args:
@@ -967,8 +970,6 @@ def draw_binary_mask(
967
970
if color is None :
968
971
color = random_color (rgb = True , maximum = 1 )
969
972
color = mplc .to_rgb (color )
970
- if area_threshold is None :
971
- area_threshold = 4096
972
973
973
974
has_valid_segment = False
974
975
binary_mask = binary_mask .astype ("uint8" ) # opencv needs uint8
@@ -979,7 +980,7 @@ def draw_binary_mask(
979
980
# draw polygons for regular masks
980
981
for segment in mask .polygons :
981
982
area = mask_util .area (mask_util .frPyObjects ([segment ], shape2d [0 ], shape2d [1 ]))
982
- if area < area_threshold :
983
+ if area < ( area_threshold or 0 ) :
983
984
continue
984
985
has_valid_segment = True
985
986
segment = segment .reshape (- 1 , 2 )
0 commit comments