Git Product home page Git Product logo

memo's Introduction

MEMO: Test Time Robustness via Adaptation and Augmentation

These directories contain code for reproducing the MEMO results for the CIFAR-10 and ImageNet distribution shift test sets.

Please note: this code has been modified from the version that generated the results in the paper for the purpose of cleaning the code. Though it is likely that the code is correct and should produce the same results, it is possible that there may be some discrepancies that were not caught. Though minor results differences may arise from stochasticity, please report any major differences or bugs by submitting an issue, and we will aim to resolve these promptly.

Setup

First, create a Anaconda environment with requirements.txt, e.g.,

conda create -n memo python=3.8 -y -q --file requirements.txt
conda activate memo

After doing so, you will need to pip install tqdm. For the robust vision transformer models, you will also need to pip install timm einops.

CIFAR-10 Experiments

The cifar-10-exps directory contains code for the CIFAR-10 experiments. You can run bash script_c10.sh for the full set of experiments. Alternatively, you can run python script_test_c10.py directly with the experiment you wish to run (see script_c10.sh for more details).

For convenience, we provide the ResNet26 model that we trained in results/cifar10_rn26_gn/ckpt.pth. We do not provide the datasets themselves, though you can download the non standard test sets here:

After downloading and setting up the datasets, make sure to modify the dataroot variable on line 8 of script_test_c10.py.

ImageNet Experiments

The imagenet-exps directory contains code for the ImageNet experiments. You can run bash script_in.sh for the full set of experiments, though this is very slow. You can again run python script_test_in.py directly with the experiment you wish to run. For the corrupted image datasets, you may wish to slightly modify the code to only run one corruption-level pair (and then parallelize).

As an example, we provide the baseline ResNet-50 model from torchvision in results/imagenet_rn50/ckpt.pth. Including all of the pretrained model weights would be prohibitively large. We did not train our own models for ImageNet, and all other models we used can be downloaded:

We also experimented with a baseline ResNext-101 (32x8d) model which we obtained from torchvision.

Please note: some of these models provide the weights in slightly different conventions, thus loading the downloaded state_dict may not directly work, and the keys in the state_dict may need to be modified to match with the code. We have done this modification already for the baseline ResNet-50 model, and thus this ckpt.pth can be used as a template for modifying other model checkpoints.

We again do not provide the datasets themselves, though you can download the test sets here:

After downloading and setting up the datasets, again make sure to modify the dataroot variable on line 8 of script_test_in.py.

Paper

Please use the following citation:

@article{memo,
    author={Zhang, M. and Levine, S. and Finn, C.},
    title={{MEMO}: Test Time Robustness via Adaptation and Augmentation},
    article={arXiv preprint arXiv:2110.09506},
    year={2021},
}

The paper can be found on arXiv here.

Acknowledgments

The design of this code was adapted from the TTT codebases. Other parts of the code that were adapted from third party sources are credited via comments in the code itself.

memo's People

Contributors

zhangmarvin avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar

memo's Issues

Issue of ablation study performance

I am verifying the effect of pairwise cross entropy, and implement it with the following code:

class SoftCrossEntropyLoss(nn.Module):
   def __init__(self):
      super().__init__()

   def forward(self, y_hat, y):
      p = F.log_softmax(y_hat, 0)
      loss = -(y*p).sum() / (y).sum()
      return loss

softce = SoftCrossEntropyLoss().cuda()
   
def pairwise_entropy(outputs):
    logits = outputs - outputs.logsumexp(dim=-1, keepdim=True)
    loss, B = 0, outputs.shape[0]
    for i in range(B):
        for j in range(B):
            if i == j: continue
            loss += softce(outputs[i], outputs[j])
    loss /= B * (B-1)
    return loss, logits

However, such implementation cannot achieve performance gain in level 5 gaussian noise compared with original model. Could you kindly provide the code of pairwise cross entropy in the ablation study?

About the adapt procedure.

In the /test_adapt.py line 54
net.load_state_dict(ckpt['state_dict'])
It means that adpat and predict on a image, but don't save the parameters?

it will report an error here in loss.backward()

hello,i want to use memo to test coco. After calculating the marginal entropy, starting backpropagation to calculate the gradient, it will report an error here in loss.backward()
Error detected in SigmoidBackward0. Traceback of forward call that caused the error:
File "test.py", line 636, in
test(opt.data,
File "test.py", line 235, in test
adapt_single(image, optimizer, batch_size, model, device)
File "test.py", line 62, in adapt_single
outputs,train_output = model(inputs) # [batch_size, num_boxes, 85]
File "miniconda3/envs/memo/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "miniconda3/envs/memo/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "models/yolo.py", line 599, in forward
return self.forward_once(x, profile) # single-scale inference, train
File "models/yolo.py", line 625, in forward_once
x = m(x) # run
File "miniconda3/envs/memo/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "miniconda3/envs/memo/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "models/yolo.py", line 153, in fuseforward
y = x[i].sigmoid()
(Triggered internally at ../torch/csrc/autograd/python_anomaly_mode.cpp:114.)

Reproducing paper's results.

Hello,

First of all, thank you so much for this amazing work.

I am trying to reproduce the results from the paper using the code in this repository. It seems that using default parameters, it is not possible to reproduce the results, and there is a large gap between the numbers. I tried to increase the number of iterations (niter) from 1 to 5, and it helped a lot, but still, the numbers are not the same. Since each run takes too much time, I was wondering if you could share the parameters used to generate the paper's results.

ImageNet-A

ImageNet-R Run-1 Run-2 Paper
ResNet-50 100.0 100.0 100.0
+Memo 99.9 99.8 99.1
DeepAug+AugMix 96.1 96.1 96.1
+Memo 95.5 95.6 94.8
MoEx+CutMix 92.0 92.0 91.9
+Memo 91.5 91.4 89.0

ImageNet-R

ImageNet-A Run-1 Run-2 Paper
ResNet-50 63.8 63.8 63.9
+Memo 61.3 61.3 58.8
DeepAug+AugMix 53.2 53.2 53.2
+Memo 51.3 51.3 49.2
MoEx+CutMix 64.5 64.5 64.5
+Memo 61.8 61.9 59.4

Details and logs can be found here.

Thank you so much.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.