Git Product home page Git Product logo

masf's Introduction

MASF

Domain Generalization via Model-Agnostic Learning of Semantic Features

We study the challenging problem of domain generalization, i.e., training a model on multi-domain source data such that it can directly generalize to unseen target domains. We adopta model-agnostic learning paradigm with gradient-based meta-train and meta-testprocedures to expose the optimization to domain shift. Further, we introduce two complementary losses which explicitly regularize the semantic structure ofthe feature space. Globally, we align a derived soft confusion matrix to preservegeneral knowledge about inter-class relationships. Locally, we promote domain-independent class-specific cohesion and separation of sample features with ametric-learning component.

This is the reference implementation of the domain generalization method described in our paper:

@inproceedings{dou2019domain,
    author = {Qi Dou and Daniel C. Castro and Konstantinos Kamnitsas and Ben Glocker},
    title = {Domain Generalization via Model-Agnostic Learning of Semantic Features},
    booktitle = {Advances in Neural Information Processing Systems (NeurIPS)},
    year = {2019},
}

If you make use of the code, please cite the paper in any resulting publications.

Setup

Check dependencies in requirements.txt, and necessarily run

pip install -r requirements.txt

Running MASF

Download PACS dataset from here, put in dataroot /path/to/PACS_dataset, put the .txt files in '/path/to/image/filelist'
Download the ImageNet pretrained AlexNet weights bvlc_alexnet.npy from here.
To run masf with target domain as art_painting

python main.py --dataset pacs --target_domain art_painting --inner_lr 1e-5 --outer_lr 1e-5 --metric_lr 1e-5 --margin 20

Monitoring training with Tensorboard

Tensorboard logs of losses and gradients are stored in /log/, to observe it run

tensorboard --logdir {/log/}

Running on medical data

To run on medical dataset, replace functions of construct_alexnet_weights() and forward_alexnex() to construct_unet_weights() and forward_unet() demoed in the medical folder

masf's People

Contributors

carrend avatar bglocker avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.