Git Product home page Git Product logo

echo's Introduction

Echo Noise for Exact Mutual Information Calculation

Tensorflow/Keras code replicating the experiments in: https://arxiv.org/abs/1904.07199

@article{brekelmans2019exact,
  title={Exact Rate-Distortion in Autoencoders via Echo Noise},
  author={Brekelmans, Rob and Moyer, Daniel and Galstyan, Aram and Ver Steeg, Greg},
  journal={arXiv preprint arXiv:1904.07199},
  year={2019}
}

Echo noise is flexible, data-driven alternative to Gaussian noise that admits an simple, exact expression for mutual information by construction. Applied in the autoencoder setting, we show that regularizing with I(X:Z) corresponds to the optimal choice of prior in the Evidence Lower Bound and leads to significant improvements over VAEs.

Echo Noise

For easy inclusion in other projects, the echo noise functions are included in one all-in-one file, echo_noise.py, which can be copied to a project and included directly, e.g.:

import echo_noise

There are two basic functions implemented, the noise function itself (echo_sample) and the MI calculation (echo_loss), both of which are included in echo_noise.py. Except for libaries, echo_noise.py has no other file dependencies.

Echo noise is meant to be used similarly to the Gaussian noise in VAEs, and was implemented with VAE implementations in mind. Assuming the inference network provides z_mean and z_log_scale, a Gaussian Encoder would look something like:

z = z_mean + tf.exp(z_log_scale) * tf.random.normal( tf.shape(z_mean) )

The Echo noise equivalent implemented here is:

z = echo_noise.echo_sample( [z_mean, z_log_scale] )

Similarly, VAEs often calculate a KL divergence penalty based on z_mean and z_log_scale. The Echo noise penalty, which is the mutual information I(x,z), can be computed using:

loss = ... + echo_noise.echo_loss([z_log_scale])

A Keras version of this might look like the following:

z_mean = Dense(latent_dim, activation = model_utils.activations.tanh64)(h)
z_log_scale = Dense(latent_dim, activation = tf.math.log_sigmoid)(h)
z_activation = Lambda(echo_noise.echo_sample)([z_mean, z_log_scale])
echo_loss = Lambda(echo_noise.echo_loss)([z_log_scale])

These functions are also found in the experiments code, model_utils/layers.py and model_utils/losses.py.

Instructions:

python run.py --config 'echo.json' --beta 1.0 --filename 'echo_example' --dataset 'binary_mnist'

Experiments are specifed using the config files, which specify the network architecture and loss functions. run.py calls model.py to parse these configs/ and create / train a model. You can also modify the tradeoff parameter beta, which is multiplied by the rate term, or specify the dataset using 'binary_mnist', 'omniglot', or 'fmnist'. . Analysis tools are mostly omitted for now, but the model loss training history is saved in a pickle file.

A note about Echo sampling and batch size:

We can choose to sample training examples with or without replacement from within the batch for constructing Echo noise.
For sampling without replacement, we have two helper functions which shuffle index orderings for x^(l). permute_neighbor_indices sets the output batch_size != None and is much faster. indices_without_replacement maintains batch_size = None (e.g. for variable batch size or fitting with keras fit). Control these with set_batch option.

Also be wary of leftover batches : we choose d_max samples to construct Echo noise from within the batch, so small batches (especially without replacement) may give inaccurate noise.

Comparison Methods

We compare diagonal Gaussian noise encoders ('VAE') and IAF encoders, alongside several marginal approximations : standard Gaussian prior, standard Gaussian with MMD penalty (info_vae.json or iaf_prior_mmd.json), Masked Autoregressive Flow (MAF), and VampPrior. All combinations can be found in the configs/ folder.

echo's People

Contributors

brekelma avatar dcmoyer avatar

Watchers

James Cloos avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.