Git Product home page Git Product logo

dvaess's Introduction

Learning Undirected Posteriors by Backpropagation through MCMC Updates

This repository offers the Tensorflow implementation of undirected posterior variational autoencoders presented in this paper. This repository can be used to reproduce all the results presented in the paper (Table 1, 2 and 3) for both binarized MNIST and OMNIGLOT.

This repo is mainly based on our earlier implementation available here.

For sampling from Boltzmann priors, population annealing (PA) algorithms is used. We rely on the sampling library QuPA which was released by Quadrant. You can have access to this library here. For sampling from Boltzmann posteriors, persistent contrastive divergence (PCD) is implemented in Tensorflow.


Running the Training/Evaluation Code

The main train/evaluation script can be run locally using the following command:

python run.py \
    --log_dir=${PATH_TO_LOG_DIR} \
    --data_dir=${PATH_TO_DATA_DIR}

If you don't have the datasets locally, the scripts will download them automatically to the data directory.

The following flags are introduced in order to run the settings reported in the paper:

  1. --dataset specifies the dataset used for the experiment. Currently, we support omniglot and binarized_mnist.
  2. --baseline sets the type of objective function used for training. This corresponds to different columns in Table 1. You can use dvaes_power for DVAE# (power), pwl_relax for the PWL relaxation, gsm_relax for the Concrete relaxation, and rbm_post for DVAE## (RBM posterior).
  3. --num_latents sets the number of latent stochastic units. We examined 200 and 400.
  4. --experiment set the experiment. Use vae for the generative models in Table 1 and 2 or use struct for the structured prediction problem.
  5. --L sets number of stochastic layers. We used L=1,2,4 for directed posteriors and L=1 for undirected.
  6. --k specifies the number of samples used for estimating the variational bound in the case of DVAE/DVAE++ and the importance weighted bound in the case of DVAE#.

Example:

python run.py \
    --log_dir=${PATH_TO_LOG_DIR} \
    --data_dir=${PATH_TO_DATA_DIR} \
    --L=1 \
    --baseline=rbm_post \
    --dateset=binarized_mnist \ 
    --experiment=vae \
    --k=1 \
    --num_latents=200 \

Running Tensorboard

You can monitor the progress of training and the performance on the validation and test datasets using tensorboard. Run the following command to start tensorboard on the log directory:

tensorboard --logdir=${PATH_TO_LOG_DIR}

Prerequisites

Make sure that you have:

  • Python (version 2.7 or higher)
  • Numpy
  • Scipy
  • QuPA
  • Tensorflow (The version should be compatible with QuPA, we tested with Tensorflow 1.12.0)

Change Logs

July 13, 2018

First release including DVAE#, PWL, Concrete and DVAE##


Citation

if you use this code in your research, please cite us:

@article{vahdat2019undirected,
  title={Learning Undirected Posteriors by Backpropagation through {MCMC} Updates},
  author={Vahdat, Arash and Andriyash, Evgeny and Macready, William G.},
  journal={arXiv preprint arXiv:1901.03440},
  year={2019}
}

Contributors

Arash Vahdat, Evgeny Andriyash

dvaess's People

Stargazers

ykachin avatar

Watchers

James Cloos avatar paper2code - bot avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.