Git Product home page Git Product logo

tlatkowski / gans-2.0 Goto Github PK

View Code? Open in Web Editor NEW
76.0 4.0 14.0 78.53 MB

Generative Adversarial Networks in TensorFlow 2.0

License: MIT License

Shell 0.58% Python 92.90% Jupyter Notebook 6.52%
generative-adversarial-network gan tensorflow tensorflow-2 tensorflow2 python3 vanilla-gan conditional-gan mnist mnist-classification fashion-mnist fashionmnist cifar10 cifar-10 deep-learning python tensorflow-examples tensorflow-models cyclegan style-transfer

gans-2.0's Introduction

Build Status codecov

GANs 2.0: Generative Adversarial Networks in TensorFlow 2.0

Project aim

The main aim of this project is to speed up a process of building deep learning pipelines that are based on Generative Adversarial Networks and simplify prototyping of various generator/discriminator models. This library provides several GAN trainers that can be used as off-the-shelf features such us:

  • Vanilla GAN
  • Conditional GAN
  • Cycle GAN
  • Wasserstein GAN
  • Progressive GAN (WIP)

Examples

Function modeling

Vanilla GAN (Gaussian function) Vanilla GAN (sigmoid function)
vanilla_mnist conditional_mnist

Image generation

Vanilla GAN (MNIST) Conditional GAN (MNIST)
vanilla_mnist conditional_mnist
Vanilla GAN (FASHION_MNIST) Conditional GAN (FASHION_MNIST)
vanilla_fashion_mnist conditional_fashion_mnist
Vanilla GAN (CIFAR10) Conditional GAN (CIFAR10)
vanilla_cifar10 conditional_cifar10

Image translation

Cycle GAN (SUMMER2WINTER) Cycle GAN (WINTER2SUMMER)
cycle_s2w cycle_w2s

Installation

Installs with GPU support

pip install gans2[tensorflow_gpu]

Installs with CPU support

pip install gans2[tensorflow]

Running training pipeline code examples for Vanilla GAN for MNIST digit generation

Pre-defined models

import tensorflow as tf
from easydict import EasyDict as edict

from gans.datasets import mnist
from gans.models.discriminators import discriminator
from gans.models.generators.latent_to_image import latent_to_image
from gans.trainers import optimizers
from gans.trainers import vanilla_gan_trainer

model_parameters = edict({
    'img_height':                  28,
    'img_width':                   28,
    'num_channels':                1,
    'batch_size':                  16,
    'num_epochs':                  10,
    'buffer_size':                 1000,
    'latent_size':                 100,
    'learning_rate_generator':     0.0001,
    'learning_rate_discriminator': 0.0001,
    'save_images_every_n_steps':   10
})

dataset = mnist.MnistDataset(model_parameters)

generator = latent_to_image.LatentToImageGenerator(model_parameters)
discriminator = discriminator.Discriminator(model_parameters)

generator_optimizer = optimizers.Adam(
    learning_rate=model_parameters.learning_rate_generator,
    beta_1=0.5,
)
discriminator_optimizer = optimizers.Adam(
    learning_rate=model_parameters.learning_rate_discriminator,
    beta_1=0.5,
)

gan_trainer = vanilla_gan_trainer.VanillaGANTrainer(
    batch_size=model_parameters.batch_size,
    generator=generator,
    discriminator=discriminator,
    training_name='VANILLA_GAN_MNIST',
    generator_optimizer=generator_optimizer,
    discriminator_optimizer=discriminator_optimizer,
    latent_size=model_parameters.latent_size,
    continue_training=False,
    save_images_every_n_steps=model_parameters.save_images_every_n_steps,
    visualization_type='image',
)

gan_trainer.train(
    dataset=dataset,
    num_epochs=model_parameters.num_epochs,
)

Custom models

import tensorflow as tf
from easydict import EasyDict as edict
from tensorflow.python import keras
from tensorflow.python.keras import layers

from gans.datasets import mnist
from gans.models import sequential
from gans.trainers import optimizers
from gans.trainers import vanilla_gan_trainer

model_parameters = edict({
    'img_height':                  28,
    'img_width':                   28,
    'num_channels':                1,
    'batch_size':                  16,
    'num_epochs':                  10,
    'buffer_size':                 1000,
    'latent_size':                 100,
    'learning_rate_generator':     0.0001,
    'learning_rate_discriminator': 0.0001,
    'save_images_every_n_steps':   10
})

dataset = mnist.MnistDataset(model_parameters)

generator = sequential.SequentialModel(
    layers=[
        keras.Input(shape=[model_parameters.latent_size]),
        layers.Dense(units=7 * 7 * 256, use_bias=False),
        layers.BatchNormalization(),
        layers.LeakyReLU(),

        layers.Reshape((7, 7, 256)),
        layers.Conv2DTranspose(128, (5, 5), strides=(1, 1), padding='same', use_bias=False),
        layers.BatchNormalization(),
        layers.LeakyReLU(),

        layers.Conv2DTranspose(64, (5, 5), strides=(2, 2), padding='same', use_bias=False),
        layers.BatchNormalization(),
        layers.LeakyReLU(),

        layers.Conv2DTranspose(1, (5, 5), strides=(2, 2), padding='same', use_bias=False, activation='tanh')
    ]
)

discriminator = sequential.SequentialModel(
    [
        keras.Input(
            shape=[
                model_parameters.img_height,
                model_parameters.img_width,
                model_parameters.num_channels,
            ]),
        layers.Conv2D(filters=64, kernel_size=(5, 5), strides=(2, 2), padding='same'),
        layers.LeakyReLU(),
        layers.Dropout(0.3),

        layers.Conv2D(filters=128, kernel_size=(5, 5), strides=(2, 2), padding='same'),
        layers.LeakyReLU(),
        layers.Dropout(rate=0.3),

        layers.Flatten(),
        layers.Dense(units=1),
    ]
)

generator_optimizer = optimizers.Adam(
    learning_rate=model_parameters.learning_rate_generator,
    beta_1=0.5,
)
discriminator_optimizer = optimizers.Adam(
    learning_rate=model_parameters.learning_rate_discriminator,
    beta_1=0.5,
)

gan_trainer = vanilla_gan_trainer.VanillaGANTrainer(
    batch_size=model_parameters.batch_size,
    generator=generator,
    discriminator=discriminator,
    training_name='VANILLA_GAN_MNIST_CUSTOM_MODELS',
    generator_optimizer=generator_optimizer,
    discriminator_optimizer=discriminator_optimizer,
    latent_size=model_parameters.latent_size,
    continue_training=False,
    save_images_every_n_steps=model_parameters.save_images_every_n_steps,
    visualization_type='image',
)

gan_trainer.train(
    dataset=dataset,
    num_epochs=model_parameters.num_epochs,
)

More code examples

Vanilla GAN for Gaussian function modeling

Vanilla GAN for sigmoid function modeling

Conditional GAN for MNIST digit generation

Cycle GAN for summer2winter/winter2summer style transfer

Wasserstein GAN for MNIST digit generatio

Monitoring model training

In order to visualize a training process (loss values, generated outputs) run the following command in the project directory:

tensorboard --logdir outputs

To follow the training process go to the default browser and type the following address http://your-workstation-name:6006/

The below picture presents the TensorBoard view lunched for two experiments: VANILLA_MNIST and VANILLA_FASION_MNIST.

References

  1. Deep Convolutional Generative Adversarial Network Tutorial in TensorFlow
  2. Cycle GAN Tutorial in TensorFlow
  3. Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks paper

gans-2.0's People

Contributors

dependabot[bot] avatar tlatkowski avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar

gans-2.0's Issues

Errors in the examples

I'm not able to run the code under conda.

I tried installing the module with pip install gans2[tensorflow] but gans nor gans2 modules are available. So I downloaded the code from GitHub but the example "pre-defined models" crashes because of several errors:

  1. TypeError: __init__() got an unexpected keyword argument 'visualization_type' (fixed by commenting the line out),
  2. TypeError: __init__() missing 1 required positional argument: 'validation_dataset' (fixed by changing several files' content, mainly gans/datasets/mnist.py),
  3. W tensorflow/core/framework/op_kernel.cc:1655] OP_REQUIRES failed at summary_kernels.cc:57 : Not found: Failed to create a directory: ./outputs\VANILLA_GAN_MNIST; No such file or directory, I gave up on that one and didn't check what causes it.

2 questions:

  1. How should one install the code in Conda so the module is visible?
  2. Could you please provide an example of code that doesn't break?

Missing file optimizer.py in pip install

To reproduce:
Install gans2

pip install gans2

attempt to import optimizer:

from gans.trainers import optimizers

Expected result:

Actual result:

ImportError: cannot import name 'optimizers' from 'gans.trainers' (c:\users\**PATH**\lib\site-packages\gans\trainers\__init__.py)

Question about Conv2d + ConvTranspose2d

Dear @tlatkowski, I was looking at your code about CIFAR-10 conditional DCGAN and It is the best I've found, even though all your code is written using TensorFlow and I only have experience with PyTorch. Anyway, I found quite strange to see the usage of both convolution and transposed convolution layers in the Generator model. Can you explain me what led you to this choice?
Thanks again for your great work.

WARNING WITHOUT RESULT

one I run the code, I got this warning , and no result.

WARNING:tensorflow:Entity <bound method VanillaGANTrainer.train_step of <tensorflow.python.eager.function.TfMethodTarget object at 0x7eff3c4cada0>> could not be transformed and will be staged without change. Error details can be found in the logs when running with the env variable AUTOGRAPH_VERBOSITY >= 1. Please report this to the AutoGraph team. Cause: Unexpected error transforming <bound method VanillaGANTrainer.train_step of <tensorflow.python.eager.function.TfMethodTarget object at 0x7eff3c4cada0>>. If you believe this is due to a bug, please set the verbosity to 10 (on Linux, export AUTOGRAPH_VERBOSITY=10) and attach the full output when filing the bug report. Caused by: Bad argument number for Name: 3, expecting 4

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.