Skip to content

Commit 07367e0

Browse files
authored
Merge pull request #19 from George-Jiao/v2
init_commit for v2
2 parents cd2949e + b9c3e54 commit 07367e0

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

73 files changed

+4288
-1599
lines changed

.gitignore

+5-5
Original file line numberDiff line numberDiff line change
@@ -143,9 +143,9 @@ dmypy.json
143143
*.h5
144144
*.tar
145145
*.tar.gz
146-
*.ckpt
147146

148-
configs/local/default.yaml
149-
/data/
150-
/logs/
151-
.env
147+
# Aim logging
148+
.aim
149+
assets/
150+
logs/
151+
*/local/*

.pre-commit-config.yaml

+62-35
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,12 @@ default_language_version:
33

44
repos:
55
- repo: https://github.com/pre-commit/pre-commit-hooks
6-
rev: v4.4.0
6+
rev: v5.0.0
77
hooks:
88
# list of supported hooks: https://pre-commit.com/hooks.html
99
- id: trailing-whitespace
1010
- id: end-of-file-fixer
11-
- id: check-docstring-first
11+
# - id: check-docstring-first
1212
- id: check-yaml
1313
- id: debug-statements
1414
- id: detect-private-key
@@ -19,41 +19,63 @@ repos:
1919

2020
# python code formatting
2121
- repo: https://github.com/psf/black
22-
rev: 23.1.0
22+
rev: 24.4.2
2323
hooks:
2424
- id: black
25-
args: [--line-length, "99"]
25+
# args: [--line-length, "99"]
2626

2727
# python import sorting
2828
- repo: https://github.com/PyCQA/isort
29-
rev: 5.12.0
29+
rev: 5.13.2
3030
hooks:
3131
- id: isort
3232
args: ["--profile", "black", "--filter-files"]
3333

3434
# python upgrading syntax to newer version
3535
- repo: https://github.com/asottile/pyupgrade
36-
rev: v3.3.1
36+
rev: v3.15.2
3737
hooks:
3838
- id: pyupgrade
39-
args: [--py38-plus]
39+
# args: [--py38-plus]
4040

41-
# python docstring formatting
42-
- repo: https://github.com/myint/docformatter
43-
rev: v1.5.1
44-
hooks:
45-
- id: docformatter
46-
args: [--in-place, --wrap-summaries=99, --wrap-descriptions=99]
41+
# # python docstring formatting
42+
# - repo: https://github.com/myint/docformatter
43+
# rev: v1.7.4
44+
# hooks:
45+
# - id: docformatter
46+
# args:
47+
# [
48+
# --in-place,
49+
# --wrap-summaries=99,
50+
# --wrap-descriptions=99,
51+
# --style=sphinx,
52+
# --black,
53+
# ]
54+
55+
# # python docstring coverage checking
56+
# - repo: https://github.com/econchick/interrogate
57+
# rev: 1.5.0 # or master if you're bold
58+
# hooks:
59+
# - id: interrogate
60+
# args:
61+
# [
62+
# --verbose,
63+
# --fail-under=80,
64+
# --ignore-init-module,
65+
# --ignore-init-method,
66+
# --ignore-module,
67+
# --ignore-nested-functions,
68+
# -vv,
69+
# ]
4770

4871
# python check (PEP8), programming errors and code complexity
4972
- repo: https://github.com/PyCQA/flake8
50-
rev: 6.0.0
73+
rev: 7.0.0
5174
hooks:
5275
- id: flake8
53-
args:
54-
[
76+
args: [
5577
"--extend-ignore",
56-
"E203,E402,E501,F401,F841",
78+
"E203,E402,E501,F401,F841,E722", # add E722 to ignore bare except
5779
"--exclude",
5880
"logs/*,data/*",
5981
]
@@ -67,34 +89,36 @@ repos:
6789

6890
# yaml formatting
6991
- repo: https://github.com/pre-commit/mirrors-prettier
70-
rev: v3.0.0-alpha.6
92+
rev: v3.0.0
7193
hooks:
7294
- id: prettier
73-
types: [yaml]
74-
exclude: "environment.yaml"
95+
additional_dependencies:
96+
- prettier@3.3.3 # SEE: https://github.com/pre-commit/pre-commit/issues/3133
7597

7698
# shell scripts linter
77-
- repo: https://github.com/shellcheck-py/shellcheck-py
78-
rev: v0.9.0.2
79-
hooks:
80-
- id: shellcheck
99+
# - repo: https://github.com/shellcheck-py/shellcheck-py
100+
# rev: v0.10.0.1
101+
# hooks:
102+
# - id: shellcheck
81103

82104
# md formatting
83105
- repo: https://github.com/executablebooks/mdformat
84-
rev: 0.7.16
106+
rev: 0.7.17
85107
hooks:
86108
- id: mdformat
87109
args: ["--number"]
88110
additional_dependencies:
89111
- mdformat-gfm
90112
- mdformat-tables
91113
- mdformat_frontmatter
114+
- mdformat-beautysh
115+
- mdformat-black
92116
# - mdformat-toc
93117
# - mdformat-black
94118

95119
# word spelling linter
96120
- repo: https://github.com/codespell-project/codespell
97-
rev: v2.2.4
121+
rev: v2.3.0
98122
hooks:
99123
- id: codespell
100124
args:
@@ -103,21 +127,24 @@ repos:
103127

104128
# jupyter notebook cell output clearing
105129
- repo: https://github.com/kynan/nbstripout
106-
rev: 0.6.1
130+
rev: 0.7.1
107131
hooks:
108132
- id: nbstripout
109133

110134
# jupyter notebook linting
111135
- repo: https://github.com/nbQA-dev/nbQA
112-
rev: 1.6.3
136+
rev: 1.8.5
113137
hooks:
114138
- id: nbqa-black
115-
args: ["--line-length=99"]
139+
# args: ["--line-length=99"]
116140
- id: nbqa-isort
117141
args: ["--profile=black"]
118-
- id: nbqa-flake8
119-
args:
120-
[
121-
"--extend-ignore=E203,E402,E501,F401,F841",
122-
"--exclude=logs/*,data/*",
123-
]
142+
- id: nbqa-pyupgrade
143+
args: ["--py38-plus"]
144+
- id: nbqa-isort
145+
args: ["--float-to-top"]
146+
# - id: nbqa-flake8
147+
# args: [
148+
# # "--extend-ignore=E203,E402,E501,F401,F841",
149+
# "--exclude=logs/*,data/*",
150+
# ]
File renamed without changes.

USDSgen/data/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .make_dataloader import build_loader

USDSgen/data/datasets.py

+173
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,173 @@
1+
import os
2+
3+
import albumentations as A
4+
import numpy as np
5+
import torch
6+
import torchvision.transforms as T
7+
from albumentations.pytorch import ToTensorV2
8+
from PIL import Image
9+
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
10+
from timm.data.transforms import str_to_pil_interp
11+
from torch.utils.data import Dataset
12+
from torchvision import datasets
13+
14+
15+
def build_cls_dataset(config, logger):
16+
train_transforms = T.Compose(
17+
[
18+
T.Lambda(lambda img: img.convert("RGB") if img.mode != "RGB" else img),
19+
T.Resize(
20+
(config.data.img_size, config.data.img_size),
21+
interpolation=str_to_pil_interp(config.data.interpolation),
22+
),
23+
# T.RandomHorizontalFlip(),
24+
# A.RandomRotate90(p=0.5),
25+
# A.HorizontalFlip(p=0.5),
26+
# A.VerticalFlip(p=0.5),
27+
T.ToTensor(),
28+
T.Normalize(
29+
mean=torch.tensor(IMAGENET_DEFAULT_MEAN),
30+
std=torch.tensor(IMAGENET_DEFAULT_STD),
31+
),
32+
]
33+
)
34+
val_transforms = T.Compose(
35+
[
36+
T.Resize(
37+
(config.data.img_size, config.data.img_size),
38+
interpolation=str_to_pil_interp(config.data.interpolation),
39+
),
40+
T.ToTensor(),
41+
T.Normalize(
42+
mean=torch.tensor(IMAGENET_DEFAULT_MEAN),
43+
std=torch.tensor(IMAGENET_DEFAULT_STD),
44+
),
45+
]
46+
)
47+
if config.data.type == "cls_imagenet":
48+
data_path = config.data.path
49+
dataset_train = datasets.ImageFolder(
50+
os.path.join(data_path.root, data_path.split.train),
51+
transform=train_transforms,
52+
)
53+
dataset_val = datasets.ImageFolder(
54+
os.path.join(data_path.root, data_path.split.val), transform=val_transforms
55+
)
56+
dataset_test = datasets.ImageFolder(
57+
os.path.join(data_path.root, data_path.split.test), transform=val_transforms
58+
)
59+
else:
60+
raise NotImplementedError("We only support ImageNet Now.")
61+
62+
logger.info(
63+
f"Build [Cls] dataset: train images = {len(dataset_train)}, val images = {len(dataset_val)}, test images = {len(dataset_test)}"
64+
)
65+
return dataset_train, dataset_val, dataset_test
66+
67+
68+
def build_seg_dataset(config, logger):
69+
train_transforms = A.Compose(
70+
[
71+
A.Resize(width=config.data.img_size, height=config.data.img_size),
72+
A.RandomRotate90(p=0.5),
73+
A.HorizontalFlip(p=0.5),
74+
A.VerticalFlip(p=0.5),
75+
# A.RandomBrightnessContrast(brightness_limit=0.3, contrast_limit=0.3, p=0.5),
76+
A.ToFloat(max_value=255),
77+
A.Normalize(
78+
mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225), max_pixel_value=1
79+
),
80+
ToTensorV2(),
81+
]
82+
)
83+
val_transforms = A.Compose(
84+
[
85+
A.Resize(width=config.data.img_size, height=config.data.img_size),
86+
A.ToFloat(max_value=255),
87+
A.Normalize(
88+
mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225), max_pixel_value=1
89+
),
90+
ToTensorV2(),
91+
]
92+
)
93+
Dataset_class = eval(config.data.type + "Dataset")
94+
95+
dataset_train = Dataset_class(config.data, "train", train_transforms)
96+
dataset_val = Dataset_class(config.data, "val", val_transforms)
97+
dataset_test = Dataset_class(config.data, "test", val_transforms)
98+
logger.info(
99+
f"Build [Seg] dataset: train images = {len(dataset_train)}, val images = {len(dataset_val)}, test images = {len(dataset_test)}"
100+
)
101+
102+
return dataset_train, dataset_val, dataset_test
103+
104+
105+
class SegBaseDataset(Dataset):
106+
def __init__(self, DataConfig, stage, transforms=None):
107+
super().__init__()
108+
data_folder = os.path.join(DataConfig.path.root, DataConfig.path.split[stage])
109+
self.num_classes = DataConfig.num_classes
110+
self.update_datalist(data_folder)
111+
self.transforms = transforms
112+
113+
def __getitem__(self, index):
114+
image_file = self.image_list[index]
115+
mask_file = self.mask_list[index]
116+
image = np.array(Image.open(image_file).convert("RGB"))
117+
if self.num_classes == 2:
118+
mask = np.array(Image.open(mask_file).convert("1")).astype(int)
119+
else:
120+
mask = np.array(Image.open(mask_file)).astype(int)
121+
if self.transforms is not None:
122+
image_mask = self.transforms(image=image, mask=mask)
123+
image_mask["img_path"] = image_file
124+
image_mask["mask_path"] = mask_file
125+
return image_mask
126+
127+
def update_datalist(self, folder):
128+
image_path = os.path.join(folder, "image")
129+
mask_path = os.path.join(folder, "mask")
130+
# find all file in the folder and subfolder
131+
filenames = []
132+
for root, dirs, files in os.walk(image_path):
133+
for file in files:
134+
filenames.append(os.path.join(root, file))
135+
136+
# filenames = os.listdir(image_path)
137+
self.image_list = filenames
138+
self.mask_list = [i.replace(image_path, mask_path) for i in filenames]
139+
140+
def __len__(self):
141+
return len(self.image_list)
142+
143+
144+
class SegVocDataset(Dataset):
145+
def __init__(self, DataConfig, stage, transforms=None):
146+
super().__init__()
147+
self.update_datalist(DataConfig.path.root, stage, DataConfig.path.image_type)
148+
self.transforms = transforms
149+
150+
def __getitem__(self, index):
151+
image_file = self.image_list[index]
152+
mask_file = self.mask_list[index]
153+
image = np.array(Image.open(image_file).convert("RGB"))
154+
mask = np.array(Image.open(mask_file).convert("1")).astype(int)
155+
if self.transforms is not None:
156+
image_mask = self.transforms(
157+
image=image, mask=mask, img_path=image_file, mask_path=mask_file
158+
)
159+
return image_mask
160+
161+
def update_datalist(self, root, stage, image_type):
162+
filenames = np.loadtxt(
163+
os.path.join(root, "ImageSets", stage + ".txt"), dtype=str
164+
)
165+
image_filenames = [i + "." + image_type for i in filenames]
166+
mask_filenames = [i + ".png" for i in filenames]
167+
self.image_list = [os.path.join(root, "JPEGImages", i) for i in image_filenames]
168+
self.mask_list = [
169+
os.path.join(root, "SegmentationClass", i) for i in mask_filenames
170+
]
171+
172+
def __len__(self):
173+
return len(self.image_list)

0 commit comments

Comments
 (0)