Git Product home page Git Product logo

pytorch-domain-adaptation's Introduction

Pytorch Adversarial Domain Adaptation

A collection of implementations of adversarial unsupervised domain adaptation algorithms.

Domain adaptation

The goal of domain adaptation is to transfer the knowledge of a model to a different but related data distribution. The model is trained on a source dataset and applied to a target dataset (usually unlabeled). In this case, the model is trained on regular MNIST images, but we want to get good performance on MNIST with random color (without any labels).

In adversarial domain adaptation, this problem is usually solved by training an auxiliary model called the domain discriminator. The goal of this model is to classify examples as coming from the source or target distribution. The original classifier will then try to maximize the loss of the domain discriminator, comparable to the GAN training procedure.

Implemented papers

Paper: Unsupervised Domain Adaptation by Backpropagation, Ganin & Lemptsky (2014)
Link: https://arxiv.org/abs/1409.7495
Description: Negates the gradient of the discriminator for the feature extractor to train both networks simultaneously.
Implementation: revgrad.py


Paper: Adversarial Discriminative Domain Adaptation, Tzeng et al. (2017)
Link: https://arxiv.org/abs/1702.05464
Description: Adapts the weights of a classifier pretrained on source data to produce similar features on the target data.
Implementation: adda.py


Paper: Wasserstein Distance Guided Representation Learning, Shen et al. (2017)
Link: https://arxiv.org/abs/1707.01217
Description: Uses a domain critic to minimize the Wasserstein Distance (with Gradient Penalty) between domains.
Implementation: wdgrl.py

Results

Method Accuracy on MNIST-M Parameters
Source only 0.33
RevGrad 0.74 default
ADDA 0.76 default
WDGRL 0.78 --k-clf 10 --wd-clf 0.1

Instructions

  1. Download the BSDS500 dataset and extract it somewhere. Point the DATA_DIR variable in config.py to this location.
  2. In a Python 3.6 environment, run:
$ conda install pytorch torchvision numpy -c pytorch
$ pip install tqdm opencv-python
  1. Train a model on the source dataset with
$ python train_source.py
  1. Choose an algorithm and pass it the pretrained network, for example:
$ python adda.py trained_models/source.pt

pytorch-domain-adaptation's People

Contributors

jvanvugt avatar

Watchers

 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.