Git Product home page Git Product logo

pancreas_segmentation's Introduction

Pancreas Segmentation

Tensorflow implementation of a 3D-CNN U-net with Grid Attention and DSV for pancreas segmentation from CT.

Network architecture

  • Classic U-net with residual connections
  • Grid Attention gave a biggest boost in the performance
  • DSV forces intermediate feature-maps to be semantically discriminative

image

Training

Network has been trained on publicly accessible dataset CT-82 from TCIA (64/16 split between training/validation)

Weighted DSC (Dice Similarity Coefficient) used as a loss function. Best weight hyperparameter served to my purposes selected as 7.

  • recall ~95%
  • precision ~57%
  • DCS/F1 ~72% (though it is not really important for my experiments)

image

Here is the Tensorboard.dev comparison between weight hyperparameter with values: 1, 7, 10. I recommend to enable only validation runs and apply filter tag as follow f1|recall|prec

Training process

Whole network has been trained end-to-end, w/o any tiling. Reasoning is to avoid artifacts where pancreas segmentation cut to a tile edge.

Every CT downscaled to dimensionality 160x160x160, this is the maximum size that fits into TeslaK40m (12GB RAM). Pooling implemented over WxH dimensions only, D (depth) keeps constant (ie 160 over the whole network), this helps a little with segmentation recovery. Single CT in a training batch, therefore BatchNormalization was not in use.

Optimization algoritm Adam with start learning rate 0.002 then reduce on plateau by 0.1 over 30 epochs. Total number of epochs restricted to 1000.

Training took ~60 hours on a single server with a single GPU NVIDIA TeslaK40m. Most of the progress achieved in first 3-5 hours.

tensorflow container used as a runtime environment:

docker run --rm -it tensorflow/tensorflow:2.3.2-gpu /bin/bash
github clone https://github.com/IvanKuchin/pancreas_segmentation.git
cd panreas_segmentation
python train_segmentation.py

Learned weights are available here due to GitHub limitation on big files.

Inference

Pre-requisite: tensorflow 2.3 (you could try latest version, but no guarantee that it will work)

Inference can be done on a regular laptop without any GPU installed. Time required for inference ~10-15 seconds.

To test segmentation on your data

  1. Clone this repository github clone https://github.com/IvanKuchin/pancreas_segmentation.git
  2. Create predict folder in cloned folder and put there single pass CT. If it will contain multiple passes result is unpredictable.
  3. Download weights.hdf5 from the link above and put it in cloned folder
  4. python predict.py

Output will be prediction.nii which Neuroimaging Informatics Technology Initiative

All magic happening in last three lines

if __name__ == "__main__":
    pred = Predict()
    pred.main("predict", "prediction.nii")

I used 3DSlicer to check the results visually.

An importance of probability distribution in source data

Network has been trained on CT-82 with every scan is contrast-free. Means that network should recognize similar scans to CT-82 probability distribution. I've tried to test CT with contrast, result was unsatisfied.

Some results

Video recording of segmentation results posted on connme.ru in a group Pancreas cancer detection

Example of prediction in 3DSlicer (prediction: green, ground truth: red)

image

pancreas_segmentation's People

Contributors

ivankuchin avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.