kohya-ss / sd-scripts

Apache License 2.0
5.31k stars 880 forks source link

Update debiased estimation loss function to accommodate V-pred #1715

Closed catboxanon closed 1 month ago

catboxanon commented 1 month ago

This PR: 1) Updates debiased estimation loss function for V-pred. The previous function was intended only for epsilon. For reference: https://github.com/kohya-ss/sd-scripts/issues/1058#issuecomment-1916893278

~~2) Adds a deprecation notice for scale_v_pred_loss_like_noise_pred. For reference: https://github.com/kohya-ss/sd-scripts/pull/934 https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/main/guided_diffusion/gaussian_diffusion.py#L864~~

Removed per discussion below.

cc @feffy380 @sdbds, let me know if I'm missing something.

liesened commented 1 month ago

At the current time I believe you never actually want to use scale_v_prediction_loss_like_noise_prediction, as it was originally a "fix" for an already outdated and wrongly implemented version of MinSNR (which is fixed now). Tying it to ZSNR doesn't really make any sense.

catboxanon commented 1 month ago

You're right. I guess a warning only will have to suffice.

kohya-ss commented 1 month ago

Thank you for this! I plotted the weights of each loss, is this correct? Debiased estimation vpred is based on the implementation of the new apply_debiased_estimation.

Figure_1 Figure_2

Based on this chart, I think there may be some use for scale_v_pred_loss_like_noise_pred as well.

catboxanon commented 1 month ago

is this correct?

Difficult for me to say affirmatively. I'd prefer waiting for a second or third opinion on it.

Based on this chart, I think there may be some use for scale_v_pred_loss_like_noise_pred as well.

I can remove the deprecation notice if you would like.

liesened commented 1 month ago

Your plots are a bit confusing, I assume "Debiased estimation" is the one in this PR with v_prediction=False, "Debiased estimation vpred" is with v_prediction=True, so far good, but with "SNR" it gets unclear. If you use "Scale vpred like noise vpred" with "SNR weighted loss vpred", then "Scale vpred like noise vpred" definitely doesn't look right to me. The weighting shouldn't decrease with timesteps increasing.

The purpose of this change is to clamp debiased weighting between (0,1) to accomodate for v-prediction loss, which looks right.

liesened commented 1 month ago

Some A/B/C tests from bluvoll trying to convert eps-based checkpoint to vpred on a tiny dataset: MinSNR | broken Debiased | fixed Debiased xyz_grid-0007-4018816163 xyz_grid-0010-2740262790 xyz_grid-0011-4062610619 xyz_grid-0012-1817032482

kohya-ss commented 1 month ago

Sorry for the confusion. Each corresponds as follows.

As I understand it, scale_v_prediction_loss_like_noise_prediction is not intended to be used in conjunction with apply_snr_weight.

liesened commented 1 month ago

Tried to do some math and plot this as well. image

I think it's important to see how each weighting strategy changes the effective SNR. However, in practice no one trains v-pred without ZSNR, it would make more sense to plot using ZSNR schedule. Notice that Debiased estimation vpred weighting * SNR(t) overlaps with vpred-like loss: image

If there's anything of note here, it's that Debiased+vpred WITH vpred-like loss look suspiciously alike to MinSNR+vpred at higher timesteps, which can suggest that scale_v_prediction_loss_like_noise_prediction may (or may not) be useful for Debiased+vpred. Didn't include some options since the readability is already really poor.

Either way, discarding scale_v_prediction_loss_like_noise_prediction right away was a poor choice. More tests would likely be needed.

liesened commented 1 month ago

Additionally, here's a plot comparing variants of debiased estimation. image

Non-vpred variant is kinda useless with that spike at the beginning.

What I'd suggest is to leave scale_v_pred_loss_like_noise_pred as is. V-prediction variant of Debiased estimation seems useful.

kohya-ss commented 1 month ago

Thank you for the great diagrams and insight!

I did not expect scale_v_pred_loss_like_noise_pred to be combined with ZSNR, but it's interesting that it seems to make sense to combine them.

Non-vpred variant is kinda useless with that spike at the beginning.

It's true that they are for noise pred, so it can't be helped that they have no meaning with v-pred.

@catboxanon I think it would be better to remove the deprecation warning for scale_v_pred_loss_like_noise_pred. Then this PR may be ready to be merged.

catboxanon commented 1 month ago

@kohya-ss

I think it would be better to remove the deprecation warning for scale_v_pred_loss_like_noise_pred. Then this PR may be ready to be merged.

Done.

sdbds commented 1 month ago

Some A/B/C tests from bluvoll trying to convert eps-based checkpoint to vpred on a tiny dataset: MinSNR | broken Debiased | fixed Debiased xyz_grid-0007-4018816163 xyz_grid-0010-2740262790 xyz_grid-0011-4062610619 xyz_grid-0012-1817032482

It looks good from the results. For some reason, it seems that debias estimation loss function late in training cause color contamination.

sdbds commented 1 month ago

Considering that the original paper was published a long time ago, I thought it would be a good idea to refer to the cited paper for updates. https://scholar.google.com/scholar?cites=6450976606823846518&as_sdt=2005&sciodt=0,5&oi=gsb

liesened commented 1 month ago

@sdbds

For some reason, it seems that debias estimation loss function late in training cause color contamination.

I'm not exactly sure what do you mean by "color contamination", but if it's about weird color splotches, I do think this is rather strange. My theory is that they appear because of the new prediction target and schedule being wonky at first, and this will eventually go away with sufficient training. In fact, this model was trained on 300k samples using 1/(snr+1) and I don't see any splotches there. It doesn't happen with MinSNR on the test examples likely because MinSNR doesn't rescale neither the mid nor high timesteps, and mid timesteps recieve more training with MinSNR compared to debiased. Notice how "grey" MinSNR results look. This is only a speculation though, and it may not be true.

I thought it would be a good idea to refer to the cited paper for updates.

I only found these loosely related papers, but looks like they don't focus on what this PR attempts to do at all.

kohya-ss commented 1 month ago

I've merged. Sorry for the delay.