1
1
import os
2
2
import argparse
3
+ import numpy as np
4
+ import pandas as pd
3
5
from glob import glob
6
+ from tqdm import tqdm
4
7
5
8
import imageio .v3 as imageio
6
9
7
10
import torch
8
11
9
12
import torch_em
13
+ from torch_em .util import segmentation
14
+ from torch_em .transform .raw import standardize
10
15
from torch_em .data .datasets import get_livecell_loader
16
+ from torch_em .loss import DiceLoss , LossWrapper , ApplyAndRemoveMask , DiceBasedDistanceLoss
17
+
18
+ from elf .evaluation import mean_segmentation_accuracy
11
19
12
20
from vimunet import get_vimunet_model
13
21
14
22
15
23
ROOT = "/scratch/usr/nimanwai"
16
24
17
-
18
- def get_loaders (path ):
19
- patch_shape = (520 , 704 )
25
+ OFFSETS = [
26
+ [- 1 , 0 ], [0 , - 1 ],
27
+ [- 3 , 0 ], [0 , - 3 ],
28
+ [- 9 , 0 ], [0 , - 9 ],
29
+ [- 27 , 0 ], [0 , - 27 ]
30
+ ]
31
+
32
+
33
+ def get_loaders (args , patch_shape = (520 , 704 )):
34
+ if args .distances :
35
+ label_trafo = torch_em .transform .label .PerObjectDistanceTransform (
36
+ distances = True ,
37
+ boundary_distances = True ,
38
+ directed_distances = False ,
39
+ foreground = True ,
40
+ min_size = 25
41
+ )
42
+ else :
43
+ label_trafo = None
20
44
21
45
train_loader = get_livecell_loader (
22
- path = path , split = "train" , patch_shape = patch_shape , batch_size = 2 , binary = True , cell_types = ["A172" ],
46
+ path = args .input ,
47
+ split = "train" ,
48
+ patch_shape = patch_shape ,
49
+ batch_size = 2 ,
50
+ label_dtype = torch .float32 ,
51
+ boundaries = args .boundaries ,
52
+ label_transform = label_trafo ,
53
+ offsets = OFFSETS if args .affinities else None ,
54
+ num_workers = 16
23
55
)
24
56
25
57
val_loader = get_livecell_loader (
26
- path = path , split = "val" , patch_shape = patch_shape , batch_size = 1 , binary = True , cell_types = ["A172" ],
58
+ path = args .input ,
59
+ split = "val" ,
60
+ patch_shape = patch_shape ,
61
+ batch_size = 1 ,
62
+ label_dtype = torch .float32 ,
63
+ boundaries = args .boundaries ,
64
+ label_transform = label_trafo ,
65
+ offsets = OFFSETS if args .affinities else None ,
66
+ num_workers = 16
27
67
)
28
68
29
69
return train_loader , val_loader
30
70
31
71
72
+ def get_output_channels (args ):
73
+ if args .boundaries :
74
+ output_channels = 2
75
+ elif args .distances :
76
+ output_channels = 3
77
+ elif args .affinities :
78
+ output_channels = (len (OFFSETS ) + 1 )
79
+
80
+ return output_channels
81
+
82
+
83
+ def get_loss_function (args ):
84
+ if args .affinities :
85
+ loss = LossWrapper (
86
+ loss = DiceLoss (),
87
+ transform = ApplyAndRemoveMask (masking_method = "multiply" )
88
+ )
89
+ elif args .distances :
90
+ loss = DiceBasedDistanceLoss (mask_distances_in_bg = True )
91
+
92
+ else :
93
+ loss = DiceLoss ()
94
+
95
+ return loss
96
+
97
+
98
+ def get_save_root (args ):
99
+ # experiment_type
100
+ if args .boundaries :
101
+ experiment_type = "boundaries"
102
+ elif args .affinities :
103
+ experiment_type = "affinities"
104
+ elif args .distances :
105
+ experiment_type = "distances"
106
+ else :
107
+ raise ValueError
108
+
109
+ # saving the model checkpoints
110
+ save_root = os .path .join (
111
+ args .save_root ,
112
+ "pretrained" if args .pretrained else "scratch" ,
113
+ experiment_type
114
+ )
115
+
116
+ return save_root
117
+
118
+
32
119
def run_livecell_training (args ):
33
120
# the dataloaders for livecell dataset
34
- train_loader , val_loader = get_loaders (path = args . input )
121
+ train_loader , val_loader = get_loaders (args )
35
122
36
123
if args .pretrained :
37
124
checkpoint = "/scratch/usr/nimanwai/models/Vim-tiny/vim_tiny_73p1.pth"
38
125
else :
39
126
checkpoint = None
40
127
128
+ output_channels = get_output_channels (args )
129
+
41
130
# the vision-mamba + decoder (UNet-based) model
42
- model = get_vimunet_model (checkpoint = checkpoint )
131
+ model = get_vimunet_model (out_channels = output_channels , checkpoint = checkpoint )
43
132
44
- # saving the model checkpoints
45
- save_root = os .path .join (
46
- args .save_root ,
47
- "pretrained" if args .pretrained else "scratch"
48
- )
133
+ save_root = get_save_root (args )
134
+
135
+ # loss function
136
+ loss = get_loss_function (args )
49
137
50
138
# trainer for the segmentation task
51
139
trainer = torch_em .default_segmentation_trainer (
@@ -54,6 +142,9 @@ def run_livecell_training(args):
54
142
train_loader = train_loader ,
55
143
val_loader = val_loader ,
56
144
learning_rate = 1e-4 ,
145
+ loss = loss ,
146
+ metric = loss ,
147
+ log_image_interval = 50 ,
57
148
save_root = save_root ,
58
149
compile_model = False
59
150
)
@@ -63,23 +154,72 @@ def run_livecell_training(args):
63
154
def run_livecell_inference (args ):
64
155
device = torch .device ("cuda" if torch .cuda .is_available () else "cpu" )
65
156
157
+ output_channels = get_output_channels (args )
158
+
159
+ save_root = get_save_root (args )
160
+ checkpoint = os .path .join (save_root , "checkpoints" , "livecell-vimunet" , "best.pt" )
161
+
66
162
# the vision-mamba + decoder (UNet-based) model
67
- model = get_vimunet_model (checkpoint = args .checkpoint )
163
+ model = get_vimunet_model (out_channels = output_channels , checkpoint = checkpoint )
164
+
165
+ test_image_dir = os .path .join (ROOT , "data" , "livecell" , "images" , "livecell_test_images" )
166
+ all_test_labels = glob (os .path .join (ROOT , "data" , "livecell" , "annotations" , "livecell_test_images" , "*" , "*" ))
167
+
168
+ msa_list , sa50_list = [], []
169
+
170
+ for label_path in tqdm (all_test_labels ):
171
+ labels = imageio .imread (label_path )
172
+ image_id = os .path .split (label_path )[- 1 ]
173
+
174
+ image = imageio .imread (os .path .join (test_image_dir , image_id ))
175
+ image = standardize (image )
68
176
69
- for image_path in glob (os .path .join (ROOT , "data" , "livecell" , "images" , "livecell_test_images" , "*" )):
70
- image = imageio .imread (image_path )
71
177
tensor_image = torch .from_numpy (image )[None , None ].to (device )
72
178
73
179
predictions = model (tensor_image )
74
180
predictions = predictions .squeeze ().detach ().cpu ().numpy ()
75
181
182
+ if args .boundaries :
183
+ fg , bd = predictions
184
+ instances = segmentation .watershed_from_components (bd , fg )
185
+
186
+ elif args .affinities :
187
+ fg , affs = predictions [0 ], predictions [1 :]
188
+ instances = segmentation .mutex_watershed_segmentation (fg , affs , offsets = OFFSETS )
189
+
190
+ elif args .distances :
191
+ fg , cdist , bdist = predictions
192
+ instances = segmentation .watershed_from_center_and_boundary_distances (
193
+ cdist , bdist , fg , min_size = 50 ,
194
+ center_distance_threshold = 0.5 ,
195
+ boundary_distance_threshold = 0.6 ,
196
+ distance_smoothing = 1.0
197
+ )
198
+
199
+ msa , sa_acc = mean_segmentation_accuracy (instances , labels , return_accuracies = True )
200
+ msa_list .append (msa )
201
+ sa50_list .append (sa_acc [0 ])
202
+
203
+ res_path = os .path .join (save_root , "results.csv" )
204
+
205
+ res = {
206
+ "LiveCELL" : "Metrics" ,
207
+ "mSA" : np .mean (msa_list ),
208
+ "SA50" : np .mean (sa50_list )
209
+ }
210
+ df = pd .DataFrame .from_dict ([res ])
211
+ df .to_csv (res_path )
212
+ print (df )
213
+ print (f"The result is saved at { res_path } " )
214
+
76
215
77
216
def main (args ):
217
+ assert (args .boundaries + args .affinities + args .distances ) == 1
218
+
78
219
if args .train :
79
220
run_livecell_training (args )
80
221
81
222
if args .predict :
82
- assert args .checkpoint is not None , "Provide the checkpoint path to the trained model."
83
223
run_livecell_inference (args )
84
224
85
225
@@ -88,9 +228,15 @@ def main(args):
88
228
parser .add_argument ("-i" , "--input" , type = str , default = os .path .join (ROOT , "data" , "livecell" ))
89
229
parser .add_argument ("--iterations" , type = int , default = 1e4 )
90
230
parser .add_argument ("-s" , "--save_root" , type = str , default = os .path .join (ROOT , "experiments" , "vision-mamba" ))
231
+
91
232
parser .add_argument ("--pretrained" , action = "store_true" )
233
+
92
234
parser .add_argument ("--train" , action = "store_true" )
93
235
parser .add_argument ("--predict" , action = "store_true" )
94
- parser .add_argument ("-c" , "--checkpoint" , default = None , type = str )
236
+
237
+ parser .add_argument ("--boundaries" , action = "store_true" )
238
+ parser .add_argument ("--affinities" , action = "store_true" )
239
+ parser .add_argument ("--distances" , action = "store_true" )
240
+
95
241
args = parser .parse_args ()
96
242
main (args )
0 commit comments