Skip to content

mazurowski-lab/SLM-SAM2

Repository files navigation

SLM-SAM2: Accelerating Volumetric Medical Image Annotation via Short-Long Memory SAM 2

This is the official implementation of SLM-SAM 2.

arXiv Paper

image

SLM-SAM 2 is a novel video object segmentation method that can accelerate volumetric medical image annotation by propagating annotations from a single slice to the remaining slices within volumes. By introducing a dynamic short-long memory module, SLM-SAM 2 shows improved segmentation performance on organs, bones and muscles across different imaging modalities than SAM 2.

image

Installation

Firstly, please install PyTorch and TorchVision dependencies following instructions here. SLM-SAM 2 can be installed using:

cd SLM-SAM 2

pip install -e .

Getting Started

1. Download SAM 2 Pretrained Checkpoints

Before finetuning, we need to download SAM 2 pretrained checkpoints using following commands:

cd  checkpoints && \
./download_ckpts.sh && \
cd ..

2. Finetuning on Medical Dataset

Open ./sam2/configs/sam2.1_training/slm_sam2_hiera_t_finetune.yaml, add path to image folder, mask folder, and text file describing volumes used for training. The dataset format follows the same as that of SAM 2.

DATA_DIRECTORY
├── images
│   ├── volume1
│   │   ├── 00000.jpg
│   │   ├── 00001.jpg
│   │   └── ...
│   └── ...
├── masks
│   ├── volume1
│   │   ├── 00000.png
│   │   ├── 00001.png
│   │   └── ...
│   └── ...
├── train.txt
├── test.txt

Start finetuning by running:

CUDA_VISIBLE_DEVICES=[GPU_ID] python3 training/train.py \
    -c configs/sam2.1_training/slm_sam2_hiera_t_finetune.yaml \
    --use-cluster 0 \
    --num-gpus 1

3. Inference

Propagate annotation by running:

CUDA_VISIBLE_DEVICES=[GPU_ID] python3 inference.py \
    --test_img_folder [test image folder path] \
    --test_mask_folder [test mask folder path] \
    --checkpoint_folder [checkpoint path] \
    --checkpoint_name [checkpoint file name] \
    --cfg_name slm_sam2_hiera_t.yaml \
    --test_txt_file [test text file path] \
    --mask_prompt_dict [path to mask prompt dictionary] \
    --output_folder [path of output folder, to save predictions] \
  • checkpoint_folder: directory that contains .pt file
  • checkpoint_name: name of .pt file
  • mask_prompt_dict: dictionary mapping each volume ID to the slice index used as the mask prompt (e.g., mask_prompt_dict[volume_id] = slice_index)

License

All codes in this repository are under GPLv3 license.

About

This is the official implementation of SLM-SAM 2

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published