Skip to content

Commit e7bdae0

Browse files
committed
Update inference + add all instance segmentation setups
1 parent ef05d9c commit e7bdae0

File tree

3 files changed

+172
-31
lines changed

3 files changed

+172
-31
lines changed

experiments/vision-mamba/.gitignore

+1
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
*.out
22
*.sh
3+
*.png
+163-17
Original file line numberDiff line numberDiff line change
@@ -1,51 +1,139 @@
11
import os
22
import argparse
3+
import numpy as np
4+
import pandas as pd
35
from glob import glob
6+
from tqdm import tqdm
47

58
import imageio.v3 as imageio
69

710
import torch
811

912
import torch_em
13+
from torch_em.util import segmentation
14+
from torch_em.transform.raw import standardize
1015
from torch_em.data.datasets import get_livecell_loader
16+
from torch_em.loss import DiceLoss, LossWrapper, ApplyAndRemoveMask, DiceBasedDistanceLoss
17+
18+
from elf.evaluation import mean_segmentation_accuracy
1119

1220
from vimunet import get_vimunet_model
1321

1422

1523
ROOT = "/scratch/usr/nimanwai"
1624

17-
18-
def get_loaders(path):
19-
patch_shape = (520, 704)
25+
OFFSETS = [
26+
[-1, 0], [0, -1],
27+
[-3, 0], [0, -3],
28+
[-9, 0], [0, -9],
29+
[-27, 0], [0, -27]
30+
]
31+
32+
33+
def get_loaders(args, patch_shape=(520, 704)):
34+
if args.distances:
35+
label_trafo = torch_em.transform.label.PerObjectDistanceTransform(
36+
distances=True,
37+
boundary_distances=True,
38+
directed_distances=False,
39+
foreground=True,
40+
min_size=25
41+
)
42+
else:
43+
label_trafo = None
2044

2145
train_loader = get_livecell_loader(
22-
path=path, split="train", patch_shape=patch_shape, batch_size=2, binary=True, cell_types=["A172"],
46+
path=args.input,
47+
split="train",
48+
patch_shape=patch_shape,
49+
batch_size=2,
50+
label_dtype=torch.float32,
51+
boundaries=args.boundaries,
52+
label_transform=label_trafo,
53+
offsets=OFFSETS if args.affinities else None,
54+
num_workers=16
2355
)
2456

2557
val_loader = get_livecell_loader(
26-
path=path, split="val", patch_shape=patch_shape, batch_size=1, binary=True, cell_types=["A172"],
58+
path=args.input,
59+
split="val",
60+
patch_shape=patch_shape,
61+
batch_size=1,
62+
label_dtype=torch.float32,
63+
boundaries=args.boundaries,
64+
label_transform=label_trafo,
65+
offsets=OFFSETS if args.affinities else None,
66+
num_workers=16
2767
)
2868

2969
return train_loader, val_loader
3070

3171

72+
def get_output_channels(args):
73+
if args.boundaries:
74+
output_channels = 2
75+
elif args.distances:
76+
output_channels = 3
77+
elif args.affinities:
78+
output_channels = (len(OFFSETS) + 1)
79+
80+
return output_channels
81+
82+
83+
def get_loss_function(args):
84+
if args.affinities:
85+
loss = LossWrapper(
86+
loss=DiceLoss(),
87+
transform=ApplyAndRemoveMask(masking_method="multiply")
88+
)
89+
elif args.distances:
90+
loss = DiceBasedDistanceLoss(mask_distances_in_bg=True)
91+
92+
else:
93+
loss = DiceLoss()
94+
95+
return loss
96+
97+
98+
def get_save_root(args):
99+
# experiment_type
100+
if args.boundaries:
101+
experiment_type = "boundaries"
102+
elif args.affinities:
103+
experiment_type = "affinities"
104+
elif args.distances:
105+
experiment_type = "distances"
106+
else:
107+
raise ValueError
108+
109+
# saving the model checkpoints
110+
save_root = os.path.join(
111+
args.save_root,
112+
"pretrained" if args.pretrained else "scratch",
113+
experiment_type
114+
)
115+
116+
return save_root
117+
118+
32119
def run_livecell_training(args):
33120
# the dataloaders for livecell dataset
34-
train_loader, val_loader = get_loaders(path=args.input)
121+
train_loader, val_loader = get_loaders(args)
35122

36123
if args.pretrained:
37124
checkpoint = "/scratch/usr/nimanwai/models/Vim-tiny/vim_tiny_73p1.pth"
38125
else:
39126
checkpoint = None
40127

128+
output_channels = get_output_channels(args)
129+
41130
# the vision-mamba + decoder (UNet-based) model
42-
model = get_vimunet_model(checkpoint=checkpoint)
131+
model = get_vimunet_model(out_channels=output_channels, checkpoint=checkpoint)
43132

44-
# saving the model checkpoints
45-
save_root = os.path.join(
46-
args.save_root,
47-
"pretrained" if args.pretrained else "scratch"
48-
)
133+
save_root = get_save_root(args)
134+
135+
# loss function
136+
loss = get_loss_function(args)
49137

50138
# trainer for the segmentation task
51139
trainer = torch_em.default_segmentation_trainer(
@@ -54,6 +142,9 @@ def run_livecell_training(args):
54142
train_loader=train_loader,
55143
val_loader=val_loader,
56144
learning_rate=1e-4,
145+
loss=loss,
146+
metric=loss,
147+
log_image_interval=50,
57148
save_root=save_root,
58149
compile_model=False
59150
)
@@ -63,23 +154,72 @@ def run_livecell_training(args):
63154
def run_livecell_inference(args):
64155
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
65156

157+
output_channels = get_output_channels(args)
158+
159+
save_root = get_save_root(args)
160+
checkpoint = os.path.join(save_root, "checkpoints", "livecell-vimunet", "best.pt")
161+
66162
# the vision-mamba + decoder (UNet-based) model
67-
model = get_vimunet_model(checkpoint=args.checkpoint)
163+
model = get_vimunet_model(out_channels=output_channels, checkpoint=checkpoint)
164+
165+
test_image_dir = os.path.join(ROOT, "data", "livecell", "images", "livecell_test_images")
166+
all_test_labels = glob(os.path.join(ROOT, "data", "livecell", "annotations", "livecell_test_images", "*", "*"))
167+
168+
msa_list, sa50_list = [], []
169+
170+
for label_path in tqdm(all_test_labels):
171+
labels = imageio.imread(label_path)
172+
image_id = os.path.split(label_path)[-1]
173+
174+
image = imageio.imread(os.path.join(test_image_dir, image_id))
175+
image = standardize(image)
68176

69-
for image_path in glob(os.path.join(ROOT, "data", "livecell", "images", "livecell_test_images", "*")):
70-
image = imageio.imread(image_path)
71177
tensor_image = torch.from_numpy(image)[None, None].to(device)
72178

73179
predictions = model(tensor_image)
74180
predictions = predictions.squeeze().detach().cpu().numpy()
75181

182+
if args.boundaries:
183+
fg, bd = predictions
184+
instances = segmentation.watershed_from_components(bd, fg)
185+
186+
elif args.affinities:
187+
fg, affs = predictions[0], predictions[1:]
188+
instances = segmentation.mutex_watershed_segmentation(fg, affs, offsets=OFFSETS)
189+
190+
elif args.distances:
191+
fg, cdist, bdist = predictions
192+
instances = segmentation.watershed_from_center_and_boundary_distances(
193+
cdist, bdist, fg, min_size=50,
194+
center_distance_threshold=0.5,
195+
boundary_distance_threshold=0.6,
196+
distance_smoothing=1.0
197+
)
198+
199+
msa, sa_acc = mean_segmentation_accuracy(instances, labels, return_accuracies=True)
200+
msa_list.append(msa)
201+
sa50_list.append(sa_acc[0])
202+
203+
res_path = os.path.join(save_root, "results.csv")
204+
205+
res = {
206+
"LiveCELL": "Metrics",
207+
"mSA": np.mean(msa_list),
208+
"SA50": np.mean(sa50_list)
209+
}
210+
df = pd.DataFrame.from_dict([res])
211+
df.to_csv(res_path)
212+
print(df)
213+
print(f"The result is saved at {res_path}")
214+
76215

77216
def main(args):
217+
assert (args.boundaries + args.affinities + args.distances) == 1
218+
78219
if args.train:
79220
run_livecell_training(args)
80221

81222
if args.predict:
82-
assert args.checkpoint is not None, "Provide the checkpoint path to the trained model."
83223
run_livecell_inference(args)
84224

85225

@@ -88,9 +228,15 @@ def main(args):
88228
parser.add_argument("-i", "--input", type=str, default=os.path.join(ROOT, "data", "livecell"))
89229
parser.add_argument("--iterations", type=int, default=1e4)
90230
parser.add_argument("-s", "--save_root", type=str, default=os.path.join(ROOT, "experiments", "vision-mamba"))
231+
91232
parser.add_argument("--pretrained", action="store_true")
233+
92234
parser.add_argument("--train", action="store_true")
93235
parser.add_argument("--predict", action="store_true")
94-
parser.add_argument("-c", "--checkpoint", default=None, type=str)
236+
237+
parser.add_argument("--boundaries", action="store_true")
238+
parser.add_argument("--affinities", action="store_true")
239+
parser.add_argument("--distances", action="store_true")
240+
95241
args = parser.parse_args()
96242
main(args)

experiments/vision-mamba/vimunet.py

+8-14
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,6 @@
44

55
# pretrained model weights: vim_t - https://huggingface.co/hustvl/Vim-tiny/blob/main/vim_tiny_73p1.pth
66

7-
from collections import OrderedDict
8-
97
import torch
108

119
from torch_em.model import UNETR
@@ -98,7 +96,7 @@ def forward(self, x, inference_params=None):
9896
return x # from here, the tokens can be upsampled easily (N x H x W x C)
9997

10098

101-
def get_vimunet_model(device=None, checkpoint=None):
99+
def get_vimunet_model(out_channels, device=None, checkpoint=None):
102100
if device is None:
103101
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
104102

@@ -120,34 +118,30 @@ def get_vimunet_model(device=None, checkpoint=None):
120118

121119
encoder.default_cfg = _cfg()
122120

121+
model_state = None
123122
if checkpoint is not None:
124123
state = torch.load(checkpoint, map_location="cpu")
125124

126125
if checkpoint.endswith(".pth"): # from Vim
127126
encoder_state = state["model"]
127+
encoder.load_state_dict(encoder_state)
128128

129129
else: # from torch_em
130130
model_state = state["model_state"]
131131

132-
encoder_prefix = "encoder."
133-
encoder_state = []
134-
for k, v in model_state.items():
135-
if k.startswith(encoder_prefix):
136-
encoder_state.append((k[len(encoder_prefix):], v))
137-
138-
encoder_state = OrderedDict(encoder_state)
139-
140-
encoder.load_state_dict(encoder_state)
141-
142132
encoder.img_size = encoder.patch_embed.img_size[0]
143133

144134
model = UNETR(
145135
encoder=encoder,
146-
out_channels=1,
136+
out_channels=out_channels,
147137
resize_input=False,
148138
use_skip_connection=False,
149139
final_activation="Sigmoid"
150140
)
141+
142+
if model_state is not None:
143+
model.load_state_dict(model_state)
144+
151145
model.to(device)
152146

153147
return model

0 commit comments

Comments
 (0)