Git Product home page Git Product logo

probabilistic_unet's Introduction

PyTorch Implementation of the Probabilistic UNet

This is yet another PyTorch implementation of the Probabilistic UNet. This repo is forked from https://github.com/stefanknegt/Probabilistic-Unet-Pytorch. I have enabled the checkpoint saving and implemented a visualization method.

Probabilistic U-Net paper for segmentation of ambiguous images: https://arxiv.org/abs/1806.05034. Official code repo: https://github.com/SimonKohl/probabilistic_unet.

Example Result 1 result1

Example Result 2 result2

About this repository

Please note that in the spirit of understanding the code, the test set is currently being used as the validation set and the model checkpoint is saved based on the best test loss. This is not fair to compare against methods that treat the held-out set as the test set. For fair quantitative evaluation, you probably need to remove the validation loop from train_model.py

I have fixed the dataset for train/val purposes. This makes it possible to reliably use the same train/val split for different scripts (training and visualization).

This code was tested with PyTorch 1.6. You probably need to install pydicom: pip install pydicom.

Usage

Note: the original repository mentions the need to add the KL divergence loss in PyTorch code. I don't think there is any need to manually add KL divergence in the PyTorch code. PyTorch distributions now include the KL divergence for independent Normal distribution. This works fine in PyTorch 1.6, and probably some older versions as well.

Training

To train the network, use python3 train_model.py. This file contains optimization settings and also a directory name (out_dir). The trained model and training logs will be saved in this directory: out_dir.

To train your own model, you need to prepare a dataset that yields (input patch, segmentation labels) and you should be able to use the pretty much the same code as in train_model.py.

Dealing with NaNs: sometimes, you might encounter NaNs. Decreasing the learning rate might be easiest solution. Of course, modifying the ELBO loss would be more fancy :)

Visualizing

To save results of segmentation, run python3 visualize.py. This will load the trained model, by default it will load the trained model provided with this repo in trained_model/model_dict.pth. To use your own trained model, set the cpk_directory in visualize.py.

LIDC Dataset

One of the datasets used in the original paper is the LIDC dataset. The original repo preprocessed this data and stored them in a pickle file, which you can download here. After downloading the files you should place them in a folder called 'data'. After that, you can train your own Probabilistic UNet on the LIDC dataset using the simple train script provided in train_model.py.

probabilistic_unet's People

Contributors

usman-rafique avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar

probabilistic_unet's Issues

Segmentation results

Hi

Thank you for your code. They are useful and helpful. I have several quick questions. The final out_samples are not 0-1 segmentation. Are they softmax outputs?

Best
Stella

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.