Git Product home page Git Product logo

pggan's Introduction

PGGAN - Pytorch Implementation

PyTorch Implementation of Karras et al., "Progressive Growing of GANs for Improved Quality, Stability, and Variation" (ICLR 2018).

Getting Started

Dataset

37345 Korean front face images were used for train. Images were resized to 2**(scale_index+2) x 2**(scale_index+2) before passed to the model as input.

Installation

  • numpy
  • opencv-python
  • Pillow
  • torch
  • torchvision
  • urllib3
  • wandb

To install the dependencies run:

pip install -r requirements.txt

Checkpoint

settings for training:

  • add here!

Usage

Train

python train.py --run_id simple_test

If you want to load a checkpoint and retrain it, use --ckpt_id and --ckpt_step

python train.py --run_id simple_test --ckpt_id={PATH/TO/CKPT} --ckpt_step {STEP}

If you want to use multi GPUs, add --use_mGPU
If you want to use wandb, add --use_wandb

Generate

python demo.py

Overview

Progressive Growing of GANs

model_architecture Starting from low resolution image generating network, progressively add new blocks with larger scale to both generator and discriminator. In our code, for every scale_jump_step, both model.G and model.D call add_block function.

Smooth Resolution Transition

blending To fade the new blocks smooothly, there are weighted residual connections between layers. The generator upscales the previous block feature map by 2, then blends it with the new block feature map in the RGB domain.

for i, block in enumerate(self.blocks, 0):
    x = block(x)
  
    # Lower scale RGB image
    if self.alpha > 0 and i == (len(self.blocks) - 2):
        x_up = self.toRGB_blocks[-2](x, apply_upscale=True)

# Current scale RGB image
x = self.toRGB_blocks[-1](x)

# Blend!
if self.alpha > 0:
    x = self.alpha * x + (1.0 - self.alpha) * x_up

On the other hand, the discriminator downscales the current RGB input by 2 and pass it to the previous block. In this case, blending occurs in the feature domain.

# Lower scale features
if self.alpha > 0 and len(self.fromRGB_blocks) > 1:
    x_down = self.fromRGB_blocks[-2](x, apply_downscale=True)

# Current scale features
x = self.fromRGB_blocks[-1](x)

apply_merge = self.alpha > 0 and len(self.blocks) > 1
for block in reversed(self.blocks):
    x = block(x)

    # Blend!
    if apply_merge:
        apply_merge = False
        x = self.alpha * x + (1 - self.alpha) * x_down

Objectives

WGAN-GP loss is used. Both generator and discriminator are optimized per every minibatch. In addition, drift loss, which is used to keep the discriminator output from drifting too far away from 0, is added to the discriminator loss.

Versions

Issues

Wrong Blending

Immediately after the scale jump, meaningless images are generated.

Before:
before
After:
after

This is because of the miswritten blending equations in both generator and discriminator. Keep in mind that immediately after the scale jump, the alpha is set to 0. Now the equations are fixed.

# Before
if self.alpha > 0:
            x = self.alpha * x_up + (1.0 - self.alpha) * x
# After
if self.alpha > 0:
            x = (1.0 - self.alpha) * x_up + self.alpha * x

TO DO

  • implement fourth term of discriminator loss
  • implement test.py
  • upload requirements.txt
  • upload checkpoint and sample output
  • fix checkpoint loading part to automatically set scale & alpha jump related variables

Authors

Acknowledgements

pggan's People

Contributors

yukyeongleee avatar 1zong2 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.