Git Product home page Git Product logo

ccdm-stochastic-segmentation's Introduction

Stochastic Segmentation with Conditional Categorical Diffusion Models

The official code repo for the paper Stochastic Segmentation with Conditional Categorical Diffusion Models, accepted at the International Conference on Computer Vision (ICCV) 2023.

Abstract:

Semantic segmentation has made significant progress in recent years thanks to deep neural networks, but the common objective of generating a single segmentation output that accurately matches the image's content may not be suitable for safety-critical domains such as medical diagnostics and autonomous driving. Instead, multiple possible correct segmentation maps may be required to reflect the true distribution of annotation maps. In this context, stochastic semantic segmentation methods must learn to predict conditional distributions of labels given the image, but this is challenging due to the typically multimodal distributions, high-dimensional output spaces, and limited annotation data. To address these challenges, we propose a conditional categorical diffusion model (CCDM) for semantic segmentation based on Denoising Diffusion Probabilistic Models. Our model is conditioned to the input image, enabling it to generate multiple segmentation label maps that account for the aleatoric uncertainty arising from divergent ground truth annotations. Our experimental results show that CCDM achieves state-of-the-art performance on LIDC, a stochastic semantic segmentation dataset, and outperforms established baselines on the classical segmentation dataset Cityscapes.

Installation

Requires Python 3.10 and Torch 1.7.0 (see requirements.txt):

pip install -r requirements.txt

Datasets

Cityscapes

Download from here: Cityscapes dataset. Please switch to branch cts for the experiments on Cityscapes dataset and follow Cityscapes.md for instructions.

LIDC

For LIDCv1: We used the data available on Stefan Knegt's gihub page.

For LIDCv2: This split of LIDC can be found on the github page of Hierarchical Probabilistic U-Net or from this Google Drive link.

Training (Segmentation with multiple annotations)

For training on LIDCv1, copy the dataset to ${TMPDIR}/data_lidc.hdf5, and, in params.yml, set:

dataset_file: datasets.lidc

To run the training:

python ddpm_train.py params.yml

Evaluation

For evalution on LIDC, in params_eval.yml, set:

dataset_file: datasets.lidc

To run the evaluation:

python ddpm_eval.py params_eval.yml

Pretrained models

You can find a pretrained model for LIDCv1 along with the parameter file needed for evaluating it here.

You can find the code to train and evaluate models on cityscapes as well as a cdm_dino_256x512 checkpoint in the releases.

Citation

If you find our work relevant to your research, please cite:

@InProceedings{Zbinden_2023_ICCV,
    author    = {Zbinden, Lukas and Doorenbos, Lars and Pissas, Theodoros and Huber, Adrian Thomas and Sznitman, Raphael and M\'arquez-Neila, Pablo},
    title     = {Stochastic Segmentation with Conditional Categorical Diffusion Models},
    booktitle = {Proceedings of the IEEE/CVF International Conference on Computer Vision (ICCV)},
    month     = {October},
    year      = {2023},
    pages     = {1119-1129}
}

License

The code is published under the MIT License.

Updates

  • 15/03/2023 Initial commit.
  • 10/05/2023 Added new branch for cityscapes experiments and released cdm_dino checkpoint.
  • 14/07/2023 Accepted at ICCV 2023.

Acknowledgements

We made base our implementation of Dino as feature extractor on https://github.com/ShirAmir/dino-vit-features/blob/main/extractor.py and also make use the official checkpoints from https://github.com/facebookresearch/dino. We thank their respective authors for open-sourcing.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.