Skip to content

Commit ef05d9c

Browse files
committed
Add inference script
1 parent 31ba444 commit ef05d9c

File tree

2 files changed

+46
-2
lines changed

2 files changed

+46
-2
lines changed

experiments/vision-mamba/run_livecell.py

+28-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,10 @@
11
import os
22
import argparse
3+
from glob import glob
4+
5+
import imageio.v3 as imageio
6+
7+
import torch
38

49
import torch_em
510
from torch_em.data.datasets import get_livecell_loader
@@ -55,8 +60,27 @@ def run_livecell_training(args):
5560
trainer.fit(iterations=int(args.iterations))
5661

5762

63+
def run_livecell_inference(args):
64+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
65+
66+
# the vision-mamba + decoder (UNet-based) model
67+
model = get_vimunet_model(checkpoint=args.checkpoint)
68+
69+
for image_path in glob(os.path.join(ROOT, "data", "livecell", "images", "livecell_test_images", "*")):
70+
image = imageio.imread(image_path)
71+
tensor_image = torch.from_numpy(image)[None, None].to(device)
72+
73+
predictions = model(tensor_image)
74+
predictions = predictions.squeeze().detach().cpu().numpy()
75+
76+
5877
def main(args):
59-
run_livecell_training(args)
78+
if args.train:
79+
run_livecell_training(args)
80+
81+
if args.predict:
82+
assert args.checkpoint is not None, "Provide the checkpoint path to the trained model."
83+
run_livecell_inference(args)
6084

6185

6286
if __name__ == "__main__":
@@ -65,5 +89,8 @@ def main(args):
6589
parser.add_argument("--iterations", type=int, default=1e4)
6690
parser.add_argument("-s", "--save_root", type=str, default=os.path.join(ROOT, "experiments", "vision-mamba"))
6791
parser.add_argument("--pretrained", action="store_true")
92+
parser.add_argument("--train", action="store_true")
93+
parser.add_argument("--predict", action="store_true")
94+
parser.add_argument("-c", "--checkpoint", default=None, type=str)
6895
args = parser.parse_args()
6996
main(args)

experiments/vision-mamba/vimunet.py

+18-1
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44

55
# pretrained model weights: vim_t - https://huggingface.co/hustvl/Vim-tiny/blob/main/vim_tiny_73p1.pth
66

7+
from collections import OrderedDict
8+
79
import torch
810

911
from torch_em.model import UNETR
@@ -101,6 +103,7 @@ def get_vimunet_model(device=None, checkpoint=None):
101103
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
102104

103105
encoder = ViM(
106+
img_size=1024,
104107
patch_size=16,
105108
embed_dim=192,
106109
depth=24,
@@ -119,7 +122,21 @@ def get_vimunet_model(device=None, checkpoint=None):
119122

120123
if checkpoint is not None:
121124
state = torch.load(checkpoint, map_location="cpu")
122-
encoder_state = state["model"]
125+
126+
if checkpoint.endswith(".pth"): # from Vim
127+
encoder_state = state["model"]
128+
129+
else: # from torch_em
130+
model_state = state["model_state"]
131+
132+
encoder_prefix = "encoder."
133+
encoder_state = []
134+
for k, v in model_state.items():
135+
if k.startswith(encoder_prefix):
136+
encoder_state.append((k[len(encoder_prefix):], v))
137+
138+
encoder_state = OrderedDict(encoder_state)
139+
123140
encoder.load_state_dict(encoder_state)
124141

125142
encoder.img_size = encoder.patch_embed.img_size[0]

0 commit comments

Comments
 (0)