Open LinB203 opened 1 week ago
Indeed. I have very seldom trained with v-pred so, didn't explore this setting too much. Do you maybe wanna open a PR?
Does it have any problem?
@drhead cc
So, the intent of min-snr-gamma is to balance how important each timestep is, so that they are treated as equally important.
Min-snr-gamma attempts to do this with a fixed formula, but you can also do this by setting up trainable parameters to control timestep weighting, which the EDM2 paper did (https://arxiv.org/pdf/2312.02696 pp. 19-22) , based on earlier work by Kendall et. al (https://arxiv.org/abs/1705.07115). This is a much more flexible approach, since you effectively don't have to do anything except ensure your model trains for long enough for the weights to settle and you can effectively guarantee that the parameters will settle at a point where the average training loss of each timestep is equalized in terms of how much it contributes to gradients.
I do have a training run where I am doing this right now, on v-prediction, and this is what the timestep weightings look like currently:
I would say that this is evidence that the v-prediction min-snr-gamma curve is at least more correct than one looking more similar to the epsilon one, in that there's a peak around timestep 200-300 and lower weights towards the ends. But we can also see that the weights of the tails are definitely not zero, so I would say you are right to suspect that the tails of min snr gamma approaching zero is not ideal.
I can't really nail down a fixed formula that would represent the type of curve that I've gotten from the homoscedactic uncertainty method, but I do know that the intensity of the peak increases along with the size of the latent, so that would need to be accounted for:
I do have a suspicion that this is related to what was pointed out in Hoogeboom et al. (https://arxiv.org/pdf/2301.11093) about high resolution models needing more noise to fully destroy the signal.
In conclusion though re: min-snr-gamma, the formula is correct as implemented if the curve is that shape, that's what the formula described in the paper for v-prediction looks like. There's arguably better methods and room for those to be implemented, but they'd have to be separate, there's nothing to really "fix" with min-snr-gamma.
Describe the bug
I believe the SNR weighting of v_prediction should follow a similar trend as eps, otherwise, for T>600, the model learns almost nothing as the weight approaches zero. If I am wrong, please correct me. Thank you!
Reproduction
Logs
System Info
None
Who can help?
@yiyixuxu @asomoza @sayakpaul