Git Product home page Git Product logo

ssgan-tensorflow's Introduction

Semi-supervised learning GAN in Tensorflow

Descriptions

This is my Tensorflow implementation of Semi-supervised Learning Generative Adversarial Networks proposed in the paper Improved Techniques for Training GANs. The goal of this work is exploiting the samples generated by GAN generators to boost the performance of image classification tasks by improving generalization.

In sum, the main idea is training a network playing both the roles of a classifier performing image classification task as well as a discriminator trained to distinguish samples from the generator distribution from real data. To be more specific, the discriminator/classifier takes an image as input and classified it into n+1 classes, where n is the number of classes of a classification task. True samples are classified into the first n classes and generated samples are classified into the n+1-th class, as shown in the figure below.

The loss of this multi-task learning framework can be decomposed into the supervised loss

,

and the GAN loss of a discriminator

,

During the training phase, we jointly minimize the total loss obtained by simply combining the two losses together.

The implemented model is trained and tested on three publicly available datasets: MNIST, SVHN, and CIFAR-10.

Note that this implementation only follows the main idea of the original paper while differing a lot in implementation details such as model architectures, hyperparameters, applied optimizer, etc. Also, some useful training tricks applied to this implementation are stated at the end of this README.

*This code is still being developed and subject to change.

Prerequisites

Usage

Download datasets with:

$ python download.py --dataset MNIST SVHN CIFAR10

Train models with downloaded datasets:

$ python trainer.py --dataset MNIST
$ python trainer.py --dataset SVHN
$ python trainer.py --dataset CIFAR10

Test models with saved checkpoints:

$ python evaler.py --dataset MNIST --checkpoint ckpt_dir
$ python evaler.py --dataset SVHN --checkpoint ckpt_dir
$ python evaler.py --dataset CIFAR10 --checkpoint ckpt_dir

Train and test your own datasets:

  • Create a directory
$ mkdir datasets/YOUR_DATASET
  • Store your data as an h5py file datasets/YOUR_DATASET/data.hy and each data point contains
    • 'image': has shape [h, w, c], where c is the number of channels (grayscale images: 1, color images: 3)
    • 'label': represented as an one-hot vector
  • Maintain a list datasets/YOUR_DATASET/id.txt listing ids of all data points
  • Modify trainer.py including args, data_info, etc.
  • Finally, train and test models:
$ python trainer.py --dataset YOUR_DATASET
$ python evaler.py --dataset YOUR_DATASET

Results

MNIST

  • Generated samples (100th epochs)

  • First 40 epochs

SVHN

  • Generated samples (100th epochs)

  • First 160 epochs

CIFAR-10

  • Generated samples (1000th epochs)

  • First 200 epochs

Training details

MNIST

  • The supervised loss

  • The loss of Discriminator

D_loss_real

D_loss_fake

D_loss (total loss)

  • The loss of Generator

G_loss

  • Classification accuracy

SVHN

  • The supervised loss

  • The loss of Discriminator

D_loss_real

D_loss_fake

D_loss (total loss)

  • The loss of Generator

G_loss

  • Classification accuracy

CIFAR-10

  • The supervised loss

  • The loss of Discriminator

D_loss_real

D_loss_fake

D_loss (total loss)

  • The loss of Generator

G_loss

  • Classification accuracy

Training tricks

  • To avoid the fast convergence of the discriminator network
    • The generator network is updated more frequently.
    • Higher learning rate is applied to the training of the generator.
  • One-sided label smoothing is applied to the positive labels.
  • Gradient clipping trick is applied to stablize training
  • Reconstruction loss with an annealed weight is applied as an auxiliary loss to help the generator get rid of the initial local minimum.
  • Utilize Adam optimizer with higher momentum.
  • Please refer to the codes for more details.

Related works

Author

Shao-Hua Sun / @shaohua0116

Acknowledgement

@wookayin for the code structure

@carpedm20 for the README format

ssgan-tensorflow's People

Contributors

shaohua0116 avatar

Watchers

 avatar James Cloos avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.