Git Product home page Git Product logo

inference_suboptimality's Introduction

Inference Suboptimality in Variational Autoencoders

This repository contains code for a project undertaken as part of the Advanced Topics in Machine Learning course (HT 2020) at Oxford. The code here was written by Mikhail Andrenkov, Maxence Draguet, Sebastian Lee, and Diane Magnin.

It is a reproduction of the code for the paper Inference Suboptimality in Variational Autoencoders by Cremer, Li & Duvenaud.

Our code contains the following features relevant to replicating results from the paper:

  • Flexible encoder/decoder architectures
  • Various approximate posteriors inlcuding:
    • Factorised Gaussian
    • R-NVP flows
    • R-NVP flows with auxiliary variables
  • Relevant binarised image datasets inlcuding:
    • MNIST
    • Fashion-MNIST
    • CIFAR-10
  • Local optimisation training loop
  • AIS and IWAE log-likelihood estimators

Additionally we implemented a planar flows approximate posterior.

Quick Links

Prerequisites

To run this code you will need the following:

  • Python 3.7+

Our code uses PyTorch. We include a requirements file (requirements.txt). We recommend creating a virtual environment (using conda or virtualenv) for this code base e.g.

python3 -m venv aml; source aml/bin/activate

From there, all Python prerequisites should be satisfied by running

pip3 install -r requirements.txt

To run experiments with a GPU, it is essential to use Python 3.7.5 (on Windows). Our code is compatible with CUDA 10.1.

Datasets

We do not provide the datasets directly in this repository. However we are using modifications of standard datasets (e.g. MNIST, CIFAR10) that can be loaded with the torchvision datasets module. To retrieve the datasets, and make the requisite modifications (the binarisation specified by Larochelle et al) run:

python data/get_datasets.py

Running Code

Standalone experiments can be run from the experiment folder using the main.py script. Configuration for such an experiment can be set using the base_config.yaml file for general attributes of the experiment as well as specific config files in the additional_configs/ folder (e.g. for setting parameters of a flow module).

Running a specific experiment from the paper can be done by accessing the relevant hard coded configuration files in the Experiment_List folder, which have been made to match the specifications of the paper. For example to reproduce the configuration of a fully-factorised gaussian approximate posterior with an amortised inference network (๐“›(VAE[q]) | qFFG from Table 2. in the paper), run from the experiments folder:

python main.py -config experiment_list/expA/base_config.yaml -additional_configs experiment_list/expA/additional_configs/

Alternatively, all results from a given experiment can be run at once in sequence using the bash script in the relevant experiment folder.

Accessing Experimental Results

Results of an experiment are by default saved in experiments/results/X/ where X is a timestamp for the experiment. Here you will find a copy of the configuration used to run that experiment, a .csv file containing logging of relevant metrics (e.g. train/test loss), and tensorboard events files. To view the tensorboard logs navigate to this folder and run:

tensorboard --logdir .

Alternatively run the command from elsewhere and modify the path accordingly. Plots of an experiment run can also be made by running the plot_from_df.py script from the experiments/plotting folder and passing the path to the folder containing the csv file to the -save_path flag.

Weights of the models being trained in a given experiment are also saved by default in experiment/saved_models/Y/X/ where Y is a hash of the configuration file and X is a timestamp for the experiment. Saved models can be loaded (e.g. to run local optimisation) by specifying the saved model path in the base config (Note they are saved weights and not full checkpoints so cannot be used to resume training).

Code Structure

Below is the structure of the relevant files in our repository.

โ”‚
โ”œโ”€โ”€ requirements.txt
โ”œโ”€โ”€ README.md
โ”‚     
โ”œโ”€โ”€ data
โ”‚     
โ”œโ”€โ”€ experiments
โ”‚    โ”‚
โ”‚    โ”‚
โ”‚    โ”œโ”€โ”€ additional_configs
โ”‚    โ”‚   โ”‚
โ”‚    โ”‚   โ”œโ”€โ”€ aux_flow_config.yaml
โ”‚    โ”‚   โ”œโ”€โ”€ esimator_config.yaml
โ”‚    โ”‚   โ”œโ”€โ”€ flow_config.yaml
โ”‚    โ”‚   โ”œโ”€โ”€ local_optimisation_config.yaml
โ”‚    โ”‚   โ””โ”€โ”€ planar_config.yaml
โ”‚    โ”‚
โ”‚    โ”œโ”€โ”€ experiment_list (bash scripts for paper experiments)
โ”‚    โ”‚   โ”‚
โ”‚    โ”‚   โ”œโ”€โ”€ expA
โ”‚    โ”‚   โ”œโ”€โ”€ expB
โ”‚    โ”‚   โ”œโ”€โ”€ expB
โ”‚    โ”‚   โ”œโ”€โ”€ expC
โ”‚    โ”‚   โ”œโ”€โ”€ expD
โ”‚    โ”‚   โ””โ”€โ”€ expE
โ”‚    โ”‚
โ”‚    โ”œโ”€โ”€ plotting
โ”‚    โ”‚   โ”‚
โ”‚    โ”‚   โ”œโ”€โ”€ plot_config.json
โ”‚    โ”‚   โ””โ”€โ”€ plot_from_df.py
โ”‚    โ”‚
โ”‚    โ”œโ”€โ”€ results
โ”‚    โ”‚   โ”‚
โ”‚    โ”‚   โ””โ”€โ”€ **result files (not tracked/commited)**
โ”‚    โ”‚
โ”‚    โ”œโ”€โ”€ saved_models
โ”‚    โ”‚   โ”‚
โ”‚    โ”‚   โ””โ”€โ”€ **saved_model files (not tracked/commited)**
โ”‚    โ”‚
โ”‚    โ”œโ”€โ”€ base_config.yaml
โ”‚    โ”œโ”€โ”€ context.py
โ”‚    โ””โ”€โ”€ main.py
โ”‚     
โ”œโ”€โ”€ models
โ”‚    โ”‚
โ”‚    โ”‚
โ”‚    โ”œโ”€โ”€ approximate_posteriors
โ”‚    โ”‚   โ”‚
โ”‚    โ”‚   โ”œโ”€โ”€ __init__.py
โ”‚    โ”‚   โ”‚
โ”‚    โ”‚   โ”œโ”€โ”€ base_approximate_posterior.py
โ”‚    โ”‚   โ”œโ”€โ”€ base_norm_flow.py
โ”‚    โ”‚   โ”œโ”€โ”€ gaussian.py
โ”‚    โ”‚   โ”œโ”€โ”€ planar_flow.py
โ”‚    โ”‚   โ”œโ”€โ”€ rnvp_aux_flow.py
โ”‚    โ”‚   โ”œโ”€โ”€ rnvp_flow.py
โ”‚    โ”‚   โ””โ”€โ”€ sylv_flow.py
โ”‚    โ”‚
โ”‚    โ”œโ”€โ”€ likelihood_estimators
โ”‚    โ”‚   โ”‚
โ”‚    โ”‚   โ”œโ”€โ”€ __init__.py
โ”‚    โ”‚   โ”‚
โ”‚    โ”‚   โ”œโ”€โ”€ ais_estimator.yaml
โ”‚    โ”‚   โ”œโ”€โ”€ base_estimator.yaml
โ”‚    โ”‚   โ”œโ”€โ”€ iwae_estimator.yaml
โ”‚    โ”‚   โ””โ”€โ”€ max_estimator.yaml
โ”‚    โ”‚
โ”‚    โ”œโ”€โ”€ local_optimisation_modules
โ”‚    โ”‚   โ”‚
โ”‚    โ”‚   โ”œโ”€โ”€ __init__.py
โ”‚    โ”‚   โ”‚
โ”‚    โ”‚   โ”œโ”€โ”€ base_local_optimisation.py
โ”‚    โ”‚   โ”œโ”€โ”€ gaussian_local_optimisation.py
โ”‚    โ”‚   โ”œโ”€โ”€ rnvp_aux_flow_local_optimisation.py
โ”‚    โ”‚   โ””โ”€โ”€ rnvp_flow_local_optimisation.py
โ”‚    โ”‚
โ”‚    โ”œโ”€โ”€ loss_modules
โ”‚    โ”‚   โ”‚
โ”‚    โ”‚   โ”œโ”€โ”€ __init__.py
โ”‚    โ”‚   โ”‚
โ”‚    โ”‚   โ”œโ”€โ”€ base_loss.py
โ”‚    โ”‚   โ”œโ”€โ”€ gaussian_loss.py
โ”‚    โ”‚   โ”œโ”€โ”€ planar_loss.py
โ”‚    โ”‚   โ”œโ”€โ”€ rnvp_aux_loss.py
โ”‚    โ”‚   โ””โ”€โ”€ rnvp_loss.py
โ”‚    โ”‚
โ”‚    โ”œโ”€โ”€ networks
โ”‚    โ”‚   โ”‚
โ”‚    โ”‚   โ”œโ”€โ”€ __init__.py
โ”‚    โ”‚   โ”‚
โ”‚    โ”‚   โ”œโ”€โ”€ base_network.py
โ”‚    โ”‚   โ”œโ”€โ”€ convolutional.py
โ”‚    โ”‚   โ”œโ”€โ”€ deconvolutional.py
โ”‚    โ”‚   โ”œโ”€โ”€ fc_encoder.py
โ”‚    โ”‚   โ””โ”€โ”€ fc_decoder.py
โ”‚    โ”‚
โ”‚    โ”œโ”€โ”€ decoder.py
โ”‚    โ”œโ”€โ”€ encoder.py
โ”‚    โ”œโ”€โ”€ vae_runner.py
โ”‚    โ””โ”€โ”€ vae.py
โ”‚    
โ””โ”€โ”€ utils
     โ”‚
     โ”‚
     โ”œโ”€โ”€ __init__.py 
     โ”‚     
     โ”œโ”€โ”€ custom_torch_transforms.py
     โ”œโ”€โ”€ dataloaders.py
     โ”œโ”€โ”€ math_operations.py
     โ”œโ”€โ”€ parameters.py 
     โ”œโ”€โ”€ torch_operations_test.py
     โ””โ”€โ”€ torch_operations.py             

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.