Git Product home page Git Product logo

style-aware-discriminator's Introduction

A Style-aware Discriminator for Controllable Image Translation

Kunhee Kim, Sanghun Park, Eunyeong Jeon, Taehun Kim, Daijin Kim
POSTECH

Our model discovers various style prototypes from the dataset in a self-supervised manner. The style prototype consists of a combination of various attributes including (left) time, weather, season, and texture; and (right) age, gender, and accessories.

Paper: https://arxiv.org/abs/2203.15375

Abstract: Current image-to-image translations do not control the output domain beyond the classes used during training, nor do they interpolate between different domains well, leading to implausible results. This limitation largely arises because labels do not consider the semantic distance. To mitigate such problems, we propose a style-aware discriminator that acts as a critic as well as a style encoder to provide conditions. The style-aware discriminator learns a controllable style space using prototype-based self-supervised learning and simultaneously guides the generator. Experiments on multiple datasets verify that the proposed model outperforms current state-of-the-art image-to-image translation methods. In contrast with current methods, the proposed approach supports various applications, including style interpolation, content transplantation, and local image translation.

Installation / Requirements

  • CUDA 10.1 or newer is required for the StyleGAN2-based model since it uses custom CUDA kernels of StyleGAN2 ported by @rosinality.
  • We mainly tested on Python 3.8 and 1.10.2 with cudatoolkit=11.3 (see environment.yml) with CUDA 11.2 for custom CUDA kernel.

Clone this repository:

git clone https://github.com/kunheek/style-aware-discriminator.git
cd style-aware-discriminator

Then, install dependencies using anaconda or pip:

conda env create -f environment.yml
# or
pip install -r requirements.txt

Testing and Evaluation

We provide the following pre-trained networks.

Dataset Resolution Method #images OneDrive link
afhq-adain AFHQ $256^2$ AdaIN 1.6 M afhq-adain.pt
afhq-stylegan2 AFHQ $256^2$ StyleGAN2 5 M afhq-stylegan2-5M.pt
afhqv2 AFHQ v2 $512^2$ StyleGAN2 5 M afhqv2-512x512-5M.pt
celebahq-adain CelebA-HQ $256^2$ AdaIN 1.6 M celebahq-adain.pt
celebahq-stylegan2 CelebA-HQ $256^2$ StyleGAN2 5 M celebahq-stylegan2-5M.pt
church LSUN church $256^2$ StyleGAN2 25 M church-25M.pt
ffhq FFHQ $256^2$ StyleGAN2 25 M ffhq-25M.pt
flower Oxford 102 $256^2$ AdaIN 1.6 M flower-256x256-adain.pt

Here are links to all checkpoints (checkpoints.zip) and MD5 file (checkpoints.md5). If you have wget and unzip in your environment, you can also download the checkpoints using the following command:

# download all checkpoints.
bash download.sh checkpoints
# download a specific checkpoint.
bash download.sh afhq-adain

See the table above or download.sh for available checkpoints.

Quantitative results

(Optional) Computing inception stats requires long time. We provide pre-calculated stats for AFHQ 256 and CelebA-HQ 256 datasets (link). You can download and register them using the following command:

bash download.sh stats
# python -m tools.register_stats PATH/TO/STATS
python -m tools.register_stats assets/stats

To evaluate our model run python -m metrics METRICS --checkpoint CKPT --train-dataset TRAINDIR --eval-dataset EVALDIR. By default, all metrics will be saved in runs/{run-dir}/metrics.txt. Available metrics are:

See metrics/{task}_evaluator.py for task specific options. You can parse multiple tasks at the same time. Here are some examples:

python -m metrics fid reconstruction --seed 123 --checkpoint ./checkpoints/afhq-stylegan2-5M.pt --train-dataset ./datasets/afhq/train --eval-dataset ./datasets/afhq/val

python -m metrics mean_fid --seed 777 --checkpoint ./checkpoints/celebahq-stylegan2-5M.pt --train-dataset ./datasets/celeba_hq/train --eval-dataset ./datasets/celeba_hq/val

Qualitative results

You can synthesize images similarly to the quantitave evaluations (replace metrics to synthesis). By default, all images will be saved in runs/{run-dir}/{task} folder.

# python -m synthesis [TASKS] --checkpoint PATH/TO/CKPT --folder PATH/TO/FOLDERS
python -m synthesis swap --checkpoint ./checkpoints/afhq-stylegan2-5M.pt --folder ./testphotos/afhq/content ./testphotos/afhq/style

python -m synthesis interpolation --checkpoint ./checkpoints/afhq-stylegan2-5M.pt --folder ./testphotos/afhq/content ./testphotos/afhq/style

Some tasks require multiple folders (e.g., content and style) or extra arguments. Available synthesis tasks are:

Additional tools

We provide additional tools for visualizing the learned style space:

  • plot_tsne: visualize the learned style space and prototypes using t-SNE.

python -m tools.plot_tsne --checkpoint checkpoints/afhq-stylegan2-5M.pt --target-dataset datasets/afhq/val --seed 7 --title AFHQ --legends cat dog wild

python -m tools.plot_tsne --checkpoint checkpoints/celebahq-stylegan2-5M.pt --target-dataset datasets/celeba_hq/val --seed 7 --title CelebA-HQ --legends female male
  • similarity_search: find samples that are most similar to the query (in the style space and the content space) in the target dataset.
python -m tools.similarity_search --checkpoint CKPT --query QUERY_IMAGE --target-dataset TESTDIR

Training

Datasets

By default, all images in the folder will be used for training or evaluation (supported image formats can be found here). For example, if you parse --train-dataset=./datasets/afhq/train, all images in the ./datasets/afhq/train folder will be used for training.
For LSUN datasets, lsun must be included in the folder path.

datasets
└─ lsun
   ├─ church_outdoor_train_lmdb
   └─ church_outdoor_val_lmdb

To measure mean fid, a subdirectory corresponding to each class must exist (less than 5). If you want to reproduce experiments in the paper, we recommend to use the following structure:

datasets
├─ afhq
│  ├─ train
│  │  ├─ cat
│  │  ├─ dog
│  │  └─ wild
│  └─ val (or test)
│     └─ (cat/dog/wild)
└─ celeba_hq
   ├─ train
   │  ├─ female
   │  └─ male
   └─ val
      └─ (female/male)

Training scripts

Notice: We recommend training networks on a single GPU with enough memory (e.g., A100) to obtain best results, since we observed performance degradation with current implementation when using multiple GPUs (DDP). For example, a model trained on a A100 GPU (40GB) is slightly better than a model trained on two TITAN XP GPU (12GB * 2). We used a single NVIDIA A100 GPU for AFHQ and CelebA-HQ experiments and four NVIDIA RTX3090 GPUs for AFHQ v2, LSUN churches, and FFHQ experiments. Note that we disabled tf32 for all experiments.

We provide training scripts here. Use the following commands to train networks with custom arguments:

# Single GPU training.
python train.py --mod-type adain --total-nimg 1.6M --batch-size 16 --load-size 320 --crop-size 256 --image-size 256 --train-dataset datasets/afhq/train --eval-dataset datasets/afhq/val --out-dir runs --extra-desc some descriptions

# Multi-GPU training.
CUDA_VISIBLE_DEVICES=0,1,2,3 torchrun --nproc_per_node=4 train.py --total-nimg 25M --batch-size 64 --load-size 320 --crop-size 256 --image-size 256 --train-dataset datasets/ffhq/images1024x1024 --eval-dataset datasets/ffhq/images1024x1024 --nb-proto 128 --latent-dim 512 --latent-ratio 0.5 --jitter true --cutout true --out-dir runs --extra-desc some descriptions

Training options, codes, checkpoints, and snapshots will be saved in the {out-dir}/{run-id}-{dataset}-{resolution}-{extra-desc}. Please see train.py, model.py, and augmentation.py for available arguments.

To resume training, run python train.py --resume PATH/TO/RUNDIR. For example:

# Single GPU training.
python train.py --resume runs/000-afhq-256x256-some-discriptions

# Multi-GPU training.
CUDA_VISIBLE_DEVICES=0,1,2,3 torchrun --nproc_per_node=4 python train.py --resume runs/001-ffhq-some-discriptions

Citation

If you find this repository useful for your research, please cite our paper:

@InProceedings{kim2022style,
  title={A Style-Aware Discriminator for Controllable Image Translation},
  author={Kim, Kunhee and Park, Sanghun and Jeon, Eunyeong and Kim, Taehun and Kim, Daijin},
  booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
  year={2022},
  pages={18239--18248}
}

Acknowledgements

Many of our implementations are adapted from previous works, including SwAV, DINO, StarGAN v2, Swapping Autoencoder, clean-fid, and stylegan2-pytorch.

Licenses

All materials except custom CUDA kernels in this repository are made available under the MIT License.

The custom CUDA kernels (fused_bias_act_kernel.cu and upfirdn2d_kernel.cu) are under the Nvidia Source Code License, and are for non-commercial use only.

style-aware-discriminator's People

Contributors

kunheek avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.