Git Product home page Git Product logo

swa_gaussian's People

Contributors

andrewgordonwilson avatar dependabot[bot] avatar dnlcrl avatar izmailovpavel avatar nijkah avatar timgaripov avatar wjmaddox avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

swa_gaussian's Issues

Loading SWAG Checkpoint and Continue SWAG Training

Dear SWAG Team,

I am currently using your SWAG Code for my own Research and stumbled upon the following problem:

Continuing SGD Training from a Checkpoint is not a problem.
How I start the code:
python3 experiments/train/run_swag.py --data_path= --epochs=300 --dataset=CIFAR100 --save_freq=10 --model=VGG16 --lr_init=0.05 --wd=5e-4 --swa --swa_start=161 --swa_lr=0.01 --cov_mat --use_test --dir= --resume=

But I am not able to continue SWAG Training. I found the parameter "saw_resume", but I am not able to get the code running.
How I start the code:
python3 experiments/train/run_swag.py --data_path= --epochs=300 --dataset=CIFAR100 --save_freq=10 --model=VGG16 --lr_init=0.05 --wd=5e-4 --swa --swa_start=161 --swa_lr=0.01 --cov_mat --use_test --dir= --swa_resume=

I get the following Error Message, when the Code tries to initialize the SWAG Object:
File "PATH/swa_gaussian-master/experiments/train/run_swag.py", line 252, in
swag_model = SWAG(
File "PATH/swa_gaussian-master/swag/posteriors/swag.py", line 49, in init
self.base = base(*args, **kwargs)
TypeError: init() got an unexpected keyword argument 'loading'

How can I get this running?

Thanks for your help,
Yours sincerely,
Florian Linsner

Variance of gradient noise and optimal learning rate

Hej!

Maybe I'm wrong, but I think you use an incorrect formula for computing the optimal learning rate in https://github.com/wjmaddox/swa_gaussian/blob/master/experiments/grad_cov/grad_cov_utils.py#L52. From the first equation in Appendix A of your paper (which is Eq. 6 in Mandt's paper) it follows that

V(ĝ(θ)) ≈ V(∇g(θ) / √B) = V(∇g(θ)) / B = C(θ) / B,

which seems to be mentioned by Mandt in Assumption 1 as well. Thus you can estimate C by estimating V (ĝ(θ)) and multiplying with B, and the optimal learning rate according to Mandt is given by

η = 2 B d / (N tr(C(θ))) ≈ 2 B d / (N tr(B V(ĝ(θ)))) = 2 d / (N tr(V(ĝ(θ))).

To me it seems that your estimate for the optimal learning is off by a factor of B². This means that the minimum optimal learning rate of 3000 that you report in the appendix should actually be 3000 / 128² ≈ 0.18 (assuming you used the default batch size of 128 in https://github.com/wjmaddox/swa_gaussian/blob/master/experiments/grad_cov/run_grad_cov.py#L23), which would be much closer to a standard learning rate of 0.1.

As I said, maybe I'm getting something completely wrong, so I'd be happy to hear how you obtain the formula in Appendix A and https://github.com/wjmaddox/swa_gaussian/blob/master/experiments/grad_cov/grad_cov_utils.py#L52.

Error with CUDA10

Pytorch/CUDA error when running experiments/train/run_swag.py when running on your setup and
CUDA 10.0
nvidia driver: 410.79

This seems to be in no way your fault, see for instance here, but you should be aware that it affects this repo.

The proposed fix is to change this setting:
torch.backends.cudnn.benchmark = False.
The error remains but it does not break the script and it continues to training.

$ python run_swag.py --data_path ./data/cifar --dir train --use_test --model VGG16

Preparing directory train
Using model VGG16
Loading dataset CIFAR10 from ./data/cifar
Files already downloaded and verified
You are going to run models on the test set. Are you sure?
Files already downloaded and verified
Preparing model

SGD training
**THCudaCheck FAIL file=/pytorch/aten/src/THC/THCGeneral.cpp line=405 error=11 : invalid argument **
Traceback (most recent call last):
  File "run_swag.py", line 172, in <module>
    train_res = utils.train_epoch(loaders['train'], model, criterion, optimizer)
  File "/home/jakob/dev/swa_gaussian/swag/utils.py", line 69, in train_epoch
    loss, output = criterion(model, input, target)
  File "/home/jakob/dev/swa_gaussian/swag/losses.py", line 7, in cross_entropy
    output = model(input)
  File "/home/jakob/dev/swa_gaussian/venv/lib/python3.6/site-packages/torch-1.0.1.post2-py3.6-linux-x86_64.egg/torch/nn/modules/module.py", line 489, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/jakob/dev/swa_gaussian/swag/models/vgg.py", line 57, in forward
    x = self.features(x)
  File "/home/jakob/dev/swa_gaussian/venv/lib/python3.6/site-packages/torch-1.0.1.post2-py3.6-linux-x86_64.egg/torch/nn/modules/module.py", line 489, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/jakob/dev/swa_gaussian/venv/lib/python3.6/site-packages/torch-1.0.1.post2-py3.6-linux-x86_64.egg/torch/nn/modules/container.py", line 92, in forward
    input = module(input)
  File "/home/jakob/dev/swa_gaussian/venv/lib/python3.6/site-packages/torch-1.0.1.post2-py3.6-linux-x86_64.egg/torch/nn/modules/module.py", line 489, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/jakob/dev/swa_gaussian/venv/lib/python3.6/site-packages/torch-1.0.1.post2-py3.6-linux-x86_64.egg/torch/nn/modules/conv.py", line 320, in forward
    self.padding, self.dilation, self.groups)
RuntimeError: cuda runtime error (11) : invalid argument at /pytorch/aten/src/THC/THCGeneral.cpp:405

Reproducing UCI Regression Experiments

Hi again, I am trying to reproduce the UCI Regression experiments., specifically Table 12 from your paper. Do you have the executable setup available to rerun experiments for the SGD and SWAG column? Thanks in advance:)

Results CSV

Hi, and thank you for the great work.

I'd like to compare some of your results with my own work and need access to either the model saves or, even better would be the CSV's containing the results used to produce the reliability diagrams if you have them?

Thanks!

Chris

Reproducibility of Uncertainty Experiment

Hi,
I was trying to reproduce the uncertainty experiment. I used the same command used in the README file of this repo. For training SWAG, I ran run_swag.py --data_path=./data --epochs=300 --dataset=CIFAR10 --save_freq=300 --model=PreResNet164 --lr_init=0.1 --wd=3e-4 --swa --swa_start=161 --swa_lr=0.01 --cov_mat --use_test --dir=./snap

Afterwards, I ran uncertainty.py --data_path=./data --dataset=CIFAR10 --model=PreResNet164 --use_test --cov_mat --method=SWAG --scale=0.5 --file=<checkpoint_file> --save_path=<path_to_save_var>

Test accuracy aligns with the result from paper but NLL value sky-rockets to 1000+ where it is around 0.12 in the paper. Can you check if I have missed something in the command line params, also please tell if I need to do some other tuning to reproduce the NLL values.

Cannot find key 'n_models'

Hi @wjmaddox!

I've been trying to reproduce the results for the segmentation experiment and have hit an error I cannot seem to fix.
I'm using the commands in the readme to train a SWAG model and then to evaluate but I end up with the following error.
Any idea what the reason could be?

python eval_ensemble.py --data_path /home/ec2-user/CamVid/ --batch_size 4 --method SWAG --scale=0.5 --loss cross_entropy --N 50 --file ./experiment_swag/checkpoint-1000.pt --save_path ./experiment_swag/output.npz

/home/ec2-user/CamVid/
Preparing model
Loading model ./experiment_swag/checkpoint-1000.pt

Traceback (most recent call last):
  File "eval_ensemble.py", line 146, in <module>
    model.load_state_dict(checkpoint["state_dict"])
  File "/home/ec2-user/swa_gaussian/swag/posteriors/swag.py", line 182, in load_state_dict
    n_models = state_dict["n_models"].item()

KeyError: 'n_models'

Running on CPU

I believe it would be nice if you could easily choose whether you want to run the code on CPU or GPU. My understanding is that all related parameters are hard-coded now.

RMSE UCI Regression Results Paper

Thanks for open sourcing the code and the interesting paper. I am looking for the RMSE on UCI Regression experiments, as you state in 5.5 of the paper "We report test log-likelihoods, RMSEs and test calibration results in Appendix Tables 12 and 13 where it is possible to see that SWAG is competitive with these methods." However, Table 12 seems to only contain unnormalized test-loglikelihoods and Table 13 contains calibration results. Therefore, I was wondering if the RMSE results are available elsewhere?

reliability diagrams

Hi Wesley (@wjmaddox)
I was wondering if you could shed some light on the calibration plots.
So I'm running the save calibration plots script giving as input the predictions and targets of vgg16 trained on cifar10(5+5) but when I plot the output I get something not even close to the plots in the paper. From my reading it seems that both sgd and swag are under confident. Am I doing something wrong?

image

Questions about the plotting of relability diagrams

Hello, Thanks for your great code.
But while plotting the relability diagram according to your paper, i met some problems.
The sticks of my plotting are in a huddle.
Could you plz give the plotting code for reference?
Thanks!
image
image

'CIFAR10' object has no attribute 'targets'

I have some error when I run this script

python3 experiments/train/run_swag.py --data_path=/home/jh/data/ --epochs=300 -->dataset=CIFAR10 --save_freq=300 --model=PreResNet164 --lr_init=0.1 --wd=3e-4 -->swa --swa_start=161 --swa_lr=0.01 --cov_mat --use_test -->dir=/home/jh/swa_gaussian/experiments/train/checkpoint`

I have CUDA 8.0, and install 'python setup.py develop', but I have the following error.

Preparing directory /home/jh/swa_gaussian/experiments/train/checkpoint
Using model PreResNet164
Loading dataset CIFAR10 from /home/jh/data/

Files already downloaded and verified
Traceback (most recent call last):
File "experiments/train/run_swag.py", line 186, in
split_classes=args.split_classes,
File "/home/jh/swa_gaussian/swag/data.py", line 196, in loaders
num_classes = max(train_set.targets) + 1
AttributeError: 'CIFAR10' object has no attribute 'targets'

What should I do?

Question about KFACLaplace for BatchNorm

I couldn't many implementations for KFACLaplace so I was implementing it from your baseline in this repo. I noticed a few things about the handling of batchnorm that I am curious about...

Non-Reproducible / Weird Uncertainty Results

Hello,

I wanted to check the uncertainty properties of my SWAG models and can not reproduce the values from the paper.
When I train a VGG16/CIFAR10 model with the parameters from the paper and use your uncertainty script like this:

python uncertainty.py --data_path= --model=VGG16 --dataset=CIFAR10 --method=SWAG --scale=0.5 --use_test --cov_mat --file=<path/swag-300.pt> --save_path=<save_path>

It will result in a good accuracy, but in a huge NLL:

30/30
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 79/79 [00:00<00:00, 88.94it/s]
Accuracy: 0.9366
NLL: 1990.9390262566017

Thanks in advance

Cannot understand result

Hi i have trained CIFAR10 with --split_classes 0 and after i have run /experiments/uncertainty/uncertainty.py i cannot understand the output about accuracy and NLL:

I am expected different result with higher accuracy.
Do the following results makes sense?

Thanks

1/30
79/79 [00:13<00:00, 6.00it/s]
Accuracy: 0.3055
NLL: 82666.69703534538
2/30
| 79/79 [00:13<00:00, 5.99it/s]
Accuracy: 0.3054
NLL: 80883.04669379114
3/30
| 79/79 [00:13<00:00, 5.95it/s]
Accuracy: 0.3058
NLL: 79560.46427493387
4/30
| 79/79 [00:13<00:00, 5.77it/s]
Accuracy: 0.3046
NLL: 78679.55740708904
5/30
| 79/79 [00:13<00:00, 5.92it/s]
Accuracy: 0.3046
NLL: 78697.95717055122
6/30
79/79 [00:13<00:00, 5.92it/s]
Accuracy: 0.3054
NLL: 78423.18946819699
7/30
79/79 [00:13<00:00, 5.99it/s]
Accuracy: 0.3043
NLL: 78413.52605041317
8/30
79/79 [00:13<00:00, 5.95it/s]
Accuracy: 0.3045
NLL: 77780.88915598544
9/30
79/79 [00:13<00:00, 5.95it/s]
Accuracy: 0.3042
NLL: 77406.18986791647
10/30
79/79 [00:13<00:00, 5.84it/s]
Accuracy: 0.3043
NLL: 76929.39630398026
11/30
79/79 [00:13<00:00, 5.92it/s]
Accuracy: 0.304
NLL: 76934.29220921292
12/30
| 79/79 [00:13<00:00, 5.96it/s]
Accuracy: 0.3043
NLL: 76991.46480763993
13/30
| 79/79 [00:13<00:00, 5.99it/s]
Accuracy: 0.3049
NLL: 76846.16643335675
14/30
79/79 [00:13<00:00, 5.97it/s]
Accuracy: 0.3047
NLL: 76986.94727699496
15/30
79/79 [00:13<00:00, 5.99it/s]
Accuracy: 0.3055
NLL: 77010.32154600904
16/30
| 79/79 [00:13<00:00, 5.90it/s]
Accuracy: 0.305
NLL: 76450.6290673627
17/30
| 79/79 [00:13<00:00, 5.96it/s]
Accuracy: 0.3054
NLL: 74538.73480594362
18/30
█| 79/79 [00:13<00:00, 5.88it/s]
Accuracy: 0.3053
NLL: 74561.06134984753
19/30
| 79/79 [00:13<00:00, 5.96it/s]
Accuracy: 0.3053
NLL: 74636.1909890976
20/30
| 79/79 [00:13<00:00, 5.94it/s]
Accuracy: 0.3053
NLL: 74616.04049727577
21/30
79/79 [00:13<00:00, 5.91it/s]
Accuracy: 0.3053
NLL: 74681.99346308134
22/30
79/79 [00:13<00:00, 5.99it/s]
Accuracy: 0.3049
NLL: 74770.6411489519
23/30
79/79 [00:13<00:00, 5.99it/s]
Accuracy: 0.3053
NLL: 74883.58924571003
24/30
79/79 [00:13<00:00, 5.89it/s]
Accuracy: 0.3049
NLL: 74830.6947837674
25/30
100%|| 79/79 [00:13<00:00, 5.99it/s]
Accuracy: 0.3046
NLL: 74713.18414589943
26/30
79/79 [00:13<00:00, 5.99it/s]
Accuracy: 0.3048
NLL: 74794.40056861148
27/30
| 79/79 [00:13<00:00, 5.94it/s]
Accuracy: 0.3049
NLL: 74712.60328365376
28/30
79/79 [00:13<00:00, 5.98it/s]
Accuracy: 0.305
NLL: 74729.760734134
29/30
79/79 [00:13<00:00, 5.94it/s]
Accuracy: 0.3049
NLL: 74740.55728330463
30/30
79/79 [00:13<00:00, 5.97it/s]
Accuracy: 0.3047
NLL: 74798.28355654998

Replicating results from paper with dropout

Hi, I am trying to replicate the results from your paper, and I was unsure how you trained the model with swag-dropout. I tried to use the same command script as from readme, and replace the model with dropout - but the losses do not converge. Could you describe how to run an experiment for image classification and get results for MCMC dropout.

Questions about the implementation of calculation of Low-Rank Covariance Matrix

Hello, thanks for your awesome work. I'm trying to calculate the low-rank covariance matrix $\Sigma_{\text {low-rank }}=\frac{1}{K-1} \cdot \widehat{D} \widehat{D}^{\top}$.
In my implementation, i use ResNet18 structure, and collect its parameters and flatten. However, this leads to a array (11172042, 1).
When i tried to calculate the low-rank matrix, the matmul operation of (11172042, K) @ (K, 11172042) requires more than 80Tb memory, which is nearly impossible to calculate.
image

I thought there must be something wrong for my implementation, and i wonder how do you calculate the big matrix matmul?
Thanks for your patience.

Sampling using SWAG

First of all thank you for publicly sharing your work.

I am a senior CS bachelor student and I am using the SWAG estimate of the posterior P(theta | dataset) as part of a theoretical framework supporting the empirical results I reached so far during my thesis. I have a couple of questions concerning the SWAG class.

So if I understood correctly, the SWAG class provides a way to sample from the posterior and compute the log probability of the samples. After digging deeper into the code, I see that the output of the "compute_logprob" method depends on the values of "mean_list, var_list, covar_mat_root_list" generated by the "generate_mean_var_covar()" method as indicated in the code snippet below.

image

Going through "generate_mean_var_covar()" method, these values are extracted from the "mean" and "sq_mean" attributes of each sub-module as indicated in the code snippet below:

image

So in order to get different outputs for the "compute_logprob()" method, the values of "mean" and "sq_mean" in the different sub modules need to change. However, the only method that changes these values is the "collect_model()" method. Hence, I conjunctured that I should proceed as follows:

  1. Define a base model, then define a swag model with the same base class.
  2. when training the base model I should call the swag_model.collect_model() at the end of each epoch as this will update the parameters (the mean, and covariance matrix).
  3. After the training, the swag_model can be used to sample from the Posterior distribution as follows
  • swag_model.sample(): sample a set of parameters from the posterior
  • swag_model.compute_log_prob(): to compute the log probability of the current set of parameters (sampled with the call above)

I would greatly appreciate it if you can confirm / correct my understanding of your implementation. Thanks a lot in advance

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.