Skip to content

Commit 66c30be

Browse files
authored
Refactor ctc (#215)
1 parent 507697e commit 66c30be

File tree

2 files changed

+12
-9
lines changed

2 files changed

+12
-9
lines changed

scripts/datasets/check_ctc.py

+2-4
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from torch_em.util.debug import check_loader
33
from torch_em.data.sampler import MinInstanceSampler
44

5-
ROOT = "/home/pape/Work/data/ctc/ctc-training-data"
5+
ROOT = "/scratch/projects/nim00007/sam/data/ctc/"
66

77

88
# Some of the datasets have partial sparse labels:
@@ -11,14 +11,12 @@
1111
# Maybe depends on the split?!
1212
def check_ctc_segmentation():
1313
for name in CTC_URLS.keys():
14-
if not name.startswith("DIC"):
15-
continue
1614
print("Checking dataset", name)
1715
loader = get_ctc_segmentation_loader(
1816
ROOT, name, (1, 512, 512), 1, download=True,
1917
sampler=MinInstanceSampler()
2018
)
21-
check_loader(loader, 8, instance_labels=True)
19+
check_loader(loader, 8, plt=True, save_path="ctc.png")
2220

2321

2422
if __name__ == "__main__":

torch_em/data/datasets/ctc.py

+10-5
Original file line numberDiff line numberDiff line change
@@ -40,11 +40,14 @@ def _require_ctc_dataset(path, dataset_name, download):
4040

4141
data_path = os.path.join(path, dataset_name)
4242

43-
if not os.path.exists(data_path):
44-
url, checksum = CTC_URLS[dataset_name], CTC_CHECKSUMS[dataset_name]
45-
zip_path = os.path.join(path, f"{dataset_name}.zip")
46-
util.download_source(zip_path, url, download, checksum=checksum)
47-
util.unzip(zip_path, path, remove=True)
43+
if os.path.exists(data_path):
44+
return data_path
45+
46+
os.makedirs(data_path)
47+
url, checksum = CTC_URLS[dataset_name], CTC_CHECKSUMS[dataset_name]
48+
zip_path = os.path.join(path, f"{dataset_name}.zip")
49+
util.download_source(zip_path, url, download, checksum=checksum)
50+
util.unzip(zip_path, path, remove=True)
4851

4952
return data_path
5053

@@ -101,6 +104,8 @@ def get_ctc_segmentation_dataset(
101104
splits = glob(os.path.join(data_path, "*_GT"))
102105
splits = [os.path.basename(split) for split in splits]
103106
splits = [split.rstrip("_GT") for split in splits]
107+
else:
108+
splits = split
104109

105110
image_path, label_path = _require_gt_images(data_path, splits)
106111

0 commit comments

Comments
 (0)