This is the official implementation of SLM-SAM 2.
By Yuwen Chen, Zafer Yildiz, Qihang Li, Yaqian Chen, Haoyu Dong, Hanxue Gu, Nicholas Konz, Maciej A. Mazurowski
SLM-SAM 2 is a novel video object segmentation method that can accelerate volumetric medical image annotation by propagating annotations from a single slice to the remaining slices within volumes. By introducing a dynamic short-long memory module, SLM-SAM 2 shows improved segmentation performance on organs, bones and muscles across different imaging modalities than SAM 2.
Firstly, please install PyTorch and TorchVision dependencies following instructions here. SLM-SAM 2 can be installed using:
cd SLM-SAM 2
pip install -e .
Before finetuning, we need to download SAM 2 pretrained checkpoints using following commands:
cd checkpoints && \
./download_ckpts.sh && \
cd ..
Open ./sam2/configs/sam2.1_training/slm_sam2_hiera_t_finetune.yaml
, add path to image folder, mask folder, and text file describing volumes used for training. The dataset format follows the same as that of SAM 2.
DATA_DIRECTORY
├── images
│ ├── volume1
│ │ ├── 00000.jpg
│ │ ├── 00001.jpg
│ │ └── ...
│ └── ...
├── masks
│ ├── volume1
│ │ ├── 00000.png
│ │ ├── 00001.png
│ │ └── ...
│ └── ...
├── train.txt
├── test.txt
Start finetuning by running:
CUDA_VISIBLE_DEVICES=[GPU_ID] python3 training/train.py \
-c configs/sam2.1_training/slm_sam2_hiera_t_finetune.yaml \
--use-cluster 0 \
--num-gpus 1
Propagate annotation by running:
CUDA_VISIBLE_DEVICES=[GPU_ID] python3 inference.py \
--test_img_folder [test image folder path] \
--test_mask_folder [test mask folder path] \
--checkpoint_folder [checkpoint path] \
--checkpoint_name [checkpoint file name] \
--cfg_name slm_sam2_hiera_t.yaml \
--test_txt_file [test text file path] \
--mask_prompt_dict [path to mask prompt dictionary] \
--output_folder [path of output folder, to save predictions] \
- checkpoint_folder: directory that contains .pt file
- checkpoint_name: name of .pt file
- mask_prompt_dict: dictionary mapping each volume ID to the slice index used as the mask prompt (e.g., mask_prompt_dict[volume_id] = slice_index)
All codes in this repository are under GPLv3 license.