Git Product home page Git Product logo

min-snr-diffusion-training's Introduction

Efficient Diffusion Training via Min-SNR Weighting Strategy

By Tiankai Hang, Shuyang Gu, Chen Li, Jianmin Bao, Dong Chen, Han Hu, Xin Geng, Baining Guo.

Paper | arXiv | Code

Abstract.

Denoising diffusion models have been a mainstream approach for image generation, however, training these models often suffers from slow convergence. In this paper, we discovered that the slow convergence is partly due to conflicting optimization directions between timesteps. To address this issue, we treat the diffusion training as a multi-task learning problem, and introduce a simple yet effective approach referred to as Min-SNR-$\gamma$. This method adapts loss weights of timesteps based on clamped signal-to-noise ratios, which effectively balances the conflicts among timesteps. Our results demonstrate a significant improvement in converging speed, 3.4x faster than previous weighting strategies. It is also more effective, achieving a new record FID score of 2.06 on the ImageNet 256x256 benchmark using smaller architectures than that employed in previous state-of-the-art.

News

Data Preparation

For CelebA dataset, we follow ScoreSDE to process the data.

For ImageNet dataset, we download it from the official website. For ImageNet-64, we did not adopt offline pre-processing. For ImageNet-256, we cropped the images to 256x256 and compressed them using AutoencoderKL from Diffusers. The compressed latent codes are treated equally as images, except the file extension.

Training

For training with ViT-B model, you should first put the downloaded/processed data above to some path, and set DATA_DIR in the config file vit-b_layer12_lr1e-4_099_099_pred_x0__min_snr_5__fp16_bs8x32.sh. Then you could run like

GPUS=8
BATCH_SIZE_PER_GPU=32
bash configs/in256/vit-b_layer12_lr1e-4_099_099_pred_x0__min_snr_5__fp16_bs8x32.sh $GPUS $BATCH_SIZE_PER_GPU

Sampling with Pre-trained Models

For sampling for ImageNet-256, you could directly run

bash configs/in256/inference.sh

Thanks to the sampling method from Applying Guidance in a Limited Interval Improves Sample and Distribution Quality in Diffusion Models, we achieve a new FID score of 1.57342 on the ImageNet 256x256 benchmark. You can run the following command

bash configs/in256/inference_limited_interval_guidance.sh

For sampling for ImageNet-64, you could directly run

bash configs/in64/inference.sh

Here we use 8 GPUs for sampling. You can change GPUS=8 to GPUS=1 for single GPU evaluation in configs/in256/inference.sh The pre-trained models will be automatically downloaded and FID-50K will be calculated.

Citing Min-SNR Diffusion Training

If you find our work useful for your research, please consider citing our paper. 😊

@InProceedings{Hang_2023_ICCV,
    author    = {Hang, Tiankai and Gu, Shuyang and Li, Chen and Bao, Jianmin and Chen, Dong and Hu, Han and Geng, Xin and Guo, Baining},
    title     = {Efficient Diffusion Training via Min-SNR Weighting Strategy},
    booktitle = {Proceedings of the IEEE/CVF International Conference on Computer Vision (ICCV)},
    month     = {October},
    year      = {2023},
    pages     = {7441-7451}
}

Acknowlegements

This repository is based on openai/guided-diffusion. We adopt the implementation for sampling and FID evaluation using NVlabs/edm.

min-snr-diffusion-training's People

Contributors

tiankaihang avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.