Open ngitnenlim opened 7 months ago
Also, this change would cause apply_snr_weight
(without v-prediction) and apply_debiased_estimation
to break with a divide by zero error when zero_terminal_snr
is True. The loss of last (zero SNR) timestep would be scaled to 0 when training with min SNR and v-prediction with this modification. I've just recently started experimenting with zero terminal SNR models, so I'm not entirely sure if this is the desired behavior or how to address the issue. I noticed that @drhead seems to have quite a bit of experience in training these types of models. Would you mind sharing your thoughts or any advice you might have on this topic?
The difference won't be very significant either way. It's technically incorrect, but I wouldn't worry about it too much.
At least based on some of my ongoing experiments with adaptive loss weights, the bigger problem is that those ends go towards zero. Learned timestep weighting schemes will arrive at something like this (ignore the resolution bit and also the waviness of the line): which suggests that the ideal timestep weighting is one that does not weigh any timesteps as orders of magnitude less important than others (😱). Which does make much more intuitive sense, but before I suggest changes I would want to do a symbolic regression first. In the mean time, I would say that the answer is that neither is really correct.
Thank you! I was getting underwhelming results from my experiments on zero terminal SNR models, and I figured there might be something wrong with the loss scaling. 😅 By the way, how did you get this graph? Is this graph describing a zero terminal SNR model with v loss? I'd be very appreciative if you could link me any sources that I can learn about this adaptive loss weighting learning process.
Thank you! I was getting underwhelming results from my experiments on zero terminal SNR models, and I figured there might be something wrong with the loss scaling. 😅 By the way, how did you get this graph? Is this graph describing a zero terminal SNR model with v loss? I'd be very appreciative if you could link me any sources that I can learn about this adaptive loss weighting learning process.
This graph is the result of an adaptation of the continuous timestep weighting MLP described in the EDM2 paper back to discrete timesteps. Basically, instead of bashing your head into a wall figuring out which timestep weighting scheme is the best, you can let the model figure it out with a loss objective that will, over time, equalize the contribution of all timesteps. I'm in the process of ironing out some kinks in it (by my count, there appear to be about 6-7), after which I will set up a pull request for it here.
(the resolution part is me adding a proxy for image resolution to the mix, which is very narrowly applicable to the model I am training that trains on a variety of resolution groups -- the fact that those lines are in order suggests that this was the right thing to do, but it's very niche and would be a nightmare to set up in this codebase so it'll be left out)
This is very interesting! Thank you again for sharing your amazing findings with me. I'm looking forward to your PR.
I'd like to provide an update on this. The original schedule, represented by the red line, used a minimum signal-to-noise ratio (SNR) gamma of 5, v-prediction, and a zero terminal SNR. The blue line is a new schedule inspired by the learned timestep weighting suggested by @drhead. The results are promising after around 4,500 steps. Notice that the scales at both ends are no longer close to zero, which should help the model learn effectively across all timesteps. I believe this is the right direction, though further testing may be needed.
@jihnenglin If you want to experiment with the learned timestep weighting, I have a prototype here as a gist: https://gist.github.com/drhead/ac6ecc1f6dc1fd478064f3d81ca12a25
It does typically take a long time with a fairly large batch size by the standards of most LoRAs, but if you plug that schedule in as an array as lambda_weights, it should converge in a reasonable timeframe.
@drhead Thank you! I'll definitely try it later. And sorry for the confusion, I think it have converged way earlier, the 4.5 k step is just the first checkpoint I saved and tested. Is the learning rate 0.005 good even for fine-tuning the entire model? I'm scared.
Is the learning rate 0.005 good even for fine-tuning the entire model? I'm scared.
Oh, no, no, you want that just for the loss weight MLP itself. It's using EDM2-style weight normalization, so it can handle very high learning rates and a lack of weight decay because its activations can't grow almost unbounded like they do on most models. Use your normal learning rate for the actual diffusion model/LoRA, and set this model's parameters to use the higher learning rate and zero weight decay.
Oh, thanks for clearing it up for me! These are super helpful.
@drhead This may be an irrelevant issue but most samplers give me weird results at lower (20-30) sampling steps for my v-prediction model. And they all work perfectly fine at higher (> 50) sampling steps! DDIM and Euler are the only ones that give decent results with lower steps. Have you ever had this problem before?
@drhead This may be an irrelevant issue but most samplers give me weird results at lower (20-30) sampling steps for my v-prediction model. And they all work perfectly fine at higher (> 50) sampling steps! DDIM and Euler are the only ones that give decent results with lower steps. Have you ever had this problem before?
Very common issue. One thing I recommend for mitigating it is setting sigma_max to a lower value like 160 (zsnr options on A1111 and ComfyUI default to ~4500, Diffusers defaults to 4096), though this will not use the terminal timestep and is not ideal. I recently submitted a PR to A1111 with a scheduler that I find is better from a recent paper: https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15608
Thank you so much for sharing this pro tip! Glad to know this is an inference phase issue rather than a training phase one. Setting sigma max to 160 enables DPM++ 2M Karras / DPM++ SDE Karras samplers converge faster. Unfortunately, KL Optimal doesn't make convergence faster for my model. 😔
I think the correct implementation should be like this.
The current implementation calculates
noise_scheduler.all_snr
before even fixing the betas in the scheduler. This would makeapply_snr_weight
use the wrong SNR when training zero terminal SNR models with min SNR gamma.