Skip to content

Commit 9c71393

Browse files
committed
Enable automatic downloads for conic
1 parent 0acad72 commit 9c71393

File tree

3 files changed

+77
-48
lines changed

3 files changed

+77
-48
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
import os
2+
import sys
3+
4+
from torch_em.util.debug import check_loader
5+
from torch_em.data.datasets import get_conic_loader
6+
7+
8+
sys.path.append("..")
9+
10+
11+
def check_conic():
12+
# from util import ROOT
13+
ROOT = "/media/anwai/ANWAI/data"
14+
15+
loader = get_conic_loader(
16+
path=os.path.join(ROOT, "conic"),
17+
split="train",
18+
batch_size=2,
19+
patch_shape=(1, 512, 512),
20+
label_choice="semantic",
21+
download=True,
22+
)
23+
24+
check_loader(loader, 8)
25+
26+
27+
if __name__ == "__main__":
28+
check_conic()

torch_em/data/datasets/histopathology/conic.py

+35-37
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,16 @@
1-
21
"""The CONIC dataset contains annotations for nucleus segmentation
32
in histopathology images in H&E stained colon tissue.
43
5-
This dataset is from the publication https://doi.org/10.48550/arXiv.2303.06274.
4+
This dataset is from the publication https://doi.org/10.1016/j.media.2023.103047.
65
Please cite it if you use this dataset for your research.
76
"""
87

98
import os
10-
import numpy as np
119
from glob import glob
12-
from typing import Tuple, Union, List, Literal
13-
import gdown
1410
from tqdm import tqdm
11+
from typing import Tuple, Union, List, Literal
1512

13+
import numpy as np
1614
import pandas as pd
1715

1816
from torch.utils.data import Dataset, DataLoader
@@ -23,17 +21,17 @@
2321
from sklearn.model_selection import StratifiedShuffleSplit
2422

2523

26-
URL = "https://drive.google.com/drive/folders/1il9jG7uA4-ebQ_lNmXbbF2eOK9uNwheb"
24+
URL = "https://drive.google.com/drive/folders/1il9jG7uA4-ebQ_lNmXbbF2eOK9uNwheb?usp=sharing"
2725

2826

2927
def _create_split_list(path, split):
30-
# Ref. HoVerNet repo: https://github.com/vqdang/hover_net/blob/conic/generate_split.py
28+
# source: HoVerNet repo: https://github.com/vqdang/hover_net/blob/conic/generate_split.py.
3129
# We take the FOLD_IDX = 0 as used for the baseline model
30+
3231
split_csv = os.path.join(path, "split.csv")
3332

3433
if os.path.exists(split_csv):
3534
split_df = pd.read_csv(split_csv)
36-
3735
else:
3836
SEED = 5
3937
info = pd.read_csv(os.path.join(path, "patch_info.csv"))
@@ -46,43 +44,39 @@ def _create_split_list(path, split):
4644
_, cohort_sources = np.unique(cohort_sources, return_inverse=True)
4745

4846
num_trials = 10
49-
splitter = StratifiedShuffleSplit(
50-
n_splits=num_trials,
51-
train_size=0.8,
52-
test_size=0.2,
53-
random_state=SEED
54-
)
47+
splitter = StratifiedShuffleSplit(n_splits=num_trials, train_size=0.8, test_size=0.2, random_state=SEED)
5548

5649
splits = {}
5750
split_generator = splitter.split(img_sources, cohort_sources)
5851
for train_indices, valid_indices in split_generator:
5952
train_cohorts = img_sources[train_indices]
6053
valid_cohorts = img_sources[valid_indices]
54+
6155
assert np.intersect1d(train_cohorts, valid_cohorts).size == 0
56+
6257
train_names = [
63-
file_name
64-
for file_name in file_names
65-
for source in train_cohorts
66-
if source == file_name.split('-')[0]
58+
file_name for file_name in file_names for source in train_cohorts if source == file_name.split('-')[0]
6759
]
6860
valid_names = [
69-
file_name
70-
for file_name in file_names
71-
for source in valid_cohorts
72-
if source == file_name.split('-')[0]
61+
file_name for file_name in file_names for source in valid_cohorts if source == file_name.split('-')[0]
7362
]
63+
7464
train_names = np.unique(train_names)
7565
valid_names = np.unique(valid_names)
7666
print(f'Train: {len(train_names):04d} - Valid: {len(valid_names):04d}')
67+
7768
assert np.intersect1d(train_names, valid_names).size == 0
69+
7870
train_indices = [file_names.index(v) for v in train_names]
7971
valid_indices = [file_names.index(v) for v in valid_names]
8072

8173
while len(train_indices) > len(valid_indices):
8274
valid_indices.append(np.nan)
75+
8376
splits['train'] = train_indices
8477
splits['test'] = valid_indices
8578
break
79+
8680
split_df = pd.DataFrame(splits)
8781
split_df.to_csv(split_csv, index=False)
8882

@@ -91,7 +85,6 @@ def _create_split_list(path, split):
9185

9286

9387
def _extract_images(split, path):
94-
import h5py
9588

9689
split_list = _create_split_list(path, split)
9790

@@ -102,8 +95,9 @@ def _extract_images(split, path):
10295
raw = []
10396
semantic_masks = []
10497

105-
for idx, (image, label) in tqdm(enumerate(zip(images, labels)), desc=f"Extracting {split} data",
106-
total=images.shape[0]):
98+
for idx, (image, label) in tqdm(
99+
enumerate(zip(images, labels)), desc=f"Extracting '{split}' data", total=images.shape[0]
100+
):
107101
if idx not in split_list:
108102
continue
109103

@@ -115,37 +109,41 @@ def _extract_images(split, path):
115109
instance_masks = np.stack(instance_masks)
116110
semantic_masks = np.stack(semantic_masks)
117111

118-
output_file = os.path.join(path, f"{split}.h5")
119-
with h5py.File(output_file, "a") as f:
112+
import h5py
113+
with h5py.File(os.path.join(path, f"{split}.h5"), "a") as f:
120114
f.create_dataset("raw", data=raw, compression="gzip")
121115
f.create_dataset("labels/instance", data=instance_masks, compression="gzip")
122116
f.create_dataset("labels/semantic", data=semantic_masks, compression="gzip")
123117

124118

125-
def get_conic_data(path: Union[os.PathLike, str], split: Literal["train", "test"], download: bool = False):
119+
def get_conic_data(path: Union[os.PathLike, str], split: Literal["train", "test"], download: bool = False) -> str:
126120
"""Download the CONIC dataset for nucleus segmentation.
127121
128122
Args:
129123
path: Filepath to a folder where the downloaded data will be saved.
130124
split: The choice of data split.
131125
download: Whether to download the data if it is not present.
126+
127+
Returns:
128+
Filepath where the data is download for further processing.
132129
"""
133130
if split not in ['train', 'test']:
134131
raise ValueError(f"'{split}' is not a valid split.")
135132

136-
image_files = glob(os.path.join(path, "*.h5"))
137-
if len(image_files) > 0:
138-
return
133+
data_dir = os.path.join(path, "data")
134+
if os.path.exists(data_dir) and glob(os.path.join(data_dir, "*.h5")):
135+
return data_dir
139136

140137
os.makedirs(path, exist_ok=True)
141138

142-
# Load data if not in the given directory
143-
if not os.path.exists(os.path.join(path, "images.npy")) and download:
144-
gdown.download_folder(URL, output=path, quiet=False)
139+
# Download the files from google drive.
140+
util.download_source_gdrive(path=data_dir, url=URL, download=download, download_type="folder", quiet=False)
145141

146142
# Extract and preprocess images for all splits
147143
for _split in ['train', 'test']:
148-
_extract_images(_split, path)
144+
_extract_images(_split, data_dir)
145+
146+
return data_dir
149147

150148

151149
def get_conic_paths(
@@ -161,8 +159,8 @@ def get_conic_paths(
161159
Returns:
162160
List of filepaths for the stored data.
163161
"""
164-
get_conic_data(path, split, download)
165-
return os.path.join(path, f"{split}.h5")
162+
data_dir = get_conic_data(path, split, download)
163+
return os.path.join(data_dir, f"{split}.h5")
166164

167165

168166
def get_conic_dataset(

torch_em/data/datasets/util.py

+14-11
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,23 @@
1+
import os
12
import hashlib
23
import inspect
3-
import os
44
import zipfile
5-
6-
from packaging import version
7-
from shutil import copyfileobj, which
8-
from subprocess import run
9-
from typing import Optional, Tuple
5+
import requests
6+
from tqdm import tqdm
107
from warnings import warn
8+
from subprocess import run
119
from xml.dom import minidom
10+
from packaging import version
11+
from shutil import copyfileobj, which
12+
13+
from typing import Optional, Tuple, Literal
1214

1315
import numpy as np
14-
import requests
16+
from skimage.draw import polygon
17+
1518
import torch
16-
import torch_em
1719

18-
from skimage.draw import polygon
19-
from tqdm import tqdm
20+
import torch_em
2021
from torch_em.transform import get_raw_transform
2122
from torch_em.transform.generic import ResizeLongestSideInputs, Compose
2223

@@ -134,7 +135,7 @@ def download_source_gdrive(
134135
url: str,
135136
download: bool,
136137
checksum: Optional[str] = None,
137-
download_type: str = "zip",
138+
download_type: Literal["zip", "folder"] = "zip",
138139
expected_samples: int = 10000,
139140
quiet: bool = True,
140141
) -> None:
@@ -160,6 +161,7 @@ def download_source_gdrive(
160161
"Need gdown library to download data from google drive. "
161162
"Please install gdown: 'conda install -c conda-forge gdown==4.6.3'."
162163
)
164+
163165
print("Downloading the files. Might take a few minutes...")
164166

165167
if download_type == "zip":
@@ -171,6 +173,7 @@ def download_source_gdrive(
171173
gdown.download_folder(url=url, output=path, quiet=quiet, remaining_ok=True)
172174
else:
173175
raise ValueError("`download_path` argument expects either `zip`/`folder`")
176+
174177
print("Download completed.")
175178

176179

0 commit comments

Comments
 (0)