huggingface / diffusers

🤗 Diffusers: State-of-the-art diffusion models for image and audio generation in PyTorch and FLAX.
https://huggingface.co/docs/diffusers
Apache License 2.0
25.89k stars 5.34k forks source link

"Common Diffusion Noise Schedules and Sample Steps are Flawed" integration #3475

Open vvvm23 opened 1 year ago

vvvm23 commented 1 year ago

I was reading the paper Common Diffusion Noise Schedules and Sample Steps are Flawed and found it pretty interesting. It proposes a few simple changes that could be useful when integrated into existing schedulers in the diffusers library. Namely:

  1. "rescale the noise schedule to enforce zero terminal SNR"
  2. "change the sampler to always start from the last timestep"
  3. "rescale classifier-free guidance to prevent over-exposure"

They demonstrate the proposed fixes by finetuning Stable Diffusion and showing it can generate new concepts it could not successfully do before, such as solid colour backgrounds. Attached a figure to demonstrate this.

image

Is this something that would be valuable to add to the diffusers library? Are these proposals already integrated? Thanks~

Max-We commented 1 year ago

FYI I played around with this here: https://github.com/Max-We/sf-zero-signal-to-noise

vvvm23 commented 1 year ago

@Max-We nice write up! Do you have any plans to integrate the changes into this library?

I am wondering whether LoRA finetuning would be sufficient to adapt the model for the new scheduler, or whether we would need to do full scale finetuning.

I asked online about releasing the official finetuned model weights, but I don't think we will get them.

rvorias commented 1 year ago

I am wondering whether LoRA finetuning would be sufficient to adapt the model for the new scheduler, or whether we would need to do full scale finetuning.

Been also wondering about this, but Lora only acts on the (cross) attention layers which is nice for "what" and "where", but not sure if it can also recalibrate the total noise prediction. Keen to try out tho.

vvvm23 commented 1 year ago

I think LoRA can be applied to arbitrary layers, but maybe the default is cross attention only. It is a good observation though that maybe more than that is needed.

adammenges commented 1 year ago

Huge vote up for this, looks like it'll solve many of the issues I've been experiencing in my work with images using a white background.

Max-We commented 1 year ago

@vvvm23 I will try to implement parts of it this Saturday

patrickvonplaten commented 1 year ago

Very cool! Let me know if you need help @Max-We

patrickvonplaten commented 1 year ago

BTW Stable Diffusion 2.x is already trained with v-prediction, so we should be able to see clear improvements there without having to retrain the model - did anybody check this by any chance?

bghira commented 1 year ago

i have checked that. it works, but there's even better results by fine-tuning 2.1 even just a little bit with captioned photos after enabling the fixes. and the deeper you go, the better it gets. but out of the box, it results in much more useable 2.1 outputs.

eeyrw commented 1 year ago

I have tried trick "rescale classifier-free guidance to prevent over-exposure" as follow implement but it seems just making image more photo realistic rather than fixing over exposure:

        # 8. Denoising loop
        for i, t in enumerate(self.progress_bar(timesteps)):
            # expand the latents if we are doing classifier free guidance
            latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
            latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)

            # predict the noise residual
            noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample

            # perform guidance
            if do_classifier_free_guidance:
                noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
                noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)

                # Here we rescale
                sigma_cfg = torch.std(noise_pred)
                sigma_pos = torch.std(noise_pred_text)
                noise_pred_rescaled = noise_pred*(sigma_pos/sigma_cfg)
                phi = 0.7
                noise_pred_new = phi*noise_pred_rescaled + (1-phi)*noise_pred
                noise_pred = noise_pred_new
Max-We commented 1 year ago

FYI I suggested changes for

  1. "rescale the noise schedule to enforce zero terminal SNR"
  2. "rescale classifier-free guidance to prevent over-exposure"

in PR #3664

PeterL1n commented 1 year ago

BTW Stable Diffusion 2.x is already trained with v-prediction, so we should be able to see clear improvements there without having to retrain the model - did anybody check this by any chance?

This is incorrect. After changing the schedule you must retrain (finetune) the model on the new schedule.

patrickvonplaten commented 1 year ago

Great job opening the PR @Max-We ! I played around with it this morning and change the PR slightly to be less backwards breaking. Could you (and @PeterL1n maybe as well) take a look if the current implementation / PR is good for you? I've also added "3. Sample from the Last Timestep"

Max-We commented 1 year ago

@patrickvonplaten Thank you for your assistance! The changes look good to me.

eeyrw commented 1 year ago

BTW Stable Diffusion 2.x is already trained with v-prediction, so we should be able to see clear improvements there without having to retrain the model - did anybody check this by any chance?

This is incorrect. After changing the schedule you must retrain (finetune) the model on the new schedule.

How many samples do we need to finetune to achieve a acceptable result?

bghira commented 1 year ago

@eeyrw

in my testing last night, at a batch size of 18 on an A100-80G it took about 100-500 iterations of training (1800-9000 samples) using a learning rate of 1e-8 with a partially frozen text encoder (up to layer 18) to really see a stark contrast from the baseline model, using very high quality curated training data.

i tried training with DDIM as the sampler but it did not work. all of the tensors went NaN in about 5-6 steps. i tried debugging, and no combination of settings made it work.

using DDPMScheduler for training now, but it's concerning to me whether the new fixes are being applied this way, because no changes were made to that?

eeyrw commented 1 year ago

@bghira It seems not quite a lot samples and iterations to make it work thant which I imagine. But the learning rate you used is quite small, I often use 5e-6 or smaller but larger than 1e-6. Does that help model learn overall style or image quality rather than specific content of image?

It seems quite new to me that there are options for train scheduler other than DDPMScheduler.

bghira commented 1 year ago

the example dreambooth code for example used DPMSolverMultistepScheduler.

the learning rate is kept low because i'm training the text encoder in addition to the unet. i've done a lot of experimentation with this, a 1e-6 will change the output a lot, by each step of training. and eventually, basically destroy the model in a thousand steps or less.. by destroy, i mean, it loses comprehension of basic shapes and textures. a prompt for a dog will look like a couch cushion was photographed uncomfortably close.

if you freeze the encoder partially so that, for example, you train the last 25% of its layers, you get much better effect, but if your training data isn't balanced enough, you can lose access to concepts that the base checkpoint knew. eg. i did not provide pictures of geckos, and so, despite being better at human subjects, the model could not make a gecko.

if you freeze text encoder completely when fine-tuning, the benefit of fine-tuning almost entirely disappears, with not much changing of the contents of the image. but just absorbs essentially, the artifacts of the training material. it can become very pixelated and grainy, as it overfits on the training data's textures. Edit: This is particularly true with 2.1-v, as it is massively overtrained compared to 2.1-base or 2.0-v.

the higher the learning rate, the more likely it is to break during fine-tuning, and the more likely it is to enter catastrophic forgetting. but the lower learning rates don't appear to do much for the unet. perhaps this is desired. perhaps we really do need split learning rate support. i would be happy to be corrected on any assertions i've made.

eeyrw commented 1 year ago

@bghira You really provide a lot of useful information about finetune. Thank you. I observed same phenomena that finetune unet only gave pixelated and grainy output. I used to add some poor quality images with negative tag when training and use negative prompts when inferencing to avoid those artifacts.

I have limited GPU memory and I cached text encoder output to disk so I do not try to train text encoder simultaneously. The DeepFloyd-IF indicates that increasing text encoder size can significantly increase performance. The text encoder is a less noticed part of SD. Changing CLIP text encoder to other LLM such as T5,GPT may help to performance of SD, but it requires train from scratch which is quite expensive and no one make a try.

bghira commented 1 year ago

image so far this has obsoleted offset noise for me

flymin commented 1 year ago

@eeyrw

in my testing last night, at a batch size of 18 on an A100-80G it took about 100-500 iterations of training (1800-9000 samples) using a learning rate of 1e-8 with a partially frozen text encoder (up to layer 18) to really see a stark contrast from the baseline model, using very high quality curated training data.

i tried training with DDIM as the sampler but it did not work. all of the tensors went NaN in about 5-6 steps. i tried debugging, and no combination of settings made it work.

using DDPMScheduler for training now, but it's concerning to me whether the new fixes are being applied this way, because no changes were made to that?

It seems that fix #3664 only changes ddim scheduler. Since they do not change ddpm schedule, which is used for training/finetuning, why is it necessary to fine-tune the model?

fangchuan commented 4 days ago

hi all, I am using the zero-terminal SNR noise scheduler to fine tuning a multi-modal stable diffusion(Wonder3D), the loss is pretty large even if I can get reasonable generations in validation set, did you encounter this problem? image

batch_size=256, the experimental dataset volume is 5600 objects, I use DDPMScheduler with rescale_betas_zero_snr=True during training, and DDIMScheduler with rescale_betas_zero_snr=True, timestep_spacing='trailing' during CFG inference.