Git Product home page Git Product logo

ali-pytorch's Introduction

Adversarially Learned Inference

Implementation of paper Aversarially Learned Inference in Pytorch

main.py includes training code for datasets

  • SVHN
  • CIFAR10
  • CelebA

models.py includes the network architectures for the different datasets as defined in the orginal paper

Usage

usage: main.py [-h] --dataset DATASET --dataroot DATAROOT [--workers WORKERS]
               [--batch-size BATCH_SIZE] [--image-size IMAGE_SIZE] [--nc NC]
               [--nz NZ] [--epochs EPOCHS] [--lr LR] [--beta1 BETA1]
               [--beta2 BETA2] [--cuda] [--ngpu NGPU] [--gpu-id GPU_ID]
               [--netGx NETGX] [--netGz NETGZ] [--netDz NETDZ] [--netDx NETDX]
               [--netDxz NETDXZ] [--clamp_lower CLAMP_LOWER]
               [--clamp_upper CLAMP_UPPER] [--experiment EXPERIMENT]

optional arguments:
  -h, --help            show this help message and exit
  --dataset DATASET     cifar10 | svhn | celeba
  --dataroot DATAROOT   path to dataset
  --workers WORKERS     number of data loading workers
  --batch-size BATCH_SIZE
                        input batch size
  --image-size IMAGE_SIZE
                        the height / width of the input image to network
  --nc NC               input image channels
  --nz NZ               size of the latent z vector
  --epochs EPOCHS       number of epochs to train for
  --lr LR               learning rate for optimizer, default=0.00005
  --beta1 BETA1         beta1 for adam. default=0.5
  --beta2 BETA2         beta2 for adam. default=0.999
  --cuda                enables cuda
  --ngpu NGPU           number of GPUs to use
  --gpu-id GPU_ID       id(s) for CUDA_VISIBLE_DEVICES
  --netGx NETGX         path to netGx (to continue training)
  --netGz NETGZ         path to netGz (to continue training)
  --netDz NETDZ         path to netDz (to continue training)
  --netDx NETDX         path to netDx (to continue training)
  --netDxz NETDXZ       path to netDxz (to continue training)
  --clamp_lower CLAMP_LOWER
  --clamp_upper CLAMP_UPPER
  --experiment EXPERIMENT
                        Where to store samples and models

Example

command line example for training SVHN

python main.py --dataset svhn --dataroot . --experiment svhn_ali --cuda --ngpu 1 --gpu-id 1 --batch-size 100 --epochs 100 --image-size 32 --nz 256 --lr 1e-4 --beta1 0.5 --beta2 10e-3

Cite

@article{DBLP:journals/corr/DumoulinBPLAMC16,
  author    = {Vincent Dumoulin and
               Ishmael Belghazi and
               Ben Poole and
               Alex Lamb and
               Mart{\'{\i}}n Arjovsky and
               Olivier Mastropietro and
               Aaron C. Courville},
  title     = {Adversarially Learned Inference},
  journal   = {CoRR},
  volume    = {abs/1606.00704},
  year      = {2016},
  url       = {http://arxiv.org/abs/1606.00704},
}

ali-pytorch's People

Contributors

edgarriba avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar

ali-pytorch's Issues

Bug found for gradient explode

I've found the reason why the gradient of your ALI explode,
Firstly, use the default parameter of Adam, that is: betas=(0.9, 0.999)
Secondly, the shape of epsilon shall be same as Z.
Thank you!

Function Headers

Hi! Thanks so much for posting this code. This isn't so much an issue as it is a request: Would it be possible to write some function headers and variable definitions? You have some, which is appreciated, but those remaining are a bit confusing to me. Thanks in advance!

Example Run from README leads to infinite loss error after a few iterations

Having checked my GPU setup and gotten the code to download the dataset and run, I get this

[0/100][0/733] Loss_D: 1.38686203956604 Loss_G: 1.3840898275375366
[0/100][1/733] Loss_D: 1.385105848312378 Loss_G: 1.3885846138000488
[0/100][2/733] Loss_D: 1.3828593492507935 Loss_G: 1.387503981590271
[0/100][3/733] Loss_D: 1.3838441371917725 Loss_G: 1.3841853141784668
[0/100][4/733] Loss_D: 1.6616134643554688 Loss_G: 1.4784884452819824
[0/100][5/733] Loss_D: 18.7115478515625 Loss_G: 14.966793060302734
Traceback (most recent call last):
File "main.py", line 328, in <module>
# call train/test routines
File "main.py", line 254, in train
D_loss = compute_loss(batch_size, d_loss=True)
File "main.py", line 170, in compute_loss
loss = torch.mean(softplus(-data_preds) + softplus(sample_preds))
RuntimeError: value cannot be converted to type double without overflow: inf

This is running the example cli command given in the README

Loss goes to infinity on running with default setup

I just cloned the repo and ran the main.py with the parameters, as listed in README and got the following training progress:

[0/100][0/733] Loss_D: 1.38627135754 Loss_G: 1.38978719711
[0/100][1/733] Loss_D: 1.38632798195 Loss_G: 1.38739454746
[0/100][2/733] Loss_D: 1.38354945183 Loss_G: 1.38604617119
[0/100][3/733] Loss_D: 1.38877046108 Loss_G: 1.38658261299
[0/100][4/733] Loss_D: 1.42622494698 Loss_G: 1.35285282135
[0/100][5/733] Loss_D: 3.73983073235 Loss_G: 2.93968224525
[0/100][6/733] Loss_D: inf Loss_G: 19.790802002
[0/100][7/733] Loss_D: inf Loss_G: inf
[0/100][8/733] Loss_D: nan Loss_G: nan

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.