Git Product home page Git Product logo

fixed-point-diffusion-models's Introduction

Contributors Forks Stargazers Issues

Fixed Point Diffusion Models

Project Page ยท Paper


DiT samples

Table of Contents

Roadmap

  • Code and paper release ๐ŸŽ‰๐ŸŽ‰
  • Jupyter notebook example
  • Pretrained model release (coming soon)
  • Code walkthrough and tutorial

Abstract

We introduce the Fixed Point Diffusion Model (FPDM), a novel approach to image generation that integrates the concept of fixed point solving into the framework of diffusion-based generative modeling. Our approach embeds an implicit fixed point solving layer into the denoising network of a diffusion model, transforming the diffusion process into a sequence of closely-related fixed point problems. Combined with a new stochastic training method, this approach significantly reduces model size, reduces memory usage, and accelerates training. Moreover, it enables the development of two new techniques to improve sampling efficiency: reallocating computation across timesteps and reusing fixed point solutions between timesteps. We conduct extensive experiments with state-of-the-art models on ImageNet, FFHQ, CelebA-HQ, and LSUN-Church, demonstrating substantial improvements in performance and efficiency. Compared to the state-of-the-art DiT model, FPDM contains 87% fewer parameters, consumes 60% less memory during training, and improves image generation quality in situations where sampling computation or time is limited.

Setup

We provide an environment.yml file that can be used to create a Conda environment. If you only want to run pre-trained models locally on CPU, you can remove the cudatoolkit and pytorch-cuda requirements from the file.

conda env create -f environment.yml
conda activate DiT

Model

Our model definition, including all fixed point functionality, is included in models.py.

Training

Example training scripts:

# Standard model
accelerate launch --config_file aconfs/1_node_1_gpu_ddp.yaml --num_processes 8 train.py

# Fixed Point Diffusion Model
accelerate launch --config_file aconfs/1_node_1_gpu_ddp.yaml --num_processes 8 train.py --fixed_point True --deq_pre_depth 1 --deq_post_depth 1

# With v-prediction and zero-SNR
accelerate launch --config_file aconfs/1_node_1_gpu_ddp.yaml --num_processes 8 train.py --output_subdir v_pred_exp --predict_v True --use_zero_terminal_snr True --fixed_point True --deq_pre_depth 1 --deq_post_depth 1

# With v-prediction and zero-SNR, with 4 pre- and post-layers
accelerate launch --config_file aconfs/1_node_1_gpu_ddp.yaml --num_processes 8 train.py --output_subdir v_pred_exp --predict_v True --use_zero_terminal_snr True --fixed_point True --deq_pre_depth 4 --deq_post_depth 4

Sampling

Example sampling scripts:

# Sample
python sample.py --ckpt {checkpoint-path-from-above} --fixed_point True --fixed_point_pre_depth 1 --fixed_point_post_depth 1 --cfg_scale 4.0 --num_sampling_steps 20

# Sample with fewer iterations per timestep and more timesteps
python sample.py --ckpt {checkpoint-path-from-above} --fixed_point True --fixed_point_pre_depth 1 --fixed_point_post_depth 1 --cfg_scale 4.0 --fixed_point_iters 12 --num_sampling_steps 40 --fixed_point_reuse_solution True

Contribution

Pull requests are welcome!

Acknowledgements

  • The strong baseline from DiT:

    @article{Peebles2022DiT,
    title={Scalable Diffusion Models with Transformers},
    author={William Peebles and Saining Xie},
    year={2022},
    journal={arXiv preprint arXiv:2212.09748},
    }
    
  • The fast-DiT code from chuanyangjin:

  • All the great work from the CMU Locus Lab on Deep Equilibrium Models, which started with:

    @inproceedings{bai2019deep,
    author    = {Shaojie Bai and J. Zico Kolter and Vladlen Koltun},
    title     = {Deep Equilibrium Models},
    booktitle = {Advances in Neural Information Processing Systems (NeurIPS)},
    year      = {2019},
    }
    
  • L.M.K. thanks the Rhodes Trust for their scholarship support.

fixed-point-diffusion-models's People

Contributors

lukemelas avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar

fixed-point-diffusion-models's Issues

Reproduction problem

Hi,

Thanks for your open-sourced code, I try to reproduce the results of fpdm on ImageNet, but I find the generated images are much worse than original fast-DiT model. Could you help me check if I run the code correctly?

I try to train model with fixed-point layers and disable zero_snr/v_pred, the arguments are as follow:

{
    "ckpt_every": 100000,
    "compile": false,
    "dataset_name": "imagenet256",
    "debug": false,
    "dino_supervised": false,
    "dino_supervised_dim": 768,
    "epochs": 1400,
    "feature_path": "/home/yiming/mnt_dataset/ImageNet/ILSVRC2012/vae_features/",
    "fixed_point": true,
    "fixed_point_no_grad_max_iters": 10,
    "fixed_point_no_grad_min_iters": 0,
    "fixed_point_post_depth": 1,
    "fixed_point_pre_depth": 1,
    "fixed_point_pre_post_timestep_conditioning": false,
    "fixed_point_with_grad_max_iters": 12,
    "fixed_point_with_grad_min_iters": 1,
    "flow": false,
    "global_batch_size": 512,
    "global_seed": 0,
    "image_size": 256,
    "log_every": 100,
    "log_with": "wandb",
    "lr": 0.0001,
    "model": "DiT-XL/2",
    "name": "006-DiT-XL-2--diff--0.000100--fixed_point-pre_depth-1-post_depth-1-no_grad_iters-00-10-with_grad_iters-01-12-pre_post_time_cond_False",
    "num_classes": 1000,
    "num_workers": 4,
    "output_dir": "results/runs/006-DiT-XL-2--diff--0.000100--fixed_point-pre_depth-1-post_depth-1-no_grad_iters-00-10-with_grad_iters-01-12-pre_post_time_cond_False",
    "output_subdir": "runs",
    "predict_v": false,
    "reproducibility": {
        "command_line": "python train.py --fixed_point True --fixed_point_pre_depth 1 --fixed_point_post_depth 1",
        "time": "Wed Feb 21 16:49:11 2024"
    },
    "resume": null,
    "unsupervised": false,
    "use_zero_terminal_snr": false
}

then test by running

python sample.py --image_size 256 --global_seed 1 --ckpt ./results/runs/006-DiT-XL-2--diff--0.000100--fixed_point-pre_depth-1-post_depth-1-no_grad_iters-00-10-with_grad_iters-01-12-pre_post_time_cond_False/checkpoints/0400000.pt  --sample_index_end 100 --fixed_point True --fixed_point_pre_depth 1 --fixed_point_post_depth 1 --cfg_scale 4.0 --num_sampling_steps 250

the testing arguments are as follows:

{
    "adaptive": false,
    "adaptive_type": "increasing",
    "batch_size": 32,
    "cfg_scale": 4.0,
    "ckpt": "/home/yiming/project/Acceleration/fixed-point-diffusion-models/results/runs/006-DiT-XL-2--diff--0.000100--fixed_point-pre_depth-1-post_depth-1-no_grad_iters-00-10-with_grad_iters-01-12-pre_post_time_cond_False/checkpoints/0400000.pt",
    "dataset_name": "imagenet256",
    "ddim": false,
    "debug": false,
    "dino_supervised": false,
    "dino_supervised_dim": 768,
    "fixed_point": true,
    "fixed_point_iters": 26,
    "fixed_point_post_depth": 1,
    "fixed_point_pre_depth": 1,
    "fixed_point_pre_post_timestep_conditioning": false,
    "fixed_point_reuse_solution": false,
    "flow": false,
    "global_seed": 1,
    "image_size": 256,
    "iteration_controller": null,
    "model": "DiT-XL/2",
    "num_classes": 1000,
    "num_sampling_steps": 250,
    "output_dir": "samples/006-DiT-XL-2--diff--0.000100--fixed_point-pre_depth-1-post_depth-1-no_grad_iters-00-10-with_grad_iters-01-12-pre_post_time_cond_False/num_sampling_steps-250--cfg_scale-4.0--fixed_point_iters-26--fixed_point_reuse_solution-False--fixed_point_pptc-False",
    "predict_v": false,
    "reproducibility": {
        "command_line": "python sample.py --image_size 256 --global_seed 1 --ckpt /home/yiming/project/Acceleration/fixed-point-diffusion-models/results/runs/006-DiT-XL-2--diff--0.000100--fixed_point-pre_depth-1-post_depth-1-no_grad_iters-00-10-with_grad_iters-01-12-pre_post_time_cond_False/checkpoints/0400000.pt --sample_index_end 100 --fixed_point True --fixed_point_pre_depth 1 --fixed_point_post_depth 1 --cfg_scale 4.0 --num_sampling_steps 250",
        "git_has_uncommitted_changes": true,
        "git_root": "/mnt/dongxu-fs2/data-ssd/yiming/project/Acceleration/fixed-point-diffusion-models",
        "git_url": "https://github.com/lukemelas/fixed-point-diffusion-models/tree/519e1286ba27c34e177e05962c5d9e66edce31e6",
        "time": "Fri Mar 22 12:54:29 2024"
    },
    "sample_index_end": 100,
    "sample_index_start": 0,
    "unsupervised": false,
    "use_zero_terminal_snr": false,
    "vae": "mse"
}

I take some samples:
00003--974--geyser
00004--088--macaw
00005--979--valley
00006--417--balloon

Additionally, when do we expect to have the pre-trained model?

Best Regards

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.