Skip to content

Commit 9fbba4d

Browse files
committed
Refactor Vim backbone + Add more vision mamba configs
1 parent e7bdae0 commit 9fbba4d

File tree

3 files changed

+78
-25
lines changed

3 files changed

+78
-25
lines changed

experiments/vision-mamba/run_livecell.py

+10-4
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,13 @@
1111

1212
import torch_em
1313
from torch_em.util import segmentation
14+
from torch_em.model import get_vimunet_model
1415
from torch_em.transform.raw import standardize
1516
from torch_em.data.datasets import get_livecell_loader
1617
from torch_em.loss import DiceLoss, LossWrapper, ApplyAndRemoveMask, DiceBasedDistanceLoss
1718

1819
from elf.evaluation import mean_segmentation_accuracy
1920

20-
from vimunet import get_vimunet_model
21-
2221

2322
ROOT = "/scratch/usr/nimanwai"
2423

@@ -128,7 +127,7 @@ def run_livecell_training(args):
128127
output_channels = get_output_channels(args)
129128

130129
# the vision-mamba + decoder (UNet-based) model
131-
model = get_vimunet_model(out_channels=output_channels, checkpoint=checkpoint)
130+
model = get_vimunet_model(out_channels=output_channels, model_type=args.model_type, checkpoint=checkpoint)
132131

133132
save_root = get_save_root(args)
134133

@@ -160,7 +159,12 @@ def run_livecell_inference(args):
160159
checkpoint = os.path.join(save_root, "checkpoints", "livecell-vimunet", "best.pt")
161160

162161
# the vision-mamba + decoder (UNet-based) model
163-
model = get_vimunet_model(out_channels=output_channels, checkpoint=checkpoint)
162+
model = get_vimunet_model(
163+
out_channels=output_channels,
164+
model_type=args.model_type,
165+
with_cls_token=args.with_cls_token,
166+
checkpoint=checkpoint
167+
)
164168

165169
test_image_dir = os.path.join(ROOT, "data", "livecell", "images", "livecell_test_images")
166170
all_test_labels = glob(os.path.join(ROOT, "data", "livecell", "annotations", "livecell_test_images", "*", "*"))
@@ -228,6 +232,8 @@ def main(args):
228232
parser.add_argument("-i", "--input", type=str, default=os.path.join(ROOT, "data", "livecell"))
229233
parser.add_argument("--iterations", type=int, default=1e4)
230234
parser.add_argument("-s", "--save_root", type=str, default=os.path.join(ROOT, "experiments", "vision-mamba"))
235+
parser.add_argument("-m", "--model_type", type=str, default="vim_t")
236+
parser.add_argument("--with_cls_token", action="store_true")
231237

232238
parser.add_argument("--pretrained", action="store_true")
233239

torch_em/model/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,4 @@
22
from .probabilistic_unet import ProbabilisticUNet
33
from .unetr import UNETR
44
from .vit import get_vision_transformer
5+
from .vim import get_vimunet_model

experiments/vision-mamba/vimunet.py torch_em/model/vim.py

+67-21
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
import torch
88

9-
from torch_em.model import UNETR
9+
from .unetr import UNETR
1010

1111
from vim.models_mamba import VisionMamba, rms_norm_fn, RMSNorm, layer_norm_fn
1212

@@ -40,7 +40,7 @@ def forward_features(self, x, inference_params=None):
4040
x = x + self.pos_embed
4141
x = self.pos_drop(x)
4242

43-
# mamba impl
43+
# mamba implementation
4444
residual = None
4545
hidden_states = x
4646
for layer in self.layers:
@@ -61,7 +61,7 @@ def forward_features(self, x, inference_params=None):
6161
residual = residual + self.drop_path(hidden_states)
6262
hidden_states = self.norm_f(residual.to(dtype=self.norm_f.weight.dtype))
6363
else:
64-
# Set prenorm=False here since we don't need the residual
64+
# Set prenorm = False here since we don't need the residual
6565
fused_add_norm_fn = rms_norm_fn if isinstance(self.norm_f, RMSNorm) else layer_norm_fn
6666
hidden_states = fused_add_norm_fn(
6767
self.drop_path(hidden_states),
@@ -96,27 +96,73 @@ def forward(self, x, inference_params=None):
9696
return x # from here, the tokens can be upsampled easily (N x H x W x C)
9797

9898

99-
def get_vimunet_model(out_channels, device=None, checkpoint=None):
99+
def get_vim_encoder(model_type="vim_t", with_cls_token=True):
100+
if model_type == "vim_t":
101+
# `vim_tiny_patch16_224_bimambav2_final_pool_mean_abs_pos_embed_rope_also_residual_with_cls_token`
102+
# *has an imagenet pretrained model
103+
encoder = ViM(
104+
img_size=1024,
105+
patch_size=16,
106+
embed_dim=192,
107+
depth=24,
108+
rms_norm=True,
109+
residual_in_fp32=True,
110+
fused_add_norm=True,
111+
final_pool_type='all',
112+
if_abs_pos_embed=True,
113+
if_rope=True,
114+
if_rope_residual=True,
115+
bimamba_type="v2",
116+
if_cls_token=with_cls_token,
117+
)
118+
elif model_type == "vim_s":
119+
# `vim_small_patch16_224_bimambav2_final_pool_mean_abs_pos_embed_rope_also_residual`
120+
# AA: added a class token to the default models
121+
encoder = ViM(
122+
img_size=1024,
123+
patch_size=16,
124+
embed_dim=384,
125+
depth=24,
126+
rms_norm=True,
127+
residual_in_fp32=True,
128+
fused_add_norm=True,
129+
final_pool_type='all',
130+
if_abs_pos_embed=True,
131+
if_rope=True,
132+
if_rope_residual=True,
133+
bimamba_type="v2",
134+
if_cls_token=with_cls_token,
135+
)
136+
elif model_type == "vim_b":
137+
# `vim_base_patch16_224_bimambav2_final_pool_mean_abs_pos_embed_rope_also_residual`
138+
# AA: added a class token to the default models
139+
encoder = ViM(
140+
img_size=1024,
141+
patch_size=16,
142+
embed_dim=768,
143+
depth=24,
144+
rms_norm=True,
145+
residual_in_fp32=True,
146+
fused_add_norm=True,
147+
final_pool_type='all',
148+
if_abs_pos_embed=True,
149+
if_rope=True,
150+
if_rope_residual=True,
151+
bimamba_type="v2",
152+
if_cls_token=with_cls_token,
153+
)
154+
else:
155+
raise ValueError("Choose from `vim_t` or `vim_b`")
156+
157+
encoder.default_cfg = _cfg()
158+
return encoder
159+
160+
161+
def get_vimunet_model(out_channels, model_type="vim_t", with_cls_token=True, device=None, checkpoint=None):
100162
if device is None:
101163
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
102164

103-
encoder = ViM(
104-
img_size=1024,
105-
patch_size=16,
106-
embed_dim=192,
107-
depth=24,
108-
rms_norm=True,
109-
residual_in_fp32=True,
110-
fused_add_norm=True,
111-
final_pool_type='all',
112-
if_abs_pos_embed=True,
113-
if_rope=True,
114-
if_rope_residual=True,
115-
bimamba_type="v2",
116-
if_cls_token=True,
117-
)
118-
119-
encoder.default_cfg = _cfg()
165+
encoder = get_vim_encoder(model_type, with_cls_token)
120166

121167
model_state = None
122168
if checkpoint is not None:

0 commit comments

Comments
 (0)