Git Product home page Git Product logo

ensemble-of-averages's Introduction

Ensemble of Averages: Improving Model Selection and Boosting Performance in Domain Generalization

Official PyTorch implementation of Ensemble of Averages

This repository is built upon the DomainBed repository by FAIR.

Alt text

Environment:

	Python: 3.6.8
	PyTorch: 1.9.0+cu111
	Torchvision: 0.10.0+cu111
	CUDA: 11.1
	CUDNN: 8005
	NumPy: 1.19.5
	PIL: 8.4.0

Run Commands

Simple Moving Average (SMA):

First we train the models with our SMA protocol.

All experiments below use the ImageNet pre-trained ResNet-50 architecture, specified as resnet50 in the command using the arch hyper-parameter. Other supported options include resnext50_swsl and regnety_16gf corresponding to weakly-supervised pre-trained models ResNeXt-50 32x4d from Yalniz et al, and RegNetY-16GF from Singh et al respectively.

PACS:

python -m domainbed.scripts.sweep launch --data_dir path/to/data --output_dir erm-sma_resnet50/pacs --command_launcher multi_gpu --algorithms ERM_SMA --datasets PACS --n_hparams 3 --n_trials 2 --single_test_envs --hparams '{"arch": "resnet50"}'

VLCS:

python -m domainbed.scripts.sweep launch --data_dir path/to/data --output_dir erm-sma_resnet50/vlcs --command_launcher multi_gpu --algorithms ERM_SMA --datasets VLCS --n_hparams 3 --n_trials 2 --single_test_envs --hparams '{"arch": "resnet50"}'

OfficeHome:

python -m domainbed.scripts.sweep launch --data_dir path/to/data --output_dir erm-sma_resnet50/officehome --command_launcher multi_gpu --algorithms ERM_SMA --datasets OfficeHome --n_hparams 3 --n_trials 2 --single_test_envs --hparams '{"arch": "resnet50"}'

TerraIncognita:

python -m domainbed.scripts.sweep launch --data_dir path/to/data --output_dir erm-sma_resnet50/terra --command_launcher multi_gpu --algorithms ERM_SMA --datasets TerraIncognita --n_hparams 3 --n_trials 2 --single_test_envs --hparams '{"arch": "resnet50"}'

DomainNet (notice that the number of steps is set to 15000 for this dataset following SWAD):

python -m domainbed.scripts.sweep launch --data_dir path/to/data --output_dir erm-sma_resnet50/domainnet --command_launcher multi_gpu --algorithms ERM_SMA --datasets DomainNet --n_hparams 3 --n_trials 2 --single_test_envs --hparams '{"arch": "resnet50"}' --steps 15000

Ensemble of Averages (EoA)

We now use the best SMA models saved from the above runs (using in-domain validation accuracy based early stopping) in an ensemble, that we call EoA since these ensembles contain moving average models.

TIP: Use larger values of num_workers and batch_size for faster runtime.

PACS (note that the output_dir argument is same as the value used above in the SMA experiment for PACS dataset):

python -m domainbed.EoA --data_dir path/to/data --dataset PACS --output_dir erm-sma_resnet50/pacs --hparams '{"num_workers": 1, "batch_size": 128, "arch": "resnet50"}'

Citation

If you use our work, please cite us:

@article{arpit2021ensemble,
  title={Ensemble of Averages: Improving Model Selection and Boosting Performance in Domain Generalization},
  author={Arpit, Devansh and Wang, Huan and Zhou, Yingbo and Xiong, Caiming},
  journal={arXiv preprint arXiv:2110.10832},
  year={2021}
}

License

This repository is released under the MIT License. See this for details.

ensemble-of-averages's People

Contributors

devansh-arpit avatar dependabot[bot] avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.