Skip to content

Commit 507697e

Browse files
authored
Add AutoPET dataset (#213)
Add AutoPET dataset
1 parent 5af5259 commit 507697e

File tree

5 files changed

+155
-5
lines changed

5 files changed

+155
-5
lines changed

scripts/datasets/check_autopet.py

+24
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
from torch_em.util.debug import check_loader
2+
from torch_em.data.datasets.medical import get_autopet_loader
3+
from torch_em.data import MinInstanceSampler
4+
5+
AUTOPET_ROOT = "/scratch/projects/nim00007/data/autopet/"
6+
7+
8+
# TODO: need to rescale the inputs using raw transform (preferably to 8-bit)
9+
def check_autopet():
10+
loader = get_autopet_loader(
11+
path=AUTOPET_ROOT,
12+
patch_shape=(1, 512, 512),
13+
batch_size=2,
14+
ndim=2,
15+
download=True,
16+
modality=None,
17+
sampler=MinInstanceSampler()
18+
)
19+
print(f"Length of the loader: {len(loader)}")
20+
check_loader(loader, 8, plt=True, save_path="autopet.png")
21+
22+
23+
if __name__ == "__main__":
24+
check_autopet()
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
1+
from .autopet import get_autopet_loader
12
from .btcv import get_btcv_dataset, get_btcv_loader
+93
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
import os
2+
from glob import glob
3+
from typing import Tuple, Optional, Union
4+
5+
import torch
6+
7+
import torch_em
8+
9+
from .. import util
10+
11+
12+
AUTOPET_DATA = "http://193.196.20.155/data/autoPET/data/nifti.zip"
13+
CHECKSUM = "0ac2186ea6d936ff41ce605c6a9588aeb20f031085589897dbab22fc82a12972"
14+
15+
16+
def _assort_autopet_dataset(path, download):
17+
target_dir = os.path.join(path, "AutoPET-II")
18+
if os.path.exists(target_dir):
19+
return
20+
21+
os.makedirs(target_dir)
22+
zip_path = os.path.join(path, "autopet.zip")
23+
print("The AutoPET data is not available yet and will be downloaded.")
24+
print("Note that this dataset is large, so this step can take several hours (depending on your internet).")
25+
util.download_source(path=zip_path, url=AUTOPET_DATA, download=download, checksum=CHECKSUM)
26+
util.unzip(zip_path, target_dir, remove=False)
27+
28+
29+
def _get_paths(path, modality):
30+
root_dir = os.path.join(path, "AutoPET-II", "FDG-PET-CT-Lesions", "*", "*")
31+
ct_paths = sorted(glob(os.path.join(root_dir, "CTres.nii.gz")))
32+
pet_paths = sorted(glob(os.path.join(root_dir, "SUV.nii.gz")))
33+
label_paths = sorted(glob(os.path.join(root_dir, "SEG.nii.gz")))
34+
if modality is None:
35+
raw_paths = [(ct_path, pet_path) for ct_path, pet_path in zip(ct_paths, pet_paths)]
36+
elif modality == "CT":
37+
raw_paths = ct_paths
38+
elif modality == "PET":
39+
raw_paths = pet_paths
40+
else:
41+
raise ValueError("Choose from the available modalities: `CT` / `PET`")
42+
43+
return raw_paths, label_paths
44+
45+
46+
def get_autopet_dataset(
47+
path: str,
48+
patch_shape: Tuple[int, ...],
49+
ndim: int,
50+
modality: Optional[str] = None,
51+
download: bool = False,
52+
**kwargs
53+
) -> torch.utils.data.Dataset:
54+
"""Dataset for lesion segmentation in whole-body FDG-PET/CT scans.
55+
56+
This dataset is fromt the `AutoPET II - Automated Lesion Segmentation in PET/CT - Domain Generalization` challenge.
57+
Link: https://autopet-ii.grand-challenge.org/
58+
Please cite it if you use this dataset for publication.
59+
60+
Arguments:
61+
path: The path where the zip files / the prepared dataset exists.
62+
- Expected initial structure: `path` should have ...
63+
patch_shape: The patch shape (for 2d or 3d patches)
64+
ndim: The dimensions of the inputs (use `2` for getting 2d patches, and `3` for getting 3d patches)
65+
modality: The modality for using the AutoPET dataset.
66+
- (default: None) If passed `None`, it takes both the modalities as inputs
67+
download: Downloads the dataset
68+
69+
Returns:
70+
dataset: The segmentation dataset for the respective modalities.
71+
"""
72+
assert isinstance(modality, Union[str, None])
73+
_assort_autopet_dataset(path, download)
74+
raw_paths, label_paths = _get_paths(path, modality)
75+
dataset = torch_em.default_segmentation_dataset(
76+
raw_paths, "data", label_paths, "data",
77+
patch_shape, ndim=ndim, with_channels=modality is None,
78+
**kwargs
79+
)
80+
if "sampler" in kwargs:
81+
for ds in dataset.datasets:
82+
ds.max_sampling_attempts = 5000
83+
return dataset
84+
85+
86+
def get_autopet_loader(
87+
path, patch_shape, batch_size, ndim, modality=None, download=False, **kwargs
88+
):
89+
"""Dataloader for lesion segmentation in whole-body FDG-PET/CT scans. See `get_autopet_dataset` for details."""
90+
ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs)
91+
ds = get_autopet_dataset(path, patch_shape, ndim, modality, download, **ds_kwargs)
92+
loader = torch_em.get_data_loader(ds, batch_size=batch_size, **loader_kwargs)
93+
return loader

torch_em/segmentation.py

+9-3
Original file line numberDiff line numberDiff line change
@@ -37,12 +37,18 @@ def samples_to_datasets(n_samples, raw_paths, raw_key, split="uniform"):
3737

3838

3939
def check_paths(raw_paths, label_paths):
40-
if type(raw_paths) != type(label_paths):
40+
if not isinstance(raw_paths, type(label_paths)):
4141
raise ValueError(f"Expect raw and label paths of same type, got {type(raw_paths)}, {type(label_paths)}")
4242

4343
def _check_path(path):
44-
if not os.path.exists(path):
45-
raise ValueError(f"Could not find path {path}")
44+
if isinstance(path, str):
45+
if not os.path.exists(path):
46+
raise ValueError(f"Could not find path {path}")
47+
else:
48+
# check for single path or multiple paths (for same volume - supports multi-modal inputs)
49+
for per_path in path:
50+
if not os.path.exists(per_path):
51+
raise ValueError(f"Could not find path {per_path}")
4652

4753
if isinstance(raw_paths, str):
4854
_check_path(raw_paths)

torch_em/util/image.py

+28-2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# TODO this should be partially refactored into elf.io before the next elf release
22
# and then be used in image_stack_wrapper as welll
33
import os
4+
import numpy as np
45

56
from elf.io import open_file
67
try:
@@ -38,8 +39,33 @@ def load_image(image_path, memmap=True):
3839
return imageio.imread(image_path)
3940

4041

42+
class MultiDatasetWrapper:
43+
def __init__(self, *file_datasets):
44+
# Make sure we have the same shapes.
45+
reference_shape = file_datasets[0].shape
46+
assert all(reference_shape == ds.shape for ds in file_datasets)
47+
self.file_datasets = file_datasets
48+
49+
self.shape = (len(self.file_datasets),) + reference_shape
50+
51+
def __getitem__(self, index):
52+
channel_index, spatial_index = index[:1], index[1:]
53+
data = []
54+
for ds in self.file_datasets:
55+
ds_data = ds[spatial_index]
56+
data.append(ds_data)
57+
data = np.stack(data)
58+
data = data[channel_index]
59+
return data
60+
61+
4162
def load_data(path, key, mode="r"):
42-
if key is None:
63+
have_single_file = isinstance(path, str)
64+
if key is None and have_single_file:
4365
return load_image(path)
44-
else:
66+
elif key is None and not have_single_file:
67+
return np.stack([load_image(p) for p in path])
68+
elif key is not None and have_single_file:
4569
return open_file(path, mode=mode)[key]
70+
elif key is not None and not have_single_file:
71+
return MultiDatasetWrapper(*[open_file(p, mode=mode)[key] for p in path])

0 commit comments

Comments
 (0)