Git Product home page Git Product logo

deterministic-uncertainty-quantification's Introduction

Deterministic Uncertainty Quantification (DUQ)

This repo contains the code for Uncertainty Estimation Using a Single Deep Deterministic Neural Network, which is accepted for publication at ICML 2020.

If the code or the paper has been useful in your research, please add a citation to our work:

@article{van2020uncertainty,
  title={Uncertainty Estimation Using a Single Deep Deterministic Neural Network},
  author={van Amersfoort, Joost and Smith, Lewis and Teh, Yee Whye and Gal, Yarin},
  booktitle={International Conference on Machine Learning},
  year={2020}
}

Dependencies

The code is based on PyTorch and requires a few further dependencies, listed in environment.yml. The code was tested with the versions specified in the environment file, but should work with newer versions as well (except for ignite=0.4.3). If you find an incompatibility, please let me know and I'll gladly update the code for the newest version of each library.

Datasets

Most datasets will be downloaded on the fly by Torchvision. Only NotMNIST needs to be downloaded in advance in a subfolder called data/:

mkdir -p data && cd data && curl -O "http://yaroslavvb.com/upload/notMNIST/notMNIST_small.mat"

FastFashionMNIST is based on this code. The default Torchvision implementation first creates a PIL image (see here) which creates a CPU bottleneck (while training on GPU). The FastFashionMNIST class provides a significant speed up.

Running

The Two Moons experiments can be replicated using the Two Moons notebook. The FashionMNIST experiment is implemented in train_duq_fm.py. For both experiments, the paper's default are hardcoded and can be changed in place.

The ResNet18 based CIFAR experiments are implemented in train_duq_cifar.py. All command line parameter defaults are as listed in the experimental details in Appendix A of the paper. We additionally include a Wide ResNet based architecture.

For example: CIFAR-10 with gradient penalty with weight 0.5 and full training set:

python train_duq_cifar.py --final_model --l_gradient_penalty 0.5

Note that ommitting --final_model will lead to 20% of the training data to be used for validation, such that hyper parameter selection can be done in a responsible manner. The code also supports the Wide ResNet with --architecture WRN.

I also include code for my implementation of Deep Ensembles. It's a very simple implementation that achieves good results (95% accuracy in 75 epochs using 5 models).

python train_deep_ensemble.py --dataset CIFAR10

This command will train a Deep Ensemble consisting of 5 models (the default) on CIFAR10.

Questions

For questions about the code or the paper, feel free to open an issue or email me directly. My email can be found on my GitHub profile, my website and the paper above.

Deep Ensembles vs DUQ

deterministic-uncertainty-quantification's People

Contributors

y0ast avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar

deterministic-uncertainty-quantification's Issues

Some questions of paper and codes

  • I’m confused about the method of updating centroids, could you please explain it?

  • The paper mentioned:

The method of updating centroids was introduced in the Appendix of van den Oord et al. (2017) for updating quantised latent variable.

I only found the paper, but didn't find the Appendix of van den Oord et al. (2017), Can you provide the link of the Appendix?

  • As for the codes:
def step(engine, batch):
        model.train()
        optimizer.zero_grad()

        x, y = batch
        x, y = x.cuda(), y.cuda()

        x.requires_grad_(True)

        y_pred = model(x)

        y = F.one_hot(y, num_classes).float()

        loss = F.binary_cross_entropy(y_pred, y, reduction="mean")

        if l_gradient_penalty > 0:
            gp = calc_gradient_penalty(x, y_pred)
            loss += l_gradient_penalty * gp

        loss.backward()
        optimizer.step()

        x.requires_grad_(False)

        with torch.no_grad():
            model.eval()
            model.update_embeddings(x, y)

        return loss.item()

Is the gradient of x just for calculating gradient penalty? How does the loss of l_gradient_penalty * gp backpropagate?

 def update_embeddings(self, x, y):
        self.N = self.gamma * self.N + (1 - self.gamma) * y.sum(0)

        z = self.feature_extractor(x)

        z = torch.einsum("ij,mnj->imn", z, self.W)
        embedding_sum = torch.einsum("ijk,ik->jk", z, y)

        self.m = self.gamma * self.m + (1 - self.gamma) * embedding_sum

Could you please explain the process of model.update_embeddings ? What’s the meaning of self.N and self.m?

Thank you so much!

DUQ is able to estimate aleatoric uncertainty or not?

In your ICML presentation slides, you mentioned that DUQ is able to estimate aleatoric uncertainy, but in the paper 2.3, you said DUQ captures both aleatoric and epistemic uncertainty.

Comapred with "What Uncertainties Do We Need in Bayesian Deep Learning for Computer Vision?" by Alex Kendall in NeurIPS 2017, what is the advantage of DUQ?

Compared with using Gaussian Mixture Density Netwrok to estimate the stds(uncertainty), what is your advantage?

Thanks ahead!

Replication of Toy Example for Deep Ensemble

Thanks for great paper.

I am trying to replicate the uncertainty plot of Deep Ensemble of the Toy Example but I couldn't get the same plot as shown in Figure 1 of the paper.

I noticed that we only have the code for DUQ. Can you also share the code for generating the uncertainty plot for Deep Ensemble?

Thanks,

Element 0 of tensors does not require grad and does not have a grad_fn

Hi,
Thank you for sharing your working codes.
I want to trying your code, But i got a problem with training code

File "directory/train_duq_cifar.py", line 103, in calc_gradient_penalty gradients = calc_gradients_input(x, y_pred)
File "directory/train_duq_cifar.py", line 95, in calc_gradients_input create_graph=True,

It occurred after 1epoch

As you can see line 122 of train_duq_cifar.py, Authors set the x.requires_grad_(True) for tracking the gradients

I don't know this error occurred.
Could me help me?
Thank you

A question of paper

In the figure 2 of the paper, why let uncertain be
?
If is very close to the centroid, uncertainty will be bigger in this formula, which is hard to understand.

Reproduce results

Hi,
thanks for sharing your work.

I'm trying to reproduce your results. I'm trying SVHN CIFAR10 results. I have trained your model and now I'm testing.

I produce 2 hitsograms about scores(kernel_distance) from CIFAR 10 and SVHN, but they are quite different from your paper. Accuracy score and AUROC is similar from the paper:

SVNH
Accuracy,Auroc 0.9135159073448065 0.9238

get from

accuracy, auroc = get_cifar_svhn_ood(model)

CIFAR 10

0.9238 0.9070916430423466

get from function

accuracy, auroc = get_auroc_classification(test_dataset, model)

I have attached the 2 histograms

cifar10

svhn

Do you think this results are similar from yours?

Thanks

Error when used train_duq_fm.py

File "/deterministic-uncertainty-quantification-master/utils/evaluate_ood.py", line 44, in loop_over_dataloader
kernel_distance, pred = output.max(1)
AttributeError: 'tuple' object has no attribute 'max'

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.