1
1
import os
2
2
import argparse
3
+ from glob import glob
4
+
5
+ import imageio .v3 as imageio
6
+
7
+ import torch
3
8
4
9
import torch_em
5
10
from torch_em .data .datasets import get_livecell_loader
@@ -55,8 +60,27 @@ def run_livecell_training(args):
55
60
trainer .fit (iterations = int (args .iterations ))
56
61
57
62
63
+ def run_livecell_inference (args ):
64
+ device = torch .device ("cuda" if torch .cuda .is_available () else "cpu" )
65
+
66
+ # the vision-mamba + decoder (UNet-based) model
67
+ model = get_vimunet_model (checkpoint = args .checkpoint )
68
+
69
+ for image_path in glob (os .path .join (ROOT , "data" , "livecell" , "images" , "livecell_test_images" , "*" )):
70
+ image = imageio .imread (image_path )
71
+ tensor_image = torch .from_numpy (image )[None , None ].to (device )
72
+
73
+ predictions = model (tensor_image )
74
+ predictions = predictions .squeeze ().detach ().cpu ().numpy ()
75
+
76
+
58
77
def main (args ):
59
- run_livecell_training (args )
78
+ if args .train :
79
+ run_livecell_training (args )
80
+
81
+ if args .predict :
82
+ assert args .checkpoint is not None , "Provide the checkpoint path to the trained model."
83
+ run_livecell_inference (args )
60
84
61
85
62
86
if __name__ == "__main__" :
@@ -65,5 +89,8 @@ def main(args):
65
89
parser .add_argument ("--iterations" , type = int , default = 1e4 )
66
90
parser .add_argument ("-s" , "--save_root" , type = str , default = os .path .join (ROOT , "experiments" , "vision-mamba" ))
67
91
parser .add_argument ("--pretrained" , action = "store_true" )
92
+ parser .add_argument ("--train" , action = "store_true" )
93
+ parser .add_argument ("--predict" , action = "store_true" )
94
+ parser .add_argument ("-c" , "--checkpoint" , default = None , type = str )
68
95
args = parser .parse_args ()
69
96
main (args )
0 commit comments