Git Product home page Git Product logo

mbdg's Introduction

Model-Based Domain Generalization

This repository contains the coded needed to reporduce the results of Model-Based Domain Generalization. In particular, we include two repositorys:

  1. A fork of DomainBed which can be used to reproduce our results on ColoredMNIST, PACS, and VLCS.
  2. A separate implementation that can be used to reproduce our results on Camelyon17-WILDS and FMoW-WILDS.

We also include a library of trained domain transformation models for ColoredMNIST, PACS, Camelyon17-WILDS, and FMoW-WILDS. This library will be updated in the future with more models. All models can be downloaded from the following Google Drive link:

https://drive.google.com/drive/folders/1vDlZXk_Jow3bkPTlJLlloYCxOZAwnGBv

In this README, we provide an overview describing how this code can be run. If you find this repository useful in your research, please consider citing:

@article{robey2021model,
  title={Model-Based Domain Generalization},
  author={Robey, Alexander and Pappas, George J and Hassani, Hamed},
  journal={arXiv preprint arXiv:2102.11436},
  year={2021}
}

DomainBed implementation

In the DomainBed implementation of our code, we implement our primal-dual style MBDG algorithm in ./domainbed/algorithms.py as well as three algorithmic variants as described in Appendix C: MBDA, MBDG-DA, and MBDG-Reg. These algorithms can be run using the same commands as the original DomainBed repository (see the README at the following link).

Our method is based on a primal-dual scheme for solving the Model-Based Domain Generalization constrained optimization problem. This procedure is described in Algorithm 1 in our paper. In particular, the core of our algorithm is an alternation between updating the primal variable θ (e.g., the parameter of a neural network based classifier) and updating the dual variable λ. Below, we highlight a short code snippet that outlines our method:

class MBDG(MBDG_Base):

    def __init__(self, input_shape, num_classes, num_domains, hparams):
        super(MBDG, self).__init__(input_shape, num_classes, num_domains, hparams)
        self.dual_var = torch.tensor(1.0).cuda().requires_grad_(False)

    def update(self, minibatches, unlabeled=None):
        all_x = torch.cat([x for x,y in minibatches])
        all_y = torch.cat([y for x,y in minibatches])

        # calculate classification loss (loss(θ) in Algorithm 1)
        clean_output = self.predict(all_x)
        clean_loss = F.cross_entropy(clean_output, all_y)

        # calculate regularization term (distReg(θ) in Algorithm 1)
        dist_reg = self.calc_dist_reg(all_x, clean_output)

        # formulate the (empirical) Lagrangian Λ = loss(θ) + λ distReg(θ)
        loss = clean_loss + self.dual_var * dist_reg

        # perform primal step in θ
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        # calculate constaint unsatisfaction term (distReg(θ) - γ)
        const_unsat = dist_reg.detach() - self.hparams['mbdg_gamma']

        # perform dual step in λ
        self.dual_var = self.relu(self.dual_var + self.hparams['mbdg_dual_step_size'] * const_unsat)

        return {'loss': loss.item(), 'dist_reg': dist_reg.item(), 'dual_var': self.dual_var.item()}

In this sub-repository, we include the path to trained domain transformation models and configuration files for MUNIT as hyperparameters. In particular, these parameters can be set for each dataset in ./domainbed/hparams_registry.py:

_hparam('mbdg_model_path', model_path, lambda r: model_path)
_hparam('mbdg_config_path', config_path, lambda r: config_path)

where model_path and config_path are dataset-specific arguments.

WILDS implementation

The WILDS datasets provide out-of-distribution validation sets to perform model selection. Our code uses these validation sets in the ./mbdg-for-wilds sub-repository. The launcher script in ./dist_launch can be used to train classifiers on both Camelyon17-WILDS and on FMoW-WILDS.

The dataset and domain transformation models for WILDS can be set via the following flags in ./dist_launch.sh:

export MODEL_PATH=<path/to/camelyon17>/model.pt
export MUNIT_CONFIG_PATH=<path/to/camelyon17>/config.yaml

A non-distibuted launch script is also provided in ./launch.sh.

mbdg's People

Contributors

arobey1 avatar allenpu avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.