Git Product home page Git Product logo

flow-gan's Introduction

Flow-GAN: Combining Maximum Likelihood and Adversarial Learning in Generative Models

This repository provides a reference implementation for learning Flow-GAN models as described in the paper:

Flow-GAN: Combining Maximum Likelihood and Adversarial Learning in Generative Models
Aditya Grover, Manik Dhar, and Stefano Ermon.
AAAI Conference on Artificial Intelligence (AAAI), 2018.
Paper: https://arxiv.org/pdf/1705.08868.pdf
Blog post: https://ermongroup.github.io/blog/flow-gan

Requirements

The codebase is implemented in Python 3.6. To install the necessary requirements, run the following commands:

pip install -r requirements.txt

Datasets

The scripts for downloading and loading the MNIST and CIFAR10 datasets are included in the datasets_loader folder. These scripts will be called automatically the first time the main.py script is run.

Options

Learning and inference of Flow-GAN models is handled by the main.py script which provides the following command line arguments.

  --beta1 FLOAT           beta1 parameter for Adam optimizer
  --epoch INT             number of epochs to train
  --batch_size FLOAT      training batch size
  --learning_rate FLOAT   learning rate
  --input_height INT      The size of image to use
  --input_width INT       The size of image to use if none given use same value as input height
  --c_dim INT             Dimension of image color
  --dataset STR           The name of dataset [mnist, svhn, cifar-10]
  --checkpoint_dir STR    Directory name to save the checkpoints
  --log_dir STR           Directory name to save the logs
  --sample_dir STR        Directory name to save the image samples
  --f_div STR             divergence used for specifying the gan objective
  --prior STR             prior for generator
  --alpha FLOAT           alpha value for applying logits
  --lr_decay FLOAT        Learning rate decay rate
  --min_lr FLOAT          minimum lr allowed on decay
  --reg FLOAT             regularization parameter for adversarial training
  --model_type STR        real_nvp or nice
  --n_critic INT          no of discriminator iterations
  --no_of_layers INT      No of units between input and output in the m function for a coupling layer
  --hidden_layers INT     Size of hidden layers (applicable only for NICE)
  --like_reg FLOAT        regularizing factor for likelihood vs. adversarial losses for hybrid
  --df_dim FLOAT          Dim depth for discriminator

Examples

Training flow-GAN models on the MNIST dataset with NICE architecture.

Maximum Likelihood Estimation (MLE)

python main.py --dataset mnist --input_height=28 --c_dim=1  --checkpoint_dir checkpoint_mnist/mle --sample_dir samples_mnist/mle --model_type nice --log_dir logs_mnist/mle 
--prior logistic --beta1 0.5 --learning_rate 1e-4 --alpha 1e-7 --epoch 500 --batch_size 100 --like_reg 1.0 --n_critic 0 --no_of_layers 5

Adversarial training (ADV)

python main.py --dataset mnist --input_height=28 --c_dim=1  --checkpoint_dir checkpoint_mnist/gan --sample_dir samples_mnist/gan --model_type nice --log_dir logs_mnist/gan 
--prior logistic --beta1 0.5 --learning_rate 1e-4 --alpha 1e-7 --reg 10.0 --epoch 500 --batch_size 100 --like_reg 0.0 --n_critic 5 --no_of_layers 5

Hybrid

python main.py --dataset mnist --input_height=28 --c_dim=1  --checkpoint_dir checkpoint_mnist/flow --sample_dir samples_mnist/flow --model_type nice --log_dir logs_mnist/flow 
--prior logistic --beta1 0.5 --learning_rate 1e-4 --alpha 1e-7 --reg 10.0 --epoch 500 --batch_size 100 --like_reg 1.0 --n_critic 5 --no_of_layers 5

Training flow-GAN models on the CIFAR dataset with Real-NVP architecture.

Maximum Likelihood Estimation (MLE)

python main.py --dataset cifar --input_height=32 --c_dim=3  --checkpoint_dir checkpoint_cifar/mle --sample_dir samples_cifar/mle --model_type real_nvp --log_dir logs_cifar/mle 
--prior gaussian --beta1 0.9 --learning_rate 1e-3 --alpha 1e-7 --epoch 300 --lr_decay 0.999995 --batch_size 64 --like_reg 1.0 --n_critic 0 --no_of_layers 8 --batch_norm_adaptive 0

Adversarial training (ADV)

python main.py --dataset cifar --input_height=32 --c_dim=3  --checkpoint_dir checkpoint_cifar/gan --sample_dir samples_cifar/gan --model_type real_nvp --log_dir logs_cifar/gan 
--prior gaussian --beta1 0.5 --learning_rate 1e-4 --alpha 1e-7 --epoch 300 --batch_size 64 --like_reg 0.0  --n_critic 5 --no_of_layers 8

Hybrid

python main.py --dataset cifar --input_height=32 --c_dim=3  --checkpoint_dir checkpoint_cifar/flow --sample_dir samples_cifar/flow --model_type real_nvp --log_dir logs_cifar/flow 
--prior gaussian --beta 0.5 --learning_rate 1e-3 --lr_decay 0.99999 --alpha 1e-7 --epoch 500 --batch_size 64 --like_reg 20.  --n_critic 5 --no_of_layers 8

Portions of the codebase in this repository uses code originally provided in the open-source DCGAN and Real-NVP repositories.

Citing

If you find flow-GANs useful in your research, please consider citing the following paper:

@inproceedings{grover2018flowgan,
title={Flow-GAN: Combining Maximum Likelihood and Adversarial Learning in Generative Models},
author={Grover, Aditya and Dhar, Manik and Ermon, Stefano},
booktitle={AAAI Conference on Artificial Intelligence},
year={2018}}

flow-gan's People

Contributors

aditya-grover avatar dharmanik avatar

Watchers

James Cloos avatar paper2code - bot avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.