Skip to content

Commit 6ba315f

Browse files
committed
Merge branch 'main' into vision-mamba
2 parents a44a227 + daedad0 commit 6ba315f

File tree

12 files changed

+139
-17
lines changed

12 files changed

+139
-17
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
from torch_em.util.debug import check_loader
2+
from torch_em.data.datasets import get_dynamicnuclearnet_loader
3+
4+
5+
DYNAMICNUCLEARNET_ROOT = "/home/anwai/data/deepcell/"
6+
7+
8+
# NOTE: the DynamicNuclearNet data cannot be downloaded automatically.
9+
# you need to download it yourself from https://datasets.deepcell.org/data
10+
def check_dynamicnuclearnet():
11+
# set this path to where you have downloaded the dynamicnuclearnet data
12+
loader = get_dynamicnuclearnet_loader(
13+
DYNAMICNUCLEARNET_ROOT, "train",
14+
patch_shape=(512, 512), batch_size=2, download=True
15+
)
16+
check_loader(loader, 10, instance_labels=True, rgb=False)
17+
18+
19+
if __name__ == "__main__":
20+
check_dynamicnuclearnet()

torch_em/__version__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = "0.6.1"
1+
__version__ = "0.6.2"

torch_em/data/datasets/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from .ctc import get_ctc_segmentation_loader, get_ctc_segmentation_dataset
77
from .deepbacs import get_deepbacs_loader, get_deepbacs_dataset
88
from .dsb import get_dsb_loader, get_dsb_dataset
9+
from .dynamicnuclearnet import get_dynamicnuclearnet_loader, get_dynamicnuclearnet_dataset
910
from .hpa import get_hpa_segmentation_loader, get_hpa_segmentation_dataset
1011
from .isbi2012 import get_isbi_loader, get_isbi_dataset
1112
from .kasthuri import get_kasthuri_loader, get_kasthuri_dataset

torch_em/data/datasets/ctc.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ def get_ctc_url_and_checksum(dataset_name, split):
4848
def _require_ctc_dataset(path, dataset_name, download, split):
4949
dataset_names = list(CTC_CHECKSUMS["train"].keys())
5050
if dataset_name not in dataset_names:
51-
raise ValueError(f"Inalid dataset: {dataset_name}, choose one of {dataset_names}.")
51+
raise ValueError(f"Invalid dataset: {dataset_name}, choose one of {dataset_names}.")
5252

5353
data_path = os.path.join(path, split, dataset_name)
5454

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
import os
2+
from tqdm import tqdm
3+
from glob import glob
4+
5+
import z5py
6+
import numpy as np
7+
import pandas as pd
8+
9+
import torch_em
10+
11+
from . import util
12+
13+
14+
# Automatic download is currently not possible, because of authentication
15+
URL = None # TODO: here - https://datasets.deepcell.org/data
16+
17+
18+
def _create_split(path, split):
19+
split_file = os.path.join(path, "DynamicNuclearNet-segmentation-v1_0", f"{split}.npz")
20+
split_folder = os.path.join(path, split)
21+
os.makedirs(split_folder, exist_ok=True)
22+
data = np.load(split_file, allow_pickle=True)
23+
24+
x, y = data["X"], data["y"]
25+
metadata = data["meta"]
26+
metadata = pd.DataFrame(metadata[1:], columns=metadata[0])
27+
28+
for i, (im, label) in tqdm(enumerate(zip(x, y)), total=len(x), desc=f"Creating files for {split}-split"):
29+
out_path = os.path.join(split_folder, f"image_{i:04}.zarr")
30+
image_channel = im[..., 0]
31+
label_channel = label[..., 0]
32+
chunks = image_channel.shape
33+
with z5py.File(out_path, "a") as f:
34+
f.create_dataset("raw", data=image_channel, compression="gzip", chunks=chunks)
35+
f.create_dataset("labels", data=label_channel, compression="gzip", chunks=chunks)
36+
37+
os.remove(split_file)
38+
39+
40+
def _create_dataset(path, zip_path):
41+
util.unzip(zip_path, path, remove=False)
42+
splits = ["train", "val", "test"]
43+
assert all(
44+
[os.path.exists(os.path.join(path, "DynamicNuclearNet-segmentation-v1_0", f"{split}.npz")) for split in splits]
45+
)
46+
for split in splits:
47+
_create_split(path, split)
48+
49+
50+
def get_dynamicnuclearnet_dataset(
51+
path, split, patch_shape, download=False, **kwargs
52+
):
53+
"""Dataset for the segmentation of cell nuclei imaged with fluorescene microscopy.
54+
55+
This dataset is from the publication https://doi.org/10.1101/803205.
56+
Please cite it if you use this dataset for a publication."""
57+
splits = ["train", "val", "test"]
58+
assert split in splits
59+
60+
# check if the dataset exists already
61+
zip_path = os.path.join(path, "DynamicNuclearNet-segmentation-v1_0.zip")
62+
if all([os.path.exists(os.path.join(path, split)) for split in splits]): # yes it does
63+
pass
64+
elif os.path.exists(zip_path): # no it does not, but we have the zip there and can unpack it
65+
_create_dataset(path, zip_path)
66+
else:
67+
raise RuntimeError(
68+
"We do not support automatic download for the dynamic nuclear net dataset yet."
69+
f"Please download the dataset from https://datasets.deepcell.org/data and put it here: {zip_path}"
70+
)
71+
72+
split_folder = os.path.join(path, split)
73+
assert os.path.exists(split_folder)
74+
data_path = glob(os.path.join(split_folder, "*.zarr"))
75+
assert len(data_path) > 0
76+
77+
raw_key, label_key = "raw", "labels"
78+
79+
return torch_em.default_segmentation_dataset(
80+
data_path, raw_key, data_path, label_key, patch_shape, is_seg_dataset=True, ndim=2, **kwargs
81+
)
82+
83+
84+
def get_dynamicnuclearnet_loader(
85+
path, split, patch_shape, batch_size, download, **kwargs
86+
):
87+
"""Dataloader for the segmentation of cell nuclei for 5 different cell lines in fluorescence microscopes.
88+
See `get_dynamicnuclearnet_dataset` for details.
89+
"""
90+
ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs)
91+
dataset = get_dynamicnuclearnet_dataset(path, split, patch_shape, download, **ds_kwargs)
92+
loader = torch_em.get_data_loader(dataset, batch_size, **loader_kwargs)
93+
return loader

torch_em/loss/dice.py

+1
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@ def __init__(self, channelwise=True, eps=1e-7, reduce_channel="sum"):
8989
super().__init__()
9090
self.channelwise = channelwise
9191
self.eps = eps
92+
self.reduce_channel = reduce_channel
9293

9394
# all torch_em classes should store init kwargs to easily recreate the init call
9495
self.init_kwargs = {"channelwise": channelwise, "eps": self.eps, "reduce_channel": self.reduce_channel}

torch_em/model/unetr.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import torch.nn.functional as F
77

88
from .unet import Decoder, ConvBlock2d, Upsampler2d
9-
from .vit import get_vision_transformer, ViT_MAE, ViT_Sam
9+
from .vit import get_vision_transformer
1010

1111
try:
1212
from micro_sam.util import get_sam_model
@@ -244,7 +244,10 @@ def forward(self, x):
244244

245245
encoder_outputs = self.encoder(x)
246246

247-
if isinstance(self.encoder, ViT_Sam) or isinstance(self.encoder, ViT_MAE):
247+
if isinstance(encoder_outputs[-1], list):
248+
# `encoder_outputs` can be arranged in only two forms:
249+
# - either we only return the image embeddings
250+
# - or, we return the image embeddings and the "list" of global attention layers
248251
z12, from_encoder = encoder_outputs
249252
else:
250253
z12 = encoder_outputs

torch_em/self_training/fix_match.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@ def __init__(
136136
# functionality for saving checkpoints and initialization
137137
#
138138

139-
def save_checkpoint(self, name, best_metric, **extra_save_dict):
139+
def save_checkpoint(self, name, current_metric, best_metric, **extra_save_dict):
140140
train_loader_kwargs = get_constructor_arguments(self.train_loader)
141141
val_loader_kwargs = get_constructor_arguments(self.val_loader)
142142
extra_state = {
@@ -152,7 +152,7 @@ def save_checkpoint(self, name, best_metric, **extra_save_dict):
152152
},
153153
}
154154
extra_state.update(**extra_save_dict)
155-
super().save_checkpoint(name, best_metric, **extra_state)
155+
super().save_checkpoint(name, current_metric, best_metric, **extra_state)
156156

157157
# distribution alignment - encourages the distribution of the model's generated pseudo labels to match the marginal
158158
# distribution of pseudo labels from the source transfer

torch_em/self_training/mean_teacher.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,7 @@ def _momentum_update(self):
171171
# functionality for saving checkpoints and initialization
172172
#
173173

174-
def save_checkpoint(self, name, best_metric, **extra_save_dict):
174+
def save_checkpoint(self, name, current_metric, best_metric, **extra_save_dict):
175175
train_loader_kwargs = get_constructor_arguments(self.train_loader)
176176
val_loader_kwargs = get_constructor_arguments(self.val_loader)
177177
extra_state = {
@@ -188,7 +188,7 @@ def save_checkpoint(self, name, best_metric, **extra_save_dict):
188188
},
189189
}
190190
extra_state.update(**extra_save_dict)
191-
super().save_checkpoint(name, best_metric, **extra_state)
191+
super().save_checkpoint(name, current_metric, best_metric, **extra_state)
192192

193193
def load_checkpoint(self, checkpoint="best"):
194194
save_dict = super().load_checkpoint(checkpoint)

torch_em/trainer/default_trainer.py

+8-4
Original file line numberDiff line numberDiff line change
@@ -458,14 +458,15 @@ def _initialize(self, iterations, load_from_checkpoint, epochs=None):
458458
best_metric = np.inf
459459
return best_metric
460460

461-
def save_checkpoint(self, name, best_metric, train_time=0.0, **extra_save_dict):
461+
def save_checkpoint(self, name, current_metric, best_metric, train_time=0.0, **extra_save_dict):
462462
save_path = os.path.join(self.checkpoint_folder, f"{name}.pt")
463463
extra_init_dict = extra_save_dict.pop("init", {})
464464
save_dict = {
465465
"iteration": self._iteration,
466466
"epoch": self._epoch,
467467
"best_epoch": self._best_epoch,
468468
"best_metric": best_metric,
469+
"current_metric": current_metric,
469470
"model_state": self.model.state_dict(),
470471
"optimizer_state": self.optimizer.state_dict(),
471472
"init": self.init_data | extra_init_dict,
@@ -494,6 +495,7 @@ def load_checkpoint(self, checkpoint="best"):
494495
self._epoch = save_dict["epoch"]
495496
self._best_epoch = save_dict["best_epoch"]
496497
self.best_metric = save_dict["best_metric"]
498+
self.current_metric = save_dict["current_metric"]
497499
self.train_time = save_dict.get("train_time", 0.0)
498500

499501
model_state = save_dict["model_state"]
@@ -573,14 +575,16 @@ def fit(self, iterations=None, load_from_checkpoint=None, epochs=None, save_ever
573575
if current_metric < best_metric:
574576
best_metric = current_metric
575577
self._best_epoch = self._epoch
576-
self.save_checkpoint("best", best_metric, train_time=total_train_time)
578+
self.save_checkpoint("best", current_metric, best_metric, train_time=total_train_time)
577579

578580
# save this checkpoint as the latest checkpoint
579-
self.save_checkpoint("latest", best_metric, train_time=total_train_time)
581+
self.save_checkpoint("latest", current_metric, best_metric, train_time=total_train_time)
580582

581583
# if we save after every k-th epoch then check if we need to save now
582584
if save_every_kth_epoch is not None and (self._epoch + 1) % save_every_kth_epoch == 0:
583-
self.save_checkpoint(f"epoch-{self._epoch + 1}", best_metric, train_time=total_train_time)
585+
self.save_checkpoint(
586+
f"epoch-{self._epoch + 1}", current_metric, best_metric, train_time=total_train_time
587+
)
584588

585589
# if early stopping has been specified then check if the stopping condition is met
586590
if self.early_stopping is not None:

torch_em/trainer/spoco_trainer.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,10 @@ def _momentum_update(self):
3232
for param_model, param_teacher in zip(self.model.parameters(), self.model2.parameters()):
3333
param_teacher.data = param_teacher.data * self.momentum + param_model.data * (1. - self.momentum)
3434

35-
def save_checkpoint(self, name, best_metric, **extra_save_dict):
36-
super().save_checkpoint(name, best_metric, model2_state=self.model2.state_dict(), **extra_save_dict)
35+
def save_checkpoint(self, name, current_metric, best_metric, **extra_save_dict):
36+
super().save_checkpoint(
37+
name, current_metric, best_metric, model2_state=self.model2.state_dict(), **extra_save_dict
38+
)
3739

3840
def load_checkpoint(self, checkpoint="best"):
3941
save_dict = super().load_checkpoint(checkpoint)

torch_em/util/modelzoo.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -125,8 +125,6 @@ def _write_depedencies(export_folder, dependencies):
125125
ver = torch.__version__
126126
major, minor = list(map(int, ver.split(".")[:2]))
127127
assert major in (1, 2)
128-
if major == 2:
129-
warn("Modelzoo functionality is not fully tested for PyTorch 2")
130128
# the torch zip layout changed for a few versions:
131129
torch_min_version = "1.0"
132130
if minor > 6 and minor < 10:
@@ -363,7 +361,7 @@ def _get_axes(axis):
363361
if std is not None:
364362
preprocessing[0]["kwargs"]["std"] = std
365363

366-
elif name == "torch_em.transform.normalize_percentile":
364+
elif name == "torch_em.transform.raw.normalize_percentile":
367365

368366
lower, upper = kwargs.get("lower", 1.0), kwargs.get("upper", 99.0)
369367
axes = _get_axes(kwargs.get("axis", None))

0 commit comments

Comments
 (0)