Closed catboxanon closed 1 month ago
At the current time I believe you never actually want to use scale_v_prediction_loss_like_noise_prediction, as it was originally a "fix" for an already outdated and wrongly implemented version of MinSNR (which is fixed now). Tying it to ZSNR doesn't really make any sense.
You're right. I guess a warning only will have to suffice.
Thank you for this! I plotted the weights of each loss, is this correct? Debiased estimation vpred
is based on the implementation of the new apply_debiased_estimation
.
Based on this chart, I think there may be some use for scale_v_pred_loss_like_noise_pred
as well.
is this correct?
Difficult for me to say affirmatively. I'd prefer waiting for a second or third opinion on it.
Based on this chart, I think there may be some use for
scale_v_pred_loss_like_noise_pred
as well.
I can remove the deprecation notice if you would like.
Your plots are a bit confusing, I assume "Debiased estimation" is the one in this PR with v_prediction=False, "Debiased estimation vpred" is with v_prediction=True, so far good, but with "SNR" it gets unclear. If you use "Scale vpred like noise vpred" with "SNR weighted loss vpred", then "Scale vpred like noise vpred" definitely doesn't look right to me. The weighting shouldn't decrease with timesteps increasing.
The purpose of this change is to clamp debiased weighting between (0,1) to accomodate for v-prediction loss, which looks right.
Some A/B/C tests from bluvoll trying to convert eps-based checkpoint to vpred on a tiny dataset: MinSNR | broken Debiased | fixed Debiased
Sorry for the confusion. Each corresponds as follows.
apply_debiased_estimation
, v_pred=Falseapply_debiased_estimation
, v_pred=Truescale_v_prediction_loss_like_noise_prediction
apply_snr_weight
, min_snr_gamma=5, vpred=Falseapply_snr_weight
, min_snr_gamma=5, vpred=TrueAs I understand it, scale_v_prediction_loss_like_noise_prediction
is not intended to be used in conjunction with apply_snr_weight
.
Tried to do some math and plot this as well.
I think it's important to see how each weighting strategy changes the effective SNR. However, in practice no one trains v-pred without ZSNR, it would make more sense to plot using ZSNR schedule. Notice that Debiased estimation vpred weighting * SNR(t)
overlaps with vpred-like loss
:
apply_debiased_estimation
, v_pred=Truescale_v_prediction_loss_like_noise_prediction
apply_snr_weight
, min_snr_gamma=5, vpred=True1/SNR(t)
If there's anything of note here, it's that Debiased+vpred WITH vpred-like loss look suspiciously alike to MinSNR+vpred at higher timesteps, which can suggest that scale_v_prediction_loss_like_noise_prediction
may (or may not) be useful for Debiased+vpred. Didn't include some options since the readability is already really poor.
Either way, discarding scale_v_prediction_loss_like_noise_prediction
right away was a poor choice. More tests would likely be needed.
Additionally, here's a plot comparing variants of debiased estimation.
Non-vpred variant is kinda useless with that spike at the beginning.
What I'd suggest is to leave scale_v_pred_loss_like_noise_pred
as is. V-prediction variant of Debiased estimation seems useful.
Thank you for the great diagrams and insight!
I did not expect scale_v_pred_loss_like_noise_pred
to be combined with ZSNR, but it's interesting that it seems to make sense to combine them.
Non-vpred variant is kinda useless with that spike at the beginning.
It's true that they are for noise pred, so it can't be helped that they have no meaning with v-pred.
@catboxanon
I think it would be better to remove the deprecation warning for scale_v_pred_loss_like_noise_pred
. Then this PR may be ready to be merged.
@kohya-ss
I think it would be better to remove the deprecation warning for scale_v_pred_loss_like_noise_pred. Then this PR may be ready to be merged.
Done.
Some A/B/C tests from bluvoll trying to convert eps-based checkpoint to vpred on a tiny dataset: MinSNR | broken Debiased | fixed Debiased
It looks good from the results. For some reason, it seems that debias estimation loss function late in training cause color contamination.
Considering that the original paper was published a long time ago, I thought it would be a good idea to refer to the cited paper for updates. https://scholar.google.com/scholar?cites=6450976606823846518&as_sdt=2005&sciodt=0,5&oi=gsb
@sdbds
For some reason, it seems that debias estimation loss function late in training cause color contamination.
I'm not exactly sure what do you mean by "color contamination", but if it's about weird color splotches, I do think this is rather strange. My theory is that they appear because of the new prediction target and schedule being wonky at first, and this will eventually go away with sufficient training. In fact, this model was trained on 300k samples using 1/(snr+1)
and I don't see any splotches there. It doesn't happen with MinSNR on the test examples likely because MinSNR doesn't rescale neither the mid nor high timesteps, and mid timesteps recieve more training with MinSNR compared to debiased. Notice how "grey" MinSNR results look. This is only a speculation though, and it may not be true.
I thought it would be a good idea to refer to the cited paper for updates.
I only found these loosely related papers, but looks like they don't focus on what this PR attempts to do at all.
I've merged. Sorry for the delay.
This PR: 1) Updates debiased estimation loss function for V-pred. The previous function was intended only for epsilon. For reference: https://github.com/kohya-ss/sd-scripts/issues/1058#issuecomment-1916893278
~~2) Adds a deprecation notice for
scale_v_pred_loss_like_noise_pred
. For reference: https://github.com/kohya-ss/sd-scripts/pull/934 https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/main/guided_diffusion/gaussian_diffusion.py#L864~~Removed per discussion below.
cc @feffy380 @sdbds, let me know if I'm missing something.