Git Product home page Git Product logo

mmdagg-paper's Introduction

Reproducibility code for MMDAgg: MMD Aggregated Two-Sample Test

This GitHub repository contains the code for the reproducible experiments presented in our paper MMD Aggregated Two-Sample Test.

We provide the code to run the experiments to generate Figures 1-10 and Table 2 from our paper, those can be found in media. The code for the Failing Loudly experiment (with results reported in Table 1) can be found on the FL-MMDAgg repository.

To use our MMDAgg test in practice, we recommend using our mmdagg package, more details available on the mmdagg repository.

Our implementation uses two quantile estimation methods (wild bootstrap and permutations). The MMDAgg test aggregates over different types of kernels (e.g. Gaussian, Laplace, Inverse Multi-Quadric (IMQ), Matérn (with various parameters) kernels), each with several bandwidths. In practice, we recommend aggregating over both Gaussian and Laplace kernels, each with 10 bandwidths.

Requirements

  • python 3.9

The packages in requirements.txt are required to run our tests and the ones we compare against.

Additionally, the jax and jaxlib packages are required to run the Jax implementation of MMDAgg in mmdagg/jax.py.

Installation

In a chosen directory, clone the repository and change to its directory by executing

git clone [email protected]:antoninschrab/mmdagg-paper.git
cd mmdagg-paper

We then recommend creating and activating a virtual environment by either

  • using venv:
    python3 -m venv mmdagg-env
    source mmdagg-env/bin/activate
    # can be deactivated by running:
    # deactivate
    
  • or using conda:
    conda create --name mmdagg-env python=3.9
    conda activate mmdagg-env
    # can be deactivated by running:
    # conda deactivate
    

The packages required for reproducibility of the experiments can then be installed in the virtual environment by running

python -m pip install -r requirements.txt

For using the Jax implementation of MMDAgg, Jax needs to be installed (instructions). For example, this can be done by running

  • for GPU:
    pip install --upgrade "jax[cuda11_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
    # conda install -c conda-forge -c nvidia pip numpy scipy cuda-nvcc "jaxlib=0.4.1=*cuda*" jax
  • or, for CPU:
    conda install -c conda-forge -c nvidia pip jaxlib=0.4.1 jax

Reproducing the experiments of the paper

To run the experiments, the following command can be executed

python experiments.py

This command saves the results in dedicated .csv and .pkl files in a new directory user/raw. The output of this command is already provided in paper/raw. The results of the rest of the experiments, saved in the results directory, can be obtained by running the Computations_mmdagg.ipynb notebook and the Computations_autotst.ipynb notebook which uses the autotst package introduced in the AutoML Two-Sample Test paper.

The actual figures of the paper can be obtained from the saved results by running the code in the figures.ipynb notebook.

All the experiments are comprised of 'embarrassingly parallel for loops', significant speed up can be obtained by using parallel computing libraries such as joblib or dask.

Data

Half of the experiments uses a down-sampled version of the MNIST dataset which is created as a .data file in a new directory mnist_dataset when running the script experiments.py. This dataset can also be generated on its own by executing

python mnist.py

The other half of the experiments uses samples drawn from a perturbed uniform density (Eq. 17). A rejection sampler f_theta_sampler for this density is implemented in sampling.py.

How to use MMDAgg in practice?

The MMDAgg test is implemented as the function mmdagg in mmdagg/np.py for the Numpy version and in mmdagg/jax.py for the Jax version.

For the Numpy implementation of our MMDAgg test, we only require the numpy and scipy packages.

For the Jax implementation of our MMDAgg test, we only require the jax and jaxlib packages.

To use our tests in practice, we recommend using our mmdagg package which is available on the mmdagg repository. It can be installed by running

pip install git+https://github.com/antoninschrab/mmdagg.git

Installation instructions and example code are available on the mmdagg repository.

We also provide some code showing how to use our MMDAgg test in the demo_speed.ipynb notebook which also contains speed comparisons between the Jax and Numpy implementations, as reported below.

Speed in s Numpy (CPU) Jax (CPU) Jax (GPU)
MMDAgg 43.1 14.9 0.495

In practice, we recommend using the Jax implementation as it runs considerably faster (100 times faster in the above table, see notebook demo_speed.ipynb).

References

Failing Loudly: An Empirical Study of Methods for Detecting Dataset Shift. Stephan Rabanser, Stephan Günnemann, Zachary C. Lipto. (paper, code)

Learning Kernel Tests Without Data Splitting. Jonas M. Kübler, Wittawat Jitkrittum, Bernhard Schölkopf, Krikamol Muandet. (paper, code)

AutoML Two-Sample Test. Jonas M. Kübler, Vincent Stimper, Simon Buchholz, Krikamol Muandet, Bernhard Schölkopf. (paper, code)

MMDAggInc

For a computationally efficient version of MMDAgg which can run in linear time, check out our paper Efficient Aggregated Kernel Tests using Incomplete U-statistics with reproducible experiments in the agginc-paper repository and a package in the agginc repository.

Contact

If you have any issues running our code, please do not hesitate to contact Antonin Schrab.

Affiliations

Centre for Artificial Intelligence, Department of Computer Science, University College London

Gatsby Computational Neuroscience Unit, University College London

Inria London

Bibtex

@article{schrab2021mmd,
  author  = {Antonin Schrab and Ilmun Kim and M{\'e}lisande Albert and B{\'e}atrice Laurent and Benjamin Guedj and Arthur Gretton},
  title   = {{MMD} Aggregated Two-Sample Test},
  journal = {Journal of Machine Learning Research},
  year    = {2023},
  volume  = {24},
  number  = {194},
  pages   = {1--81},
  url     = {http://jmlr.org/papers/v24/21-1289.html}
}

License

MIT License (see LICENSE.md).

mmdagg-paper's People

Contributors

antoninschrab avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.