crowsonkb / k-diffusion

Karras et al. (2022) diffusion models for PyTorch
MIT License
2.21k stars 371 forks source link

Common Diffusion Noise Schedules and Sample Steps are Flawed #64

Open sdbds opened 1 year ago

sdbds commented 1 year ago

https://arxiv.org/abs/2305.08891

I think these might be helpful

drhead commented 1 year ago

I've been experimenting with a model trained on v-prediction with zero terminal SNR, including with K-diffusion samplers. I've also applied the modification to the DDIM sampler under LDM and tested it out and it works as described -- being able to produce a solid black image with a dark subject. Not applying the zero terminal SNR beta rescale does still produce results close to this -- far closer than anything achieved while the model was being trained on epsilon loss with zero terminal SNR -- but does not quite achieve the solid black background. So as far as what would need to be done on K-diffusion's end to bring it up to this capability, including information on my experiences experimenting with these so it can be known what these should be expected to do (going off the paper's suggestions in order):

  1. Enforce Zero Terminal SNR -- This is likely to be the most difficult one. The noise schedule would have to be rescaled to enforce zero SNR at inference. I do not know enough about how K-diffusion determines its noise schedule to fix this or to know if it is possible without significant rewrites, and the lack of availability of model checkpoints trained on zero terminal SNR with v-loss that could be used to properly test this does not help either (I know someone who will start training a 1.5 model on V-loss with terminal SNR soon, and that checkpoint will be made available). Based on my testing this is necessary to get the most out of the model.
  2. Train with V Prediction and V Loss -- It should be noted that any model trained on epsilon loss will either produce a solid black image or throw a tensor full of NaNs when inference is attempted under a noise schedule that enforces zero terminal SNR. The other suggestions in the paper can have visible benefits even on epsilon loss, but zero SNR noise schedules might as well be disabled forcibly if not running on V-parameterization. I will note that in practice terminal snr trained on epsilon loss still performs quite well, about as well as offset noise but with more consistency, but V-parameterization is necessary for it to work "as advertised".
  3. Sample from the Last Timestep -- K-diffusion already appears to do this, using a linspace sample step selection. The authors of the paper, testing the DDIM sampler, found very little difference between linspace and trailing steps on common step counts like 25, but found that trailing sample step selection was noticeably better on smaller step counts like 5, which presumably would translate upwards. Testing should be done to see if this is the case with K-diffusion samplers as well for the sake of thoroughness, but overall expectations should be if anything a very minor increase in visual fidelity that is imperceptible during normal use.
  4. Rescale Classifier-Free Guidance -- I have experimented with someone's implementation of this on K-diffusion and compared it to the implementation in Diffusers which is DDIM only. They operate nothing alike, and both in turn look nothing like the results shown in the paper, and because of that I suspect that one or both of the implementations is wrong -- the implementation I played with for K-diffusion heavily desaturates images, and the Diffusers implementation makes no significant difference in 1.5 (if anything a lot of the test images looked worse with the recommended rescale of 0.7 phi), and in 2.1 it slightly desaturated images and gave some increase in visual fidelity. The authors unfortunately do not provide code for this. I have gotten the K-diffusion implementation to work as implemented here, it appears that the desaturation is not nearly as much of an issue if you are using a model trained on zero terminal SNR with v-loss -- I've included examples sampled using 30 steps of DPM++ 2M SDE Karras in the linked pull request.

I would love to see K-diffusion support this fully. Zero terminal SNR models are incredibly capable.