Skip to content

statmlben/rankseg

Repository files navigation

Python Github MIT PRs Welcome

RankSEG: A Consistent Ranking-based Framework for Segmentation (JMLR 2023)

RankSEG is a Python module designed for segmentation tasks, aiming to maximize Dice or IoU metrics based on estimated probabilities.

Key Features

Most segmentation methods traditionally rely on IoU and Dice as evaluation metrics. During inference and prediction, these methods typically use a threshold of 0.5 or apply argmax to the estimated probabilities to generate segmentation predictions. However, this approach does not directly optimize the IoU or Dice metrics.

Our method, RankDice, directly optimizes IoU and Dice metrics.

  • Nearly ensures improved Dice and IoU performance!
  • Seamlessly integrates with any pretrained segmentation neural network (no need to retrain the models).
  • A well-developed Python function rank_dice is available for use.

Installation

git clone https://github.com/statmlben/rankseg.git
pip install -r requirements.txt

How-to-Use (on a batch)

RankDice

## `out_prob` (batch_size, num_class, width, height) is the output probability for each pixel based on a trained neural network
from rankseg import rank_dice
predict_rd, tau_rd, cutpoint_rd = rank_dice(out_prob, app=2, device='cuda')

Other existing frameworks (Threshold and Argmax)

## `out_prob` (batch_size, num_class, width, height) is the output probability for each pixel based on a trained neural network

## Threshold
predict_T = torch.where(out_prob > .5, True, False)

## Argmax
idx = torch.argmax(out_prob.data, dim=1, keepdims=True)
predict_max = torch.zeros_like(out_prob.data, dtype=bool).scatter_(1, idx, True)

Usage in pytorch-segmentation-rankseg (in subfolder)

## rankdice
$ python test.py -r saved/cityscapes/PSPNet/CrossEntropyLoss2d/T/05-04_13-08/checkpoint-epoch300.pth -p "rankdice"

TEST, Pred (rankdice) | Loss: 0.159, PixelAcc: 0.99, Mean IoU: 0.51, Mean Dice 0.59 |: 100%|██████| 84/84 [01:03<00:00,  1.33it/s]

    ## TESTING Restuls for Model: PSPNet + Loss: CrossEntropyLoss2d + predict: rankdice ## 
         test_loss      : 0.15925
         Pixel_Accuracy : 0.9879999756813049
         Mean_IoU       : 0.5099999904632568
         Mean_Dice      : 0.5929999947547913
         Class_IoU      : {0: 0.771, 1: 0.508, 2: 0.767, 3: 0.164, 4: 0.117, 5: 0.317, 6: 0.283, 7: 0.401, 8: 0.841, 9: 0.231, 10: 0.778, 11: 0.4, 12: 0.292, 13: 0.766, 14: 0.233, 15: 0.465, 16: 0.315, 17: 0.177, 18: 0.326}
         Class_Dice     : {0: 0.856, 1: 0.608, 2: 0.851, 3: 0.21, 4: 0.158, 5: 0.46, 6: 0.374, 7: 0.514, 8: 0.903, 9: 0.294, 10: 0.845, 11: 0.495, 12: 0.372, 13: 0.84, 14: 0.265, 15: 0.513, 16: 0.358, 17: 0.222, 18: 0.419}

## max
$ python test.py -r saved/cityscapes/PSPNet/CrossEntropyLoss2d/T/05-04_13-08/checkpoint-epoch300.pth -p "max"

TEST, Pred (max) | Loss: 0.159, PixelAcc: 0.99, Mean IoU: 0.49, Mean Dice 0.56 |: 100%|███████████| 84/84 [00:12<00:00,  6.52it/s]

    ## TESTING Restuls for Model: PSPNet + Loss: CrossEntropyLoss2d + predict: max ## 
         test_loss      : 0.15925
         Pixel_Accuracy : 0.9879999756813049
         Mean_IoU       : 0.48500001430511475
         Mean_Dice      : 0.5649999976158142
         Class_IoU      : {0: 0.768, 1: 0.489, 2: 0.759, 3: 0.133, 4: 0.099, 5: 0.295, 6: 0.257, 7: 0.387, 8: 0.836, 9: 0.208, 10: 0.769, 11: 0.372, 12: 0.272, 13: 0.751, 14: 0.204, 15: 0.395, 16: 0.268, 17: 0.152, 18: 0.303}
         Class_Dice     : {0: 0.854, 1: 0.585, 2: 0.844, 3: 0.172, 4: 0.136, 5: 0.428, 6: 0.341, 7: 0.498, 8: 0.9, 9: 0.268, 10: 0.835, 11: 0.464, 12: 0.351, 13: 0.826, 14: 0.233, 15: 0.437, 16: 0.308, 17: 0.193, 18: 0.392}


## threshold at 0.5
$ python test.py -r saved/cityscapes/PSPNet/CrossEntropyLoss2d/T/05-04_13-08/checkpoint-epoch300.pth -p "T"

TEST, Pred (T) | Loss: 0.159, PixelAcc: 0.99, Mean IoU: 0.50, Mean Dice 0.57 |: 100%|█████████████| 84/84 [00:13<00:00,  6.45it/s]

    ## TESTING Restuls for Model: PSPNet + Loss: CrossEntropyLoss2d + predict: T ## 
         test_loss      : 0.15925
         Pixel_Accuracy : 0.9890000224113464
         Mean_IoU       : 0.4959999918937683
         Mean_Dice      : 0.574999988079071
         Class_IoU      : {0: 0.772, 1: 0.478, 2: 0.762, 3: 0.136, 4: 0.109, 5: 0.29, 6: 0.265, 7: 0.39, 8: 0.841, 9: 0.201, 10: 0.77, 11: 0.363, 12: 0.273, 13: 0.769, 14: 0.219, 15: 0.422, 16: 0.307, 17: 0.158, 18: 0.325}
         Class_Dice     : {0: 0.857, 1: 0.573, 2: 0.846, 3: 0.174, 4: 0.147, 5: 0.419, 6: 0.349, 7: 0.499, 8: 0.902, 9: 0.257, 10: 0.836, 11: 0.451, 12: 0.351, 13: 0.841, 14: 0.247, 15: 0.468, 16: 0.349, 17: 0.197, 18: 0.414}

Jupyter Notebook

Illustrative results

Results in Fine-annotated Cityscapes dataset

  • Threshold, Argmax and rankDice are performed based on the same network (in Model column) trained by the same loss (in Loss column).
  • Averaged mDice and mIoU metrics based on state-of-the-art models/losses on Fine-annotated CityScapes val set. '/' indicates not applicable since the proposed RankDice/mRankDice requires a strictly proper loss. The best performance in each model/loss is bold-faced.
  • All trained neural networks and their config.json with different network and loss are saved in this link (12G folder: network/loss/.../*.pth + config.json)
Model Loss Threshold (at 0.5) Argmax mRankDice (our)
(mDice, mIoU) ($\times .01$) (mDice, mIoU) ($\times .01$) (mDice, mIoU) ($\times .01$)
DeepLab-V3+ CE (56.00, 48.40) (54.20, 46.60) (57.80, 49.80)
(resnet101) Focal (54.10, 46.60) (53.30, 45.60) (56.50, 48.70)
BCE (49.80, 24.90) (44.20, 22.10) (54.00, 27.00)
Soft-Dice (39.50, 35.90) (39.50, 35.90) /
B-Soft-Dice (41.00, 20.50) (27.60, 13.80) /
LovaszSoftmax (55.20, 47.60) (52.30, 45.10) /
PSPNet CE (57.50, 49.60) (56.50, 48.50) (59.30, 51.00)
(resnet50) Focal (56.00, 48.20) (55.80, 47.70) (58.20, 50.00)
BCE (51.40, 25.70) (47.60, 23.80) (55.10, 27.60)
Soft-Dice (49.10, 43.50) (48.70, 43.20) /
B-Soft-Dice (46.30, 23.10) (32.70, 16.40) /
LovaszSoftmax (56.80, 48.90) (55.40, 47.70) /
FCN8 CE (51.40, 43.70) (50.50, 42.60) (53.50, 45.30)
(resnet101) Focal (48.50, 41.20) (49.60, 41.60) (51.50, 43.70)
BCE (39.40, 19.70) (39.40, 19.70) (41.30, 20.60)
Soft-Dice (28.30, 24.30) (28.30, 24.30) /
B-Soft-Dice (29.10, 14.60) (29.10, 14.60) /
LovaszSoftmax (48.10, 40.40) (42.90, 35.80) /

Results in PASCAL VOC 2012 dataset

  • Threshold, Argmax and rankDice are performed based on the same network (in Model column) trained by the same loss (in Loss column).
  • Averaged mDice and mIoU based on state-of-the-art models/losses on PASCAL VOC 2012 val set. '---' indicates that either the performance is significantly worse or the training is unstable, and '/' indicates not applicable since the proposed RankDice/mRankDice requires a strictly proper loss. The best performance in each model-loss pair is bold-faced.
  • All trained neural networks with different network and loss are saved in this link (22G folder: network/loss/.../*.pth)
Model Loss Threshold (at 0.5) Argmax mRankDice (our)
(mDice, mIoU) ($\times .01$) (mDice, mIoU) ($\times .01$) (mDice, mIoU) ($\times .01$)
DeepLab-V3+ CE (63.60, 56.70) (61.90, 55.30) (64.01, 57.01)
(resnet101) Focal (62.70, 55.01) (60.50, 53.20) (62.90, 55.10)
BCE (63.30, 31.70) (59.90, 29.90) (64.60, 32.30)
Soft-Dice --- --- /
B-Soft-Dice --- --- /
LovaszSoftmax (57.70, 51.60) (56.20, 50.30) /
PSPNet CE (64.60, 57.10) (63.20, 55.90) (65.40, 57.80)
(resnet50) Focal (64.00, 56.10) (63.90, 56.10) (66.60, 58.50)
BCE (64.20, 32.10) (65.20, 32.60) (67.10, 33.50)
Soft-Dice (59.60, 54.00) (58.80, 53.20) /
B-Soft-Dice (63.30, 31.60) (54.00. 27.00) /
LovaszSoftmax (62.00, 55.20) (60.80, 54.10) /
FCN8 CE (49.50, 41.90) (45.30, 38.40) (50.40, 42.70)
(resnet101) Focal (50.40, 41.80) (47.20, 39.30) (51.50, 42.50)
BCE (46.20, 23.10) (44.20, 22.10) (47.70, 23.80)
Soft-Dice --- --- /
B-Soft-Dice --- --- /
LovaszSoftmax (39.80, 34.30) (37.30, 32.20) /

Results in Kvasir-SEG dataset

  • Threshold, Argmax and rankDice are performed based on the same network (in Model column) trained by the same loss (in Loss column).
  • Threshold and Argmax are exactly the same in binary segmentation.
  • Averaged mDice and mIoU based on state-of-the-art models/losses on Kvasir-SEG dataset set. '---' indicates that either the performance is significantly worse or the training is unstable, and '/' indicates not applicable since the proposed RankDice/mRankDice requires a strictly proper loss. The best performance in each model-loss pair is bold-faced.
Model Loss Threshold/Argmax mRankDice (our)
(Dice, IoU) ($\times .01$) (Dice, IoU) ($\times .01$)
DeepLab-V3+ CE (87.9, 80.7) (88.3, 80.9)
(resnet101) Focal (86.5, 87.3) /
Soft-Dice (85.7, 77.8) /
LovaszSoftmax (84.3, 77.3) /
PSPNet CE (86.3, 79.2) (87.1, 79.8)
(resnet50) Focal (83.8, 75.4) /
Soft-Dice (83.5, 75.9) /
LovaszSoftmax (86.0, 79.2) /
FCN8 CE (81.9, 73.5) (82.1, 73.6)
(resnet101) Focal (78.5, 69.0) /
Soft-Dice --- ---
LovaszSoftmax (82.0, 73.4) /

More results

  • All empirical results on different losses and models can be found here

Replication

If you want to replicate the experiments in our papers, please check the folder ./pytorch-segmentation-rankseg and its README file Pytorch-segmentation-rankseg

Citation

If you like RankSEG please star 🌟 the repository and cite the following paper:

@article{dai2023rankseg,
  title={RankSEG: A Consistent Ranking-based Framework for Segmentation},
  author={Dai, Ben and Li, Chunlin},
  journal={Journal of Machine Learning Research},
  volume={24},
  number={224},
  pages={1--50},
  year={2023}
}

To-do list

PRs Welcome Github MIT

  • develop a scalable rank_dice with non-overlapping segmentation
  • debug for torch.backends.cudnn.flags(enabled=False, deterministic=True, benchmark=True) when enabled=True
  • CUDA code to speed up the implementation based on app=1