Git Product home page Git Product logo

wavediff's Introduction

Table of contents
  1. Installation
  2. Dataset preparation
  3. How to run
  4. Results
  5. Evaluation
  6. Acknowledgments
  7. Contacts

Official PyTorch implementation of "Wavelet Diffusion Models are fast and scalable Image Generators" (CVPR'23)

Hao Phung·Quan Dao·Anh Tran

VinAI Research

[Paper]    [Poster]    [Slides]    [Video]

teaser

WaveDiff is a novel wavelet-based diffusion scheme that employs low-and-high frequency components of wavelet subbands from both image and feature levels. These are adaptively implemented to accelerate the sampling process while maintaining good generation quality. Experimental results on CelebA-HQ, CIFAR-10, LSUN-Church, and STL-10 datasets show that WaveDiff provides state-of-the-art training and inference speed, which serves as a stepping-stone to offering real-time and high-fidelity diffusion model.

Details of the model architecture and experimental results can be found in our following paper:

@InProceedings{phung2023wavediff,
    author    = {Phung, Hao and Dao, Quan and Tran, Anh},
    title     = {Wavelet Diffusion Models Are Fast and Scalable Image Generators},
    booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
    month     = {June},
    year      = {2023},
    pages     = {10199-10208}
}

Please CITE our paper whenever this repository is used to help produce published results or incorporated into other software.

Installation

Python 3.7.13 and Pytorch 1.10.0 are used in this implementation.

It is recommended to create conda env from our provided environment.yml:

conda env create -f environment.yml
conda activate wavediff

Or you can install neccessary libraries as follows:

pip install -r requirements.txt

For pytorch_wavelets, please follow here.

Dataset preparation

We trained on four datasets, including CIFAR10, STL10, LSUN Church Outdoor 256 and CelebA HQ (256 & 512).

For CIFAR10 and STL10, they will be automatically downloaded in the first time execution.

For CelebA HQ (256) and LSUN, please check out here for dataset preparation.

For CelebA HQ (512 & 1024), please download two zip files: data512x512.zip and data1024x1024.zip and then generate LMDB format dataset by Torch Toolbox.

Those two links of high-res data seem to be broken so we provide our processed lmdb files at here.

Once a dataset is downloaded, please put it in data/ directory as follows:

data/
├── STL-10
├── celeba
├── celeba_512
├── celeba_1024
├── cifar-10
└── lsun

How to run

We provide a bash script for our experiments on different datasets. The syntax is following:

bash run.sh <DATASET> <MODE> <#GPUS>

where:

  • <DATASET>: cifar10, stl10, celeba_256, celeba_512, celeba_1024, and lsun.
  • <MODE>: train and test.
  • <#GPUS>: the number of gpus (e.g. 1, 2, 4, 8).

Note, please set argument --exp correspondingly for both train and test mode. All of detailed configurations are well set in run.sh.

GPU allocation: Our work is experimented on NVIDIA 40GB A100 GPUs. For train mode, we use a single GPU for CIFAR10 and STL10, 2 GPUs for CelebA-HQ 256, 4 GPUs for LSUN, and 8 GPUs for CelebA-HQ 512 & 1024. For test mode, only a single GPU is required for all experiments.

Results

Model performance and pretrained checkpoints are provided as below:

Model FID Recall Time (s) Checkpoints
CIFAR-10 4.01 0.55 0.08 netG_1300.pth
STL-10 12.93 0.41 0.38 netG_600.pth
CelebA-HQ (256 x 256) 5.94 0.37 0.79 netG_475.pth
CelebA-HQ (512 x 512) 6.40 0.35 0.59 netG_350.pth
LSUN Church 5.06 0.40 1.54 netG_400.pth
CelebA-HQ (1024 x 1024) 5.98 0.39 0.59 netG_350.pth

Inference time is computed over 300 trials on a single NVIDIA A100 GPU for a batch size of 100, except for the one of high-resolution CelebA-HQ (512 & 1024) is computed for a batch of 25 samples.

Downloaded pre-trained models should be put in saved_info/wdd_gan/<DATASET>/<EXP> directory where <DATASET> is defined in How to run section and <EXP> corresponds to the folder name of pre-trained checkpoints.

Evaluation

Inference

Samples can be generated by calling run.sh with test mode.

FID

To compute fid of pretrained models at a specific epoch, we can add additional arguments including --compute_fid and --real_img_dir /path/to/real/images of the corresponding experiments in run.sh.

Recall

We adopt the official Pytorch implementation of StyleGAN2-ADA to compute Recall of generated samples.

Acknowledgments

Thanks to Xiao et al for releasing their official implementation of the DDGAN paper. For wavelet transformations, we utilize implementations from WaveCNet and pytorch_wavelets.

Contacts

If you have any problems, please open an issue in this repository or ping an email to [email protected].

wavediff's People

Contributors

hao-pt avatar quandao10 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

wavediff's Issues

The result contains serious noise

I tried to reproduce the effect in the paper, I did not change any parameters, but the result contains serious noise. Does anyone reproduce the effect cleanly, or what can I do to improve it.

checkpoints

Hi, I could not download checkpoints.Could you please share them with google drive?Thanks!

FID computation (test_wddgan.py)

I'm trying to run the test code test_wddgan.py and have encountered a problem. I tested the pretrained checkpoint netG_475.pth (celeba_256). It reported the error below when loading celebahq_stat.npy. I also met this problem when I tried to compute the FID score with the pretrained checkpoint of celeba_512.

array = pickle.load(fp, **pickle_kwargs)
_pickle.UnpicklingError: pickle data was truncated

Then I tried to recalculate the npz file of celeba_256 with scripts/precompute_fid_statistics.py of NVAE and test the negG_475.pth. The npz file could be loaded but the FID result of the model turned out abnormally large. I wonder how I should deal with the error Thanks!

System configuration issue

Thanks' for sharing your work.

I am trying to run this code on cifar10 but it seems to be error in system configuration. I have 2 questions.

Q1. On my system the cuda version 12.2 is already install and now i am making new environment as instruction given in read.me . It create any problem?

Q2. The environment is successfully created but when i try to run on cifar10, it pop-up some error like "/.cache/torch_extensions/py37_cu102/fused/fused.so: cannot open shared object file" . How I can solve this issue?

week4

Running celeba_256 data on multiple gpus throws an error

Hi,
When I try to run the celeba_256 on 2 GPUs, I get the following error:

Node rank 0, local proc 0, global proc 0
Node rank 0, local proc 1, global proc 1
terminate called after throwing an instance of 'std::system_error'
  what():  Connection refused

However, the same code works on a single GPU. Additionally for the menthod mentioned in the paper, dataset celeba_256 can be trained using 1 GPU for wddggan but for ddgan the same runs out of memory. Can you please provide any details how this is executed across multiple GPUs.

Thanks.

task for super-resolution

Thanks for your work and I have reconduct this model. Now I want to do it for super-resolution task, Can I need to increase the step T? Can I need to replace the adaptive groupnorm with goupnorm because the latent embedding z is for diversity which may not suit for super-resolution task? There are my thoughts but I do not if it is true.

train loss on cifar10 dataset is nan

Thanks for your work!
Why is the loss training on cifar10 dataset nan? I just download the code and run the script (bash run.sh cifar10 train 1)
image

DWT and IWT

Hi sir,Thanks for your work. What is the difference between the DWT and IWT in WaveCNet and pytorch_wavelets? Which package you are using?

ninja problem

When I use the pretrained weights to test the result by a GPU on celeba_256 dataset, it occurs following preblem. I can not solve it, can you give me some advice?

Traceback (most recent call last):
File "/home/wbx/anaconda3/envs/wavediff/lib/python3.7/site-packages/torch/utils/cpp_extension.py", line 1723, in _run_ninja_build
env=env)
File "/home/wbx/anaconda3/envs/wavediff/lib/python3.7/subprocess.py", line 512, in run
output=stdout, stderr=stderr)
subprocess.CalledProcessError: Command '['ninja', '-v']' returned non-zero exit status 1.

About timesteps setting for train and test

Thanks for your excellent work! The sampling speed is amazing and really useful for other researchers to follow.

However, I still have several questions about the timesteps setting:

Q1: Why the number of timesteps for all datasets is so small (2 or 4)? As far as I know, the number of training timesteps for many diffusion models is set to hundreds of steps. During the sampling process, various acceleration algorithms (for example, DDIM) will be used to achieve sampling within tens of steps.

Q2: Does the setting of fewer steps weaken the advantages of the diffusion models over other generation models? For example, an extreme case is that the number of timesteps is 1. The diffusion model, in this case, looks pretty similar to StyleGAN.

Q3: Can the proposed Wavelet Diffusion Models work if I set the number of timesteps as 1000? I test the sampling speed when setting the sampling timesteps as 100, and the sampling speed is also quick enough so that I do not think the fewer sampling steps are the key point for the incredible sampling speed.

The reason why I pay so much attention to the number of timesteps is that there are many algorithms based on editing intermediate results generated by diffusion models. The proposed Wavelet Diffusion Models has obvious sampling speed advantages but is limited by the number of timesteps. It seems that the algorithm based on editing intermediate results cannot be effectively used in the Wavelet Diffusion Models.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.