lukemelas / fixed-point-diffusion-models

43 stars 6 forks source link

[![Contributors][contributors-shield]][contributors-url] [![Forks][forks-shield]][forks-url] [![Stargazers][stars-shield]][stars-url] [![Issues][issues-shield]][issues-url] ### Fixed Point Diffusion Models [Project Page]( ยท [Paper](

DiT samples

Table of Contents



We introduce the Fixed Point Diffusion Model (FPDM), a novel approach to image generation that integrates the concept of fixed point solving into the framework of diffusion-based generative modeling. Our approach embeds an implicit fixed point solving layer into the denoising network of a diffusion model, transforming the diffusion process into a sequence of closely-related fixed point problems. Combined with a new stochastic training method, this approach significantly reduces model size, reduces memory usage, and accelerates training. Moreover, it enables the development of two new techniques to improve sampling efficiency: reallocating computation across timesteps and reusing fixed point solutions between timesteps. We conduct extensive experiments with state-of-the-art models on ImageNet, FFHQ, CelebA-HQ, and LSUN-Church, demonstrating substantial improvements in performance and efficiency. Compared to the state-of-the-art DiT model, FPDM contains 87% fewer parameters, consumes 60% less memory during training, and improves image generation quality in situations where sampling computation or time is limited.


We provide an environment.yml file that can be used to create a Conda environment. If you only want to run pre-trained models locally on CPU, you can remove the cudatoolkit and pytorch-cuda requirements from the file.

conda env create -f environment.yml
conda activate DiT


Our model definition, including all fixed point functionality, is included in


Example training scripts:

# Standard model
accelerate launch --config_file aconfs/1_node_1_gpu_ddp.yaml --num_processes 8

# Fixed Point Diffusion Model
accelerate launch --config_file aconfs/1_node_1_gpu_ddp.yaml --num_processes 8 --fixed_point True --deq_pre_depth 1 --deq_post_depth 1

# With v-prediction and zero-SNR
accelerate launch --config_file aconfs/1_node_1_gpu_ddp.yaml --num_processes 8 --output_subdir v_pred_exp --predict_v True --use_zero_terminal_snr True --fixed_point True --deq_pre_depth 1 --deq_post_depth 1

# With v-prediction and zero-SNR, with 4 pre- and post-layers
accelerate launch --config_file aconfs/1_node_1_gpu_ddp.yaml --num_processes 8 --output_subdir v_pred_exp --predict_v True --use_zero_terminal_snr True --fixed_point True --deq_pre_depth 4 --deq_post_depth 4


Example sampling scripts:

# Sample
python --ckpt {checkpoint-path-from-above} --fixed_point True --fixed_point_pre_depth 1 --fixed_point_post_depth 1 --cfg_scale 4.0 --num_sampling_steps 20

# Sample with fewer iterations per timestep and more timesteps
python --ckpt {checkpoint-path-from-above} --fixed_point True --fixed_point_pre_depth 1 --fixed_point_post_depth 1 --cfg_scale 4.0 --fixed_point_iters 12 --num_sampling_steps 40 --fixed_point_reuse_solution True


Pull requests are welcome!
