Git Product home page Git Product logo

large-scale-ot's Introduction

Large-Scale-OT

Implementation in PyTorch of stochastic algorithms for the computation of Regularized Optimal Transport proposed in [1]

[1] Proposes:

  • A stochastic algorithm (Alg. 1) for computing the optimal dual variables of the regularized OT problem (from which the regularized-OT objective can be computed simply)
  • A stochastic algorithm (Alg. 2) for learning an Optimal Map, parameterized as a Deep Neural Network between the source and target probability measures

Both entropy and L2 regularizations are considered and implemented.

Requirements

python2 or python3
pytorch
matplotlib

Install

git clone https://github.com/vivienseguy/Large-Scale-OT.git

Usage

Start by creating the regularized-OT computation class: either PyTorchStochasticDiscreteOT or PyTorchStochasticSemiDiscreteOT depending on your setting.

from StochasticOTClasses.StochasticOTDiscrete import PyTorchStochasticDiscreteOT

discreteOTComputer = PyTorchStochasticDiscreteOT(xs, ws, xt, wt, reg_type='l2', reg_val=0.02, device_type='cpu')

Compute the optimal dual variables through Alg. 1.:

history = discreteOTComputer.learn_OT_dual_variables(epochs=1000, batch_size=50, lr=0.0005)

Once the optimal dual variables have been obtained, you can compute the OT loss stochastically:

d_stochastic = discreteOTComputer.compute_OT_MonteCarlo(epochs=20, batch_size=50)

You can also learn an approximate optimal map between the two probability measures by learning the barycentric mapping (ALg. 2.). The mapping is parameterized as a deep neural network that you can supply in the functions parameters. Otherwise a default small 3-layers NN is used.

bp_history = discreteOTComputer.learn_barycentric_mapping(epochs=300, batch_size=50, lr=0.000002)

Once learned, you can apply the (approximate) optimal mapping to some sample via:

xsf = discreteOTComputer.evaluate_barycentric_mapping(xs)

You can visualize the source, target and mapped samples:

import matplotlib.pylab as pl

pl.figure()
pl.plot(xs[:, 0], xs[:, 1], '+b', label='source samples')
pl.plot(xt[:, 0], xt[:, 1], 'xr', label='target samples')
pl.plot(xsf[:, 0], xsf[:, 1], '+g', label='mapped source samples')
pl.legend()

References

[1] Seguy, Vivien and Damodaran, Bharath Bhushan and Flamary, Rémi and Courty, Nicolas and Rolet, Antoine and Blondel, Mathieu. Large-Scale Optimal Transport and Mapping Estimation. Proceedings of the International Conference in Learning Representations (2018)

@inproceedings{seguy2018large,
  title={Large-Scale Optimal Transport and Mapping Estimation},
  author={Seguy, Vivien and Damodaran, Bharath Bhushan and Flamary, R{\'e}mi and Courty, Nicolas and Rolet, Antoine and Blondel, Mathieu},
  booktitle={Proceedings of the International Conference in Learning Representations},
  year={2018},
}

large-scale-ot's People

Contributors

nightwinkle avatar vivienseguy avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.