Skip to content

Commit 2618f32

Browse files
bowenc0221facebook-github-bot
authored andcommitted
Reproduce Panoptic-DeepLab in Detectron2.
Summary: In this PR, I implemented Panoptic-DeepLab in Detectron2. This model can be served as a strong baseline for bottom-up panoptic segmentation. The following table shows results on Cityscapes val set. | Method | Backbone | Input size | PQ (our) | SQ (our) | RQ (our) | mIoU (our) | AP (our) | PQ (paper) | |-----------------------|-----------|------------|----------|----------|----------|----------|----------|----------| | Panoptic-DeepLab | R52-DC5 | 1024x2048| 61.9 | 81.5 | 75.0 | 79.8 | 31.7 | 59.8 | This re-implementation is 2.0% PQ better than the original TF implementation using exactly the same training parameters. I hypothesize this is because D2 (facebookresearch@513bf19) has better data pre-processing (PIL vs. cv2). Note: - R52: a ResNet-50 with its first 7x7 convolution replaced by 3 3x3 convolutions. This modification has been used in most semantic segmentation papers. We pre-train this backbone on ImageNet using the default recipe of [pytorch examples](https://github.com/pytorch/examples/tree/master/imagenet). - DC5 means using dilated convolution in `res5`. Pull Request resolved: fairinternal/detectron2#438 Test Plan: The code is tested on FAIR Cluster with distributed training with 4 nodes of 8 GPUs using the following command: ``` python3 d2_submitit.py --config config/Cityscapes-PanopticSegmentation/panoptic_deeplab_R_52_os16_mg124_poly_60k_bs32_crop_1024_2048.yaml --num-gpus 8 --num-machines 4 --resume ``` References: ``` inproceedings{cheng2020panoptic, title={Panoptic-DeepLab: A Simple, Strong, and Fast Baseline for Bottom-Up Panoptic Segmentation}, author={Cheng, Bowen and Collins, Maxwell D and Zhu, Yukun and Liu, Ting and Huang, Thomas S and Adam, Hartwig and Chen, Liang-Chieh}, booktitle={CVPR}, year={2020} } ``` Reviewed By: ppwwyyxx Differential Revision: D23418815 Pulled By: bowenc0221 fbshipit-source-id: fc3b5e2018b19a8594a3bcf820ab8cab3d3da770
1 parent 1557c20 commit 2618f32

27 files changed

+1786
-26
lines changed

.circleci/config.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ install_detectron2: &install_detectron2
104104
- run:
105105
name: Install Detectron2
106106
command: |
107-
pip install --progress-bar off -e .
107+
pip install --progress-bar off -e .[all]
108108
python -m detectron2.utils.collect_env
109109
110110
run_unittests: &run_unittests

.github/workflows/workflow.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ jobs:
6565
6666
- name: Build and install
6767
run: |
68-
CC=clang CXX=clang++ python -m pip install -e .
68+
CC=clang CXX=clang++ python -m pip install -e .[all]
6969
python -m detectron2.utils.collect_env
7070
- name: Run unittests
7171
run: python -m pytest -n 4 -v tests/

datasets/README.md

+13
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,13 @@ cityscapes/
8888
...
8989
val/
9090
test/
91+
# below are generated Cityscapes panoptic annotation
92+
cityscapes_panoptic_train.json
93+
cityscapes_panoptic_train/
94+
cityscapes_panoptic_val.json
95+
cityscapes_panoptic_val/
96+
cityscapes_panoptic_test.json
97+
cityscapes_panoptic_test/
9198
leftImg8bit/
9299
train/
93100
val/
@@ -104,6 +111,12 @@ CITYSCAPES_DATASET=/path/to/abovementioned/cityscapes python cityscapesscripts/p
104111
```
105112
These files are not needed for instance segmentation.
106113

114+
Note: to generate Cityscapes panoptic dataset, run cityscapesescript with:
115+
```
116+
CITYSCAPES_DATASET=/path/to/abovementioned/cityscapes python cityscapesscripts/preparation/createPanopticImgs.py
117+
```
118+
These files are not needed for semantic and instance segmentation.
119+
107120
## Expected dataset structure for [Pascal VOC](http://host.robots.ox.ac.uk/pascal/VOC/index.html):
108121
```
109122
VOC20{07,12}/

datasets/prepare_panoptic_fpn.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,11 @@
99
import os
1010
import time
1111
from fvcore.common.download import download
12+
from panopticapi.utils import rgb2id
1213
from PIL import Image
1314

1415
from detectron2.data.datasets.builtin_meta import COCO_CATEGORIES
1516

16-
from panopticapi.utils import rgb2id
17-
1817

1918
def _process_panoptic_to_semantic(input_panoptic, output_semantic, segments, id_map):
2019
panoptic = np.asarray(Image.open(input_panoptic), dtype=np.uint32)

detectron2/data/datasets/builtin.py

+2
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323

2424
from .builtin_meta import ADE20K_SEM_SEG_CATEGORIES, _get_builtin_metadata
2525
from .cityscapes import load_cityscapes_instances, load_cityscapes_semantic
26+
from .cityscapes_panoptic import register_all_cityscapes_panoptic
2627
from .coco import load_sem_seg
2728
from .lvis import get_lvis_instances_meta, register_lvis_instances
2829
from .pascal_voc import register_pascal_voc
@@ -244,5 +245,6 @@ def register_all_ade20k(root):
244245
register_all_coco(_root)
245246
register_all_lvis(_root)
246247
register_all_cityscapes(_root)
248+
register_all_cityscapes_panoptic(_root)
247249
register_all_pascal_voc(_root)
248250
register_all_ade20k(_root)

detectron2/data/datasets/builtin_meta.py

+25-1
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,30 @@
187187
("right_knee", "right_ankle", (255, 195, 77)),
188188
]
189189

190+
# All Cityscapes categories, together with their nice-looking visualization colors
191+
# It's from https://github.com/mcordts/cityscapesScripts/blob/master/cityscapesscripts/helpers/labels.py # noqa
192+
CITYSCAPES_CATEGORIES = [
193+
{"color": (128, 64, 128), "isthing": 0, "id": 7, "trainId": 0, "name": "road"},
194+
{"color": (244, 35, 232), "isthing": 0, "id": 8, "trainId": 1, "name": "sidewalk"},
195+
{"color": (70, 70, 70), "isthing": 0, "id": 11, "trainId": 2, "name": "building"},
196+
{"color": (102, 102, 156), "isthing": 0, "id": 12, "trainId": 3, "name": "wall"},
197+
{"color": (190, 153, 153), "isthing": 0, "id": 13, "trainId": 4, "name": "fence"},
198+
{"color": (153, 153, 153), "isthing": 0, "id": 17, "trainId": 5, "name": "pole"},
199+
{"color": (250, 170, 30), "isthing": 0, "id": 19, "trainId": 6, "name": "traffic light"},
200+
{"color": (220, 220, 0), "isthing": 0, "id": 20, "trainId": 7, "name": "traffic sign"},
201+
{"color": (107, 142, 35), "isthing": 0, "id": 21, "trainId": 8, "name": "vegetation"},
202+
{"color": (152, 251, 152), "isthing": 0, "id": 22, "trainId": 9, "name": "terrain"},
203+
{"color": (70, 130, 180), "isthing": 0, "id": 23, "trainId": 10, "name": "sky"},
204+
{"color": (220, 20, 60), "isthing": 1, "id": 24, "trainId": 11, "name": "person"},
205+
{"color": (255, 0, 0), "isthing": 1, "id": 25, "trainId": 12, "name": "rider"},
206+
{"color": (0, 0, 142), "isthing": 1, "id": 26, "trainId": 13, "name": "car"},
207+
{"color": (0, 0, 70), "isthing": 1, "id": 27, "trainId": 14, "name": "truck"},
208+
{"color": (0, 60, 100), "isthing": 1, "id": 28, "trainId": 15, "name": "bus"},
209+
{"color": (0, 80, 100), "isthing": 1, "id": 31, "trainId": 16, "name": "train"},
210+
{"color": (0, 0, 230), "isthing": 1, "id": 32, "trainId": 17, "name": "motorcycle"},
211+
{"color": (119, 11, 32), "isthing": 1, "id": 33, "trainId": 18, "name": "bicycle"},
212+
]
213+
190214
# fmt: off
191215
ADE20K_SEM_SEG_CATEGORIES = [
192216
"wall", "building", "sky", "floor", "tree", "ceiling", "road, route", "bed", "window ", "grass", "cabinet", "sidewalk, pavement", "person", "earth, ground", "door", "table", "mountain, mount", "plant", "curtain", "chair", "car", "water", "painting, picture", "sofa", "shelf", "house", "sea", "mirror", "rug", "field", "armchair", "seat", "fence", "desk", "rock, stone", "wardrobe, closet, press", "lamp", "tub", "rail", "cushion", "base, pedestal, stand", "box", "column, pillar", "signboard, sign", "chest of drawers, chest, bureau, dresser", "counter", "sand", "sink", "skyscraper", "fireplace", "refrigerator, icebox", "grandstand, covered stand", "path", "stairs", "runway", "case, display case, showcase, vitrine", "pool table, billiard table, snooker table", "pillow", "screen door, screen", "stairway, staircase", "river", "bridge, span", "bookcase", "blind, screen", "coffee table", "toilet, can, commode, crapper, pot, potty, stool, throne", "flower", "book", "hill", "bench", "countertop", "stove", "palm, palm tree", "kitchen island", "computer", "swivel chair", "boat", "bar", "arcade machine", "hovel, hut, hutch, shack, shanty", "bus", "towel", "light", "truck", "tower", "chandelier", "awning, sunshade, sunblind", "street lamp", "booth", "tv", "plane", "dirt track", "clothes", "pole", "land, ground, soil", "bannister, banister, balustrade, balusters, handrail", "escalator, moving staircase, moving stairway", "ottoman, pouf, pouffe, puff, hassock", "bottle", "buffet, counter, sideboard", "poster, posting, placard, notice, bill, card", "stage", "van", "ship", "fountain", "conveyer belt, conveyor belt, conveyer, conveyor, transporter", "canopy", "washer, automatic washer, washing machine", "plaything, toy", "pool", "stool", "barrel, cask", "basket, handbasket", "falls", "tent", "bag", "minibike, motorbike", "cradle", "oven", "ball", "food, solid food", "step, stair", "tank, storage tank", "trade name", "microwave", "pot", "animal", "bicycle", "lake", "dishwasher", "screen", "blanket, cover", "sculpture", "hood, exhaust hood", "sconce", "vase", "traffic light", "tray", "trash can", "fan", "pier", "crt screen", "plate", "monitor", "bulletin board", "shower", "radiator", "glass, drinking glass", "clock", "flag", # noqa
@@ -264,7 +288,7 @@ def _get_builtin_metadata(dataset_name):
264288
CITYSCAPES_STUFF_CLASSES = [
265289
"road", "sidewalk", "building", "wall", "fence", "pole", "traffic light",
266290
"traffic sign", "vegetation", "terrain", "sky", "person", "rider", "car",
267-
"truck", "bus", "train", "motorcycle", "bicycle", "license plate",
291+
"truck", "bus", "train", "motorcycle", "bicycle",
268292
]
269293
# fmt: on
270294
return {
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,189 @@
1+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2+
import json
3+
import logging
4+
import os
5+
from fvcore.common.file_io import PathManager
6+
7+
from detectron2.data import DatasetCatalog, MetadataCatalog
8+
from detectron2.data.datasets.builtin_meta import CITYSCAPES_CATEGORIES
9+
10+
"""
11+
This file contains functions to register the Cityscapes panoptic dataset to the DatasetCatalog.
12+
"""
13+
14+
15+
logger = logging.getLogger(__name__)
16+
17+
18+
def get_cityscapes_panoptic_files(image_dir, gt_dir, json_info):
19+
files = []
20+
# scan through the directory
21+
cities = PathManager.ls(image_dir)
22+
logger.info(f"{len(cities)} cities found in '{image_dir}'.")
23+
image_dict = {}
24+
for city in cities:
25+
city_img_dir = os.path.join(image_dir, city)
26+
for basename in PathManager.ls(city_img_dir):
27+
image_file = os.path.join(city_img_dir, basename)
28+
29+
suffix = "_leftImg8bit.png"
30+
assert basename.endswith(suffix), basename
31+
basename = os.path.basename(basename)[: -len(suffix)]
32+
33+
image_dict[basename] = image_file
34+
35+
for ann in json_info["annotations"]:
36+
image_file = image_dict.get(ann["image_id"], None)
37+
assert image_file is not None, "No image {} found for annotation {}".format(
38+
ann["image_id"], ann["file_name"]
39+
)
40+
label_file = os.path.join(gt_dir, ann["file_name"])
41+
segments_info = ann["segments_info"]
42+
43+
files.append((image_file, label_file, segments_info))
44+
45+
assert len(files), "No images found in {}".format(image_dir)
46+
assert PathManager.isfile(files[0][0]), files[0][0]
47+
assert PathManager.isfile(files[0][1]), files[0][1]
48+
return files
49+
50+
51+
def load_cityscapes_panoptic(image_dir, gt_dir, gt_json, meta):
52+
"""
53+
Args:
54+
image_dir (str): path to the raw dataset. e.g., "~/cityscapes/leftImg8bit/train".
55+
gt_dir (str): path to the raw annotations. e.g.,
56+
"~/cityscapes/gtFine/cityscapes_panoptic_train".
57+
gt_json (str): path to the json file. e.g.,
58+
"~/cityscapes/gtFine/cityscapes_panoptic_train.json".
59+
meta (dict): dictionary containing "thing_dataset_id_to_contiguous_id"
60+
and "stuff_dataset_id_to_contiguous_id" to map category ids to
61+
contiguous ids for training.
62+
63+
Returns:
64+
list[dict]: a list of dicts in Detectron2 standard format. (See
65+
`Using Custom Datasets </tutorials/datasets.html>`_ )
66+
"""
67+
68+
def _convert_category_id(segment_info, meta):
69+
if segment_info["category_id"] in meta["thing_dataset_id_to_contiguous_id"]:
70+
segment_info["category_id"] = meta["thing_dataset_id_to_contiguous_id"][
71+
segment_info["category_id"]
72+
]
73+
else:
74+
segment_info["category_id"] = meta["stuff_dataset_id_to_contiguous_id"][
75+
segment_info["category_id"]
76+
]
77+
return segment_info
78+
79+
assert os.path.exists(
80+
gt_json
81+
), "Please run `python cityscapesscripts/preparation/createPanopticImgs.py` to generate label files." # noqa
82+
with open(gt_json) as f:
83+
json_info = json.load(f)
84+
files = get_cityscapes_panoptic_files(image_dir, gt_dir, json_info)
85+
ret = []
86+
for image_file, label_file, segments_info in files:
87+
sem_label_file = (
88+
image_file.replace("leftImg8bit", "gtFine").split(".")[0] + "_labelTrainIds.png"
89+
)
90+
segments_info = [_convert_category_id(x, meta) for x in segments_info]
91+
ret.append(
92+
{
93+
"file_name": image_file,
94+
"image_id": "_".join(
95+
os.path.splitext(os.path.basename(image_file))[0].split("_")[:3]
96+
),
97+
"sem_seg_file_name": sem_label_file,
98+
"pan_seg_file_name": label_file,
99+
"segments_info": segments_info,
100+
}
101+
)
102+
assert len(ret), f"No images found in {image_dir}!"
103+
assert PathManager.isfile(
104+
ret[0]["sem_seg_file_name"]
105+
), "Please generate labelTrainIds.png with cityscapesscripts/preparation/createTrainIdLabelImgs.py" # noqa
106+
assert PathManager.isfile(
107+
ret[0]["pan_seg_file_name"]
108+
), "Please generate panoptic annotation with python cityscapesscripts/preparation/createPanopticImgs.py" # noqa
109+
return ret
110+
111+
112+
_RAW_CITYSCAPES_PANOPTIC_SPLITS = {
113+
"cityscapes_fine_panoptic_train": (
114+
"cityscapes/leftImg8bit/train",
115+
"cityscapes/gtFine/cityscapes_panoptic_train",
116+
"cityscapes/gtFine/cityscapes_panoptic_train.json",
117+
),
118+
"cityscapes_fine_panoptic_val": (
119+
"cityscapes/leftImg8bit/val",
120+
"cityscapes/gtFine/cityscapes_panoptic_val",
121+
"cityscapes/gtFine/cityscapes_panoptic_val.json",
122+
),
123+
# "cityscapes_fine_panoptic_test": not supported yet
124+
}
125+
126+
127+
def register_all_cityscapes_panoptic(root):
128+
meta = {}
129+
# The following metadata maps contiguous id from [0, #thing categories +
130+
# #stuff categories) to their names and colors. We have to replica of the
131+
# same name and color under "thing_*" and "stuff_*" because the current
132+
# visualization function in D2 handles thing and class classes differently
133+
# due to some heuristic used in Panoptic FPN. We keep the same naming to
134+
# enable reusing existing visualization functions.
135+
thing_classes = [k["name"] for k in CITYSCAPES_CATEGORIES]
136+
thing_colors = [k["color"] for k in CITYSCAPES_CATEGORIES]
137+
stuff_classes = [k["name"] for k in CITYSCAPES_CATEGORIES]
138+
stuff_colors = [k["color"] for k in CITYSCAPES_CATEGORIES]
139+
140+
meta["thing_classes"] = thing_classes
141+
meta["thing_colors"] = thing_colors
142+
meta["stuff_classes"] = stuff_classes
143+
meta["stuff_colors"] = stuff_colors
144+
145+
# There are three types of ids in panoptic:
146+
# (1) category id: like semantic segmentation, it is the class id for each
147+
# pixel. Since there are some classes not used in evaluation, the category
148+
# id is not always contiguous and thus we have two set of category ids:
149+
# - original category id: category id in the original dataset, mainly
150+
# used for evaluation.
151+
# - contiguous category id: [0, #classes), in order to train the linear
152+
# softmax classifier.
153+
# (2) instance id: this id is used to differentiate different instances from
154+
# the same category. For "stuff" classes, the instance id is always 0; for
155+
# "thing" classes, the instance id starts from 1 and 0 is reserved for
156+
# ignored instances (e.g. crowd annotation).
157+
# (3) panoptic id: this is the compact id that encode both category and
158+
# instance id by: category_id * label_divisor + instance_id. Following
159+
# the Cityscapes format, we set label_divisor = 1000.
160+
thing_dataset_id_to_contiguous_id = {}
161+
stuff_dataset_id_to_contiguous_id = {}
162+
163+
for k in CITYSCAPES_CATEGORIES:
164+
if k["isthing"] == 1:
165+
thing_dataset_id_to_contiguous_id[k["id"]] = k["trainId"]
166+
else:
167+
stuff_dataset_id_to_contiguous_id[k["id"]] = k["trainId"]
168+
169+
meta["thing_dataset_id_to_contiguous_id"] = thing_dataset_id_to_contiguous_id
170+
meta["stuff_dataset_id_to_contiguous_id"] = stuff_dataset_id_to_contiguous_id
171+
172+
for key, (image_dir, gt_dir, gt_json) in _RAW_CITYSCAPES_PANOPTIC_SPLITS.items():
173+
image_dir = os.path.join(root, image_dir)
174+
gt_dir = os.path.join(root, gt_dir)
175+
gt_json = os.path.join(root, gt_json)
176+
177+
DatasetCatalog.register(
178+
key, lambda x=image_dir, y=gt_dir, z=gt_json: load_cityscapes_panoptic(x, y, z, meta)
179+
)
180+
MetadataCatalog.get(key).set(
181+
panoptic_root=gt_dir,
182+
image_root=image_dir,
183+
panoptic_json=gt_json,
184+
gt_dir=gt_dir.replace("cityscapes_panoptic_", ""),
185+
evaluator_type="cityscapes_panoptic_seg",
186+
ignore_label=255,
187+
label_divisor=1000,
188+
**meta,
189+
)

detectron2/evaluation/cityscapes_evaluation.py

+22-15
Original file line numberDiff line numberDiff line change
@@ -62,21 +62,28 @@ def process(self, inputs, outputs):
6262
basename = os.path.splitext(os.path.basename(file_name))[0]
6363
pred_txt = os.path.join(self._temp_dir, basename + "_pred.txt")
6464

65-
output = output["instances"].to(self._cpu_device)
66-
num_instances = len(output)
67-
with open(pred_txt, "w") as fout:
68-
for i in range(num_instances):
69-
pred_class = output.pred_classes[i]
70-
classes = self._metadata.thing_classes[pred_class]
71-
class_id = name2label[classes].id
72-
score = output.scores[i]
73-
mask = output.pred_masks[i].numpy().astype("uint8")
74-
png_filename = os.path.join(
75-
self._temp_dir, basename + "_{}_{}.png".format(i, classes)
76-
)
77-
78-
Image.fromarray(mask * 255).save(png_filename)
79-
fout.write("{} {} {}\n".format(os.path.basename(png_filename), class_id, score))
65+
if "instances" in output:
66+
output = output["instances"].to(self._cpu_device)
67+
num_instances = len(output)
68+
with open(pred_txt, "w") as fout:
69+
for i in range(num_instances):
70+
pred_class = output.pred_classes[i]
71+
classes = self._metadata.thing_classes[pred_class]
72+
class_id = name2label[classes].id
73+
score = output.scores[i]
74+
mask = output.pred_masks[i].numpy().astype("uint8")
75+
png_filename = os.path.join(
76+
self._temp_dir, basename + "_{}_{}.png".format(i, classes)
77+
)
78+
79+
Image.fromarray(mask * 255).save(png_filename)
80+
fout.write(
81+
"{} {} {}\n".format(os.path.basename(png_filename), class_id, score)
82+
)
83+
else:
84+
# Cityscapes requires a prediction file for every ground truth image.
85+
with open(pred_txt, "w") as fout:
86+
pass
8087

8188
def evaluate(self):
8289
"""

detectron2/evaluation/panoptic_evaluation.py

+27
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import itertools
55
import json
66
import logging
7+
import numpy as np
78
import os
89
import tempfile
910
from collections import OrderedDict
@@ -41,6 +42,7 @@ def __init__(self, dataset_name, output_dir):
4142
v: k for k, v in self._metadata.stuff_dataset_id_to_contiguous_id.items()
4243
}
4344

45+
PathManager.mkdirs(output_dir)
4446
self._predictions_json = os.path.join(output_dir, "predictions.json")
4547

4648
def reset(self):
@@ -67,6 +69,31 @@ def process(self, inputs, outputs):
6769
for input, output in zip(inputs, outputs):
6870
panoptic_img, segments_info = output["panoptic_seg"]
6971
panoptic_img = panoptic_img.cpu().numpy()
72+
if segments_info is None:
73+
# If "segments_info" is None, we assume "panoptic_img" is a
74+
# H*W int32 image storing the panoptic_id in the format of
75+
# category_id * label_divisor + instance_id. We reserve -1 for
76+
# VOID label, and add 1 to panoptic_img since the official
77+
# evaluation script uses 0 for VOID label.
78+
label_divisor = self._metadata.label_divisor
79+
segments_info = []
80+
for panoptic_label in np.unique(panoptic_img):
81+
if panoptic_label == -1:
82+
# VOID region.
83+
continue
84+
pred_class = panoptic_label // label_divisor
85+
isthing = (
86+
pred_class in self._metadata.thing_dataset_id_to_contiguous_id.values()
87+
)
88+
segments_info.append(
89+
{
90+
"id": int(panoptic_label) + 1,
91+
"category_id": int(pred_class),
92+
"isthing": bool(isthing),
93+
}
94+
)
95+
# Official evaluation script uses 0 for VOID label.
96+
panoptic_img += 1
7097

7198
file_name = os.path.basename(input["file_name"])
7299
file_name_png = os.path.splitext(file_name)[0] + ".png"

0 commit comments

Comments
 (0)