Git Product home page Git Product logo

5l1v3r1 / evidential-deep-learning Goto Github PK

View Code? Open in Web Editor NEW

This project forked from aamini/evidential-deep-learning

0.0 1.0 0.0 9.61 MB

Learn fast, scalable, and calibrated measures of uncertainty using neural networks!

Home Page: https://proceedings.neurips.cc/paper/2020/file/aab085461de182608ee9f607f3f7d18f-Paper.pdf

License: Apache License 2.0

TeX 0.20% Python 98.43% Shell 0.30% MATLAB 1.08%

evidential-deep-learning's Introduction

Evidential Deep Learning

"All models are wrong, but some — that know when they can be trusted — are useful!"

- George Box (Adapted)

This repository contains the code to reproduce Deep Evidential Regression, as published in NeurIPS 2020, as well as more general code to leverage evidential learning to train neural networks to learn their own measures of uncertainty directly from data!

Setup

To use this package, you must install the following dependencies first:

  • python (>=3.7)
  • tensorflow (>=2.0)
  • pytorch (support coming soon)

Now you can install to start adding evidential layers and losses to your models!

pip install evidential-deep-learning

Now you're ready to start using this package directly as part of your existing tf.keras model pipelines (Sequential, Functional, or model-subclassing):

>>> import evidential_deep_learning as edl

Example

To use evidential deep learning, you must edit the last layer of your model to be evidential and use a supported loss function to train the system end-to-end. This repository supports evidential layers for both fully connected and convolutional (2D) layers. The evidential prior distribution presented in the paper follows a Normal Inverse-Gamma and can be added to your model:

import evidential_deep_learning as edl
import tensorflow as tf

model = tf.keras.Sequential(
    [
        tf.keras.layers.Dense(64, activation="relu"),
        tf.keras.layers.Dense(64, activation="relu"),
        edl.layers.DenseNormalGamma(1), # Evidential distribution!
    ]
)
model.compile(
    optimizer=tf.keras.optimizers.Adam(1e-3), 
    loss=edl.losses.EvidentialRegression # Evidential loss!
)

Checkout hello_world.py for an end-to-end toy example walking through this step-by-step. For more complex examples, scaling up to computer vision problems (where we learn to predict tens of thousands of evidential distributions simultaneously!), please refer to the NeurIPS 2020 paper, and the reproducibility section of this repo to run those examples.

Reproducibility

All of the results published as part of our NeurIPS paper can be reproduced as part of this repository. Please refer to the reproducibility section for details and instructions to obtain each result.

Citation

If you use this code for evidential learning as part of your project or paper, please cite the following work:

@article{amini2020deep,
  title={Deep evidential regression},
  author={Amini, Alexander and Schwarting, Wilko and Soleimany, Ava and Rus, Daniela},
  journal={Advances in Neural Information Processing Systems},
  volume={33},
  year={2020}
}

evidential-deep-learning's People

Contributors

aamini avatar

Watchers

 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.