kohya-ss / sd-scripts

Apache License 2.0
4.87k stars 813 forks source link

`fix_noise_scheduler_betas_for_zero_terminal_snr` should come before `prepare_scheduler_for_custom_training` #1263

Open jihnenglin opened 4 months ago

jihnenglin commented 4 months ago

I think the correct implementation should be like this.

noise_scheduler = DDPMScheduler(
    beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False
)
if args.zero_terminal_snr:
    custom_train_functions.fix_noise_scheduler_betas_for_zero_terminal_snr(noise_scheduler)
prepare_scheduler_for_custom_training(noise_scheduler, accelerator.device)

The current implementation calculates noise_scheduler.all_snr before even fixing the betas in the scheduler. This would make apply_snr_weight use the wrong SNR when training zero terminal SNR models with min SNR gamma.

jihnenglin commented 4 months ago

Also, this change would cause apply_snr_weight (without v-prediction) and apply_debiased_estimation to break with a divide by zero error when zero_terminal_snr is True. The loss of last (zero SNR) timestep would be scaled to 0 when training with min SNR and v-prediction with this modification. I've just recently started experimenting with zero terminal SNR models, so I'm not entirely sure if this is the desired behavior or how to address the issue. I noticed that @drhead seems to have quite a bit of experience in training these types of models. Would you mind sharing your thoughts or any advice you might have on this topic?

drhead commented 4 months ago

The difference won't be very significant either way. It's technically incorrect, but I wouldn't worry about it too much.

At least based on some of my ongoing experiments with adaptive loss weights, the bigger problem is that those ends go towards zero. Learned timestep weighting schemes will arrive at something like this (ignore the resolution bit and also the waviness of the line): loss_weights_14099_297d6ed00bea96ef0f01 which suggests that the ideal timestep weighting is one that does not weigh any timesteps as orders of magnitude less important than others (😱). Which does make much more intuitive sense, but before I suggest changes I would want to do a symbolic regression first. In the mean time, I would say that the answer is that neither is really correct.

jihnenglin commented 4 months ago

Thank you! I was getting underwhelming results from my experiments on zero terminal SNR models, and I figured there might be something wrong with the loss scaling. 😅 By the way, how did you get this graph? Is this graph describing a zero terminal SNR model with v loss? I'd be very appreciative if you could link me any sources that I can learn about this adaptive loss weighting learning process.

drhead commented 4 months ago

Thank you! I was getting underwhelming results from my experiments on zero terminal SNR models, and I figured there might be something wrong with the loss scaling. 😅 By the way, how did you get this graph? Is this graph describing a zero terminal SNR model with v loss? I'd be very appreciative if you could link me any sources that I can learn about this adaptive loss weighting learning process.

This graph is the result of an adaptation of the continuous timestep weighting MLP described in the EDM2 paper back to discrete timesteps. Basically, instead of bashing your head into a wall figuring out which timestep weighting scheme is the best, you can let the model figure it out with a loss objective that will, over time, equalize the contribution of all timesteps. I'm in the process of ironing out some kinks in it (by my count, there appear to be about 6-7), after which I will set up a pull request for it here.

(the resolution part is me adding a proxy for image resolution to the mix, which is very narrowly applicable to the model I am training that trains on a variety of resolution groups -- the fact that those lines are in order suggests that this was the right thing to do, but it's very niche and would be a nightmare to set up in this codebase so it'll be left out)

jihnenglin commented 4 months ago

This is very interesting! Thank you again for sharing your amazing findings with me. I'm looking forward to your PR.

jihnenglin commented 4 months ago

line-graph I'd like to provide an update on this. The original schedule, represented by the red line, used a minimum signal-to-noise ratio (SNR) gamma of 5, v-prediction, and a zero terminal SNR. The blue line is a new schedule inspired by the learned timestep weighting suggested by @drhead. The results are promising after around 4,500 steps. Notice that the scales at both ends are no longer close to zero, which should help the model learn effectively across all timesteps. I believe this is the right direction, though further testing may be needed.

drhead commented 4 months ago

@jihnenglin If you want to experiment with the learned timestep weighting, I have a prototype here as a gist: https://gist.github.com/drhead/ac6ecc1f6dc1fd478064f3d81ca12a25

It does typically take a long time with a fairly large batch size by the standards of most LoRAs, but if you plug that schedule in as an array as lambda_weights, it should converge in a reasonable timeframe.

jihnenglin commented 4 months ago

@drhead Thank you! I'll definitely try it later. And sorry for the confusion, I think it have converged way earlier, the 4.5 k step is just the first checkpoint I saved and tested. Is the learning rate 0.005 good even for fine-tuning the entire model? I'm scared.

drhead commented 4 months ago

Is the learning rate 0.005 good even for fine-tuning the entire model? I'm scared.

Oh, no, no, you want that just for the loss weight MLP itself. It's using EDM2-style weight normalization, so it can handle very high learning rates and a lack of weight decay because its activations can't grow almost unbounded like they do on most models. Use your normal learning rate for the actual diffusion model/LoRA, and set this model's parameters to use the higher learning rate and zero weight decay.

jihnenglin commented 4 months ago

Oh, thanks for clearing it up for me! These are super helpful.

jihnenglin commented 4 months ago

@drhead This may be an irrelevant issue but most samplers give me weird results at lower (20-30) sampling steps for my v-prediction model. And they all work perfectly fine at higher (> 50) sampling steps! DDIM and Euler are the only ones that give decent results with lower steps. Have you ever had this problem before?

drhead commented 4 months ago

@drhead This may be an irrelevant issue but most samplers give me weird results at lower (20-30) sampling steps for my v-prediction model. And they all work perfectly fine at higher (> 50) sampling steps! DDIM and Euler are the only ones that give decent results with lower steps. Have you ever had this problem before?

Very common issue. One thing I recommend for mitigating it is setting sigma_max to a lower value like 160 (zsnr options on A1111 and ComfyUI default to ~4500, Diffusers defaults to 4096), though this will not use the terminal timestep and is not ideal. I recently submitted a PR to A1111 with a scheduler that I find is better from a recent paper: https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15608

jihnenglin commented 4 months ago

Thank you so much for sharing this pro tip! Glad to know this is an inference phase issue rather than a training phase one. Setting sigma max to 160 enables DPM++ 2M Karras / DPM++ SDE Karras samplers converge faster. Unfortunately, KL Optimal doesn't make convergence faster for my model. 😔