Git Product home page Git Product logo

hybrid_augment's Introduction

HybridAugment++: Unified Frequency Spectra Perturbations for Model Robustness (ICCV'23)

This repository contains the PyTorch implementation of our paper "HybridAugment++: Unified Frequency Perturbations for Model Robustness" accepted to ICCV 2023 .

Paper Abstract

Convolutional Neural Networks (CNN) are known to exhibit poor generalization performance under distribution shifts. Their generalization have been studied extensively, and one line of work approaches the problem from a frequency-centric perspective. These studies highlight the fact that humans and CNNs might focus on different frequency components of an image. First, inspired by these observations, we propose a simple yet effective data augmentation method HybridAugment that reduces the reliance of CNNs on high-frequency components, and thus improves their robustness while keeping their clean accuracy high. Second, we propose HybridAugment++, which is a hierarchical augmentation method that attempts to unify various frequency-spectrum augmentations. HybridAugment++ builds on HybridAugment, and also reduces the reliance of CNNs on the amplitude component of images, and promotes phase information instead. This unification results in competitive to or better than state-of-the-art results on clean accuracy (CIFAR-10/100 and ImageNet), corruption benchmarks (ImageNet-C, CIFAR-10-C and CIFAR-100-C), adversarial robustness on CIFAR-10 and out-of-distribution detection on various datasets. HybridAugment and HybridAugment++ are implemented in a few lines of code, does not require extra data, ensemble models or additional networks.

Paper Highlights

๐Ÿ“Œ We propose HybridAugment and HybridAugment++, two simple data augmentation methods which force models to emphasize low-frequency components, and low-frequency/phase components of training samples, respectively. Both augmentations come with single-image and paired variants, which can and does work better in tandem. Such augmentations lead to models that are robust against various distribution shifts, while keeping or even improving the accuracy on clean samples.

Fig. 1: An overview of our methods HybridAugment (HA) and HybridAugment++ (HA ++ ), and their single image (_S) and paired (_P) variants. HA_P combines the high-frequency (HF) and low-frequency (LF) contents of two randomly selected images, whereas HA_P ++ combines the HF of one image with the amplitude and LF-phase mixtures of two other images. Single image variants perform the same procedure, but based on different augmented versions of a single image.

๐Ÿ“Œ Our methods outperform existing state-of-the-art methods on various benchmarks, including the corruption benchmark ImageNet-C. HybridAugment++ improves its results with more training data (i.e. DeepAugment) and other augmentation methods (i.e. AugMix), and can be used to tailor any need by changing the cut-off frequency (i.e. higher clean accuracy vs higher robustness). We also show that our method is not exclusive to CNNs, and also works quite well with transformers.

Fig. 2: Clean error and corruption robustness on ImageNet. Lower is better. The methods shown in the last four rows leverage extra data during training. โ€  indicates training with a higher cut-off frequency.

๐Ÿ“Œ HybridAugment and HybridAugment++ are easy to implement, do not require extra data, ensemble models or complicated augmentation regimes based on external networks.

Installation

๐Ÿ“Œ See environment.yml file for an exported conda environment. Note that there might be unnecessary dependencies there, so the download might take a while.

๐Ÿ“Œ See requirements.txt file for pip dependencies. Note that there might be unnecessary dependencies there, so the download might take a while.

Datasets

๐Ÿ“Œ Both CIFAR and imagenet training scripts look for the datasets under ./data/ folder, though this can be changed with the relevant flags.

๐Ÿ“Œ Links for some of the datasets: CIFAR-10-C, CIFAR-100-C, ImageNet-C.

Running the code

๐Ÿ“Œ Run the following script to train on CIFAR-10/100.

python main.py --outf output_folder --single_mode ha_p --paired_mode ha_p  --model "resnet" --dataset "cifar10"

See the input args for the other options. use --eval to evaluate the trained model. The training/evaluation will be logged under the output_folder. This script will evaluate on both CIFAR-10/100 and their corrupted versions.

Use --ood_dataset flag to choose which OOD dataset you would like to test on. Put these OOD datasets under ./data/ folder for easy experimentation.

๐Ÿ“Œ Run the following script to train on ImageNet.

python imagenet.py --arch "resnet50" --data path/to/imagenet  --multiprocessing-distributed --rank 0 --world_size 1 --single_mode ha_p --paired_mode ha_p 

See the input args for the other options. use --evaluate to evaluate a trained model on ImageNet. This script will only evaluate on ImageNet.

๐Ÿ“Œ For a fair comparison with other methods, we use the evaluation script of AugMix (see here). After downloading that repo, you can evaluate the ImageNet-trained model on ImageNet-C as follows.

python imagenet.py  --evaluate --resume path/to/checkpoint  path/to/imagenet path/to/imagenet_c

These arguments should be fine for evaluation, but refer to the relevant script for more options.

๐Ÿ“Œ Run the following scripts (under ./autoattacks/ folder) to train on CIFAR10 with adversarial training.

python train_fgsm.py --lr-max 0.20 --prob-p 0.16 --prob-s 0.90 --epochs 90 --out-dir output_folder --single_mode ha_p --paired_mode ha_p --opt-level O0

See the script for more options during training.

๐Ÿ“Œ Run the following scripts (under ./autoattacks/ folder) to evaluate the adversarial robustness of trained models (works with models trained with ./train_fgsm.py/).

python eval.py --model  path/to/model.pth  --data_dir ../data/cifar10/ --log_path path/to/log.txt

See the script for more options during evaluation.

Pretrained Weights

๐Ÿ“Œ We provide pretrained weights as well as the training/evaluation logs for most of our models.

๐Ÿ“Œ HybridAugment++ (PS) models (CIFAR-10).

AllConv DenseNet WideResNet ResNext ResNet18
mCE 10.7 9.5 8.3 7.9 8.2

๐Ÿ“Œ HybridAugment++ (PS) models (CIFAR-100).

AllConv DenseNet WideResNet ResNext ResNet18
mCE 34.4 33.4 31.2 28.8 29.9

๐Ÿ“Œ Pretrained models on ImageNet (ResNet50).

HA++ (PS) HA++ (PS) โ€  HA++ (PS) + DA HA++ (PS) + DA โ€  HA++ (PS) + DA + AM โ€ 
mCE 67.3 65.8 58.9 58.1 56.1

๐Ÿ“Œ Models trained with adversarial training + our methods on CIFAR-10 (See Table 4 our paper).

HA (S) HA++ (S) HA (P) HA++ (P) HA (PS) HA++ (PS)
CA 86.5 85.0 85.5 85.4 85.0 82.8
RA 44.1 45.4 42.1 43.5 44.8 46.0

โ— PS indicates paired-single combined variant. โ€  indicates training with a higher cut-off frequency. DA is DeepAugment, AM is AugMix.

Citation

๐Ÿ“Œ If you find our code or paper useful in your research, please consider citing our paper.

@inproceedings{yucel2023hybridaugment,
  title={HybridAugment++: Unified Frequency Spectra Perturbations for Model Robustness},
  author={Yucel, Mehmet Kerim and Cinbis, Ramazan Gokberk and Duygulu, Pinar},
  booktitle = {International Conference on Computer Vision (ICCV)}
  year={2023},
}

Acknowledgements

This code base has borrowed several implementations from this link

hybrid_augment's People

Contributors

mkyucel avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar

hybrid_augment's Issues

Implementation in albumentations API

Thanks for the interesting work!

It would be quite beneficial if these methods can be provided as a ready to use class that follows albumentation API which is probably the most used augmentation library. A pip package would also increase the adoption.

Is there anything in your method that is specific to image clasification? Can it be utilized in semantic or instance segmentation or even in other modalities such as audio/speech?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.