Git Product home page Git Product logo

gans's Introduction

gans

This is a light weight framework on top of gans for tooling Generative Adversarial Networks(GANs) for TF2.0 supporting eager execution and keras tooling.

Structure of GANs library

  • models - One of the core modules to design the architecture for the gans models, be it generator or discriminator.
  • losses - This module describes losses and penalties for gans such as wasserstien, dcgan and lsgans
  • core - This is one of the core module which ensures all the pieces are tied together for the models. This is the module which holds all the basic training ops and their extensions based on losses.

How to train a model?

Training a GANs model consisit of following steps

  1. Specify the inputs for your models and morph them into tf.data.Dataset instance for feeding it to trainer.
  2. Defining a generator model and discriminative model. Currently, we provide abstraction for convolutional models only i.e., ConvolutionalGeneratorModel and ConvolutionalDiscriminatorModel. You can use any other model too which is derived from tf.keras.Model.
  3. Select your optimizers for training ops from tf.optimizers
  4. Based on your loss function you can select any of the trainer. Currently we providing training using Wasserstien, DCGAN and LSGAN based learning strategies. If you want to create your own custom trainer then you can derive a class from BaseGANTrainer and override train, train_step and __init__ methods.

At this stage you can either use GANs classes for abstraction for your convenience or provision your own for fine-grained control.

Examples

This section provide some basic example to use this library for training your gans based ops.

  • Wasserstien GAN without any checkpoints or saving images after every iterations with default filters, strides and kernel_size
import tensorflow as tf # for tf related ops

from gans.models import ConvolutionalGeneratorModel, ConvolutionalDiscriminativeModel # importing models
from gans.train_ops import WasserstienGANTrainer # for perfroming training ops

# load your dataset here, here for sake of simplicity we are naming it as dataset

# instiantiating generator and discriminator models for the gans
generator = ConvolutionalGeneratorModel(filters=[256, 128, 128, 64],
                                        shape=(7, 7))
discriminator = ConvolutionalDiscriminativeModel(filters=[64, 128])

# defining optimizers
generator_optimizer = tf.optimizers.Adam(learning_rate=1e-4)
discriminator_optimizer = tf.optimizers.Adam(learning_rate=1e-4)

# defining trainer instance
trainer = WasserstienGANTrainer(
    generator_model=generator,
    discriminator_model=discriminator,
    generator_optimizer=generator_optimizer,
    discriminator_optimizer=discriminator_optimizer,
    save_images=False)

# perform training ops
trainer.train(dataset, batch_size=256, noise_dim=100, epochs=100)
  • DCGAN based learning for 3 dim images with saving checkpoints at every 10th epoch and saving generated images using a seed for vizualizing gans performance. Here we are going to use kernel_size of 3x3 and update the shape from 7x7 to 8x8 to generate images of 128x128 dims.
import tensorflow as tf  # for tf related ops

from gans.models import (  # import models for ops
    ConvolutionalDiscriminativeModel, ConvolutionalGeneratorModel)
from gans.train_ops import DCGANTrainer  # DCGAN trainer

# perfrom dataset ops here

# instiantiating generator and discriminator models for the gans
generator = ConvolutionalGeneratorModel(filters=[512, 256, 256, 128, 64, 32],
                                        shape=(8, 8),
                                        channel_dim=3,
                                        strides=2,
                                        kernel_size=(3, 3))
discriminator = ConvolutionalDiscriminativeModel(filters=[64, 64, 128],
                                                 kernel_size=(3, 3),
                                                 dropout_rate=0.4)

# defining optimizers
generator_optimizer = tf.optimizers.RMSprop(learning_rate=1e-4)
discriminator_optimizer = tf.optimizers.RMSprop(learning_rate=1e-4)

# defining trainer instance
trainer = DCGANTrainer(generator_model=generator,
                       discriminator_model=discriminator,
                       generator_optimizer=generator_optimizer,
                       discriminator_optimizer=discriminator_optimizer,
                       save_checkpoint_at=10,
                       checkpoint_dir="./training_checkpoints",
                       save_images=True)

# perform training ops
trainer.train(dataset, batch_size=256, noise_dim=100, epochs=100)
  • Using LSGAN based learning strategy to generate 28x28x1 dim images without any middle checkpoint and showing saved images in the process. Instead we will be saving the model at the end of all the training process.
import tensorflow as tf  # for tf related ops

from gans.models import (  # import models for ops
    ConvolutionalDiscriminativeModel, ConvolutionalGeneratorModel)
from gans.train_ops import LSGANTrainer  # LSTrainer trainer

# perfrom dataset ops here

# instiantiating generator and discriminator models for the gans
generator = ConvolutionalGeneratorModel(filters=[256, 128, 128],
                                        shape=(7, 7),
                                        kernel_size=(5, 5))
discriminator = ConvolutionalDiscriminativeModel(filters=[64, 128],
                                                 kernel_size=(5, 5))

# defining optimizers
generator_optimizer = tf.optimizers.RMSprop(learning_rate=1e-4)
discriminator_optimizer = tf.optimizers.RMSprop(learning_rate=1e-4)

# defining trainer instance
trainer = LSGANTrainer(generator_model=generator,
                       discriminator_model=discriminator,
                       generator_optimizer=generator_optimizer,
                       discriminator_optimizer=discriminator_optimizer,
                       save_images=True)

# perform training ops
trainer.train(dataset, batch_size=256, noise_dim=100, epochs=100, show_image=False)

# save the model after training under dir lsgan
trainer.save_checkpoint("lsgan")
  • Using LSGAN based learning strategy to generate 64x64x3รŸ dim images without any middle checkpoint and showing saved images in the process. Instead we will be saving the model at the end of all the training process. Also, Instead of using normal Discriminator we would be going with residual framework.
import tensorflow as tf  # for tf related ops

from gans.models import (  # import models for ops
    ConvolutionalDiscriminativeModel, ResidualConvolutionalDiscriminative)

from gans.train_ops import LSGANTrainer  # LSTrainer trainer

# perfrom dataset ops here

# instiantiating generator and discriminator models for the gans
generator = ConvolutionalGeneratorModel(filters=[512, 256, 128,64],
                                        shape=(8, 8),
                                        kernel_size=(5, 5), channel_dim=3)
discriminator = ResidualConvolutionalDiscriminative(
    filters=[64, 64, 128], block_type="basic",
    kernel_size=(5, 5))  # we can also use bottleneck
 
# defining optimizers
generator_optimizer = tf.optimizers.RMSprop(learning_rate=1e-4)
discriminator_optimizer = tf.optimizers.RMSprop(learning_rate=1e-4)

# defining trainer instance
trainer = LSGANTrainer(generator_model=generator,
                       discriminator_model=discriminator,
                       generator_optimizer=generator_optimizer,
                       discriminator_optimizer=discriminator_optimizer,
                       save_images=True)

# perform training ops
trainer.train(dataset, batch_size=256, noise_dim=100, epochs=100, show_image=False)

# save the model after training under dir lsgan
trainer.save_checkpoint("lsgan")

Data Loading Pipelines

This section shows an example for loading data pipelines

  • Using images stored in image directory with batch_size of 256, cache and prefetch.

Data must be stored in following fashion

images_dir
    |- image1.png
    |- image2.png

code for reteriving data from above directory

from gans.datapipeline.data_loader import FileDataLoader

import tensorflow as tf

#AutoTune defined
AUTOTUNE = tf.data.AUTOTUNE

# create an instance of Filedataloader
data_handler = FileDataLoader(path_to_images="image_dir",
                              image_extension="png",
                              image_dims=(64, 64),
                              image_channels=3)

# create a dataset of images in batch of 256, shuffled with prefetched and cache
data_handler.create_dataset(batch_size=256,
                            shuffle=True,
                            autotune=AUTOTUNE,
                            cache=True,
                            prefetch=AUTOTUNE)

Author

Sanidhya Mangal, [email protected], website

gans's People

Contributors

deepsourcebot avatar dependabot[bot] avatar sanidhyamangal avatar

Watchers

 avatar  avatar  avatar

gans's Issues

Trainer Module seems to be broken

Trainer module under namedtuple seems to be broken. It is not working as per the expectations. Looks like model is failing to converge and due to this all the images generated are black.

Wasserstien Loss not working as expected

Wasserstein Loss not working as expected.

TypeError                                 Traceback (most recent call last)

<ipython-input-37-852e80b41cf6> in <module>()
----> 1 trainer.train(train_dataset, 256,100)

9 frames

/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/func_graph.py in wrapper(*args, **kwargs)
    975           except Exception as e:  # pylint:disable=broad-except
    976             if hasattr(e, "ag_error_metadata"):
--> 977               raise e.ag_error_metadata.to_exception(e)
    978             else:
    979               raise

TypeError: in user code:

    /content/gans/namedtuples.py:112 train_step  *
        gen_loss = self.get_generator_loss(generated_output)
    /content/gans/namedtuples.py:67 get_generator_loss  *
        return self.generator_loss(logits)

    TypeError: tf__wasserstien_generator_loss() takes 1 positional argument but 2 were given

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.