huggingface / diffusers

🤗 Diffusers: State-of-the-art diffusion models for image and audio generation in PyTorch and FLAX.
https://huggingface.co/docs/diffusers
Apache License 2.0
25.56k stars 5.29k forks source link

The density_for_timestep_sampling and loss_weighting for SD3 Training!!! #9056

Open DidiD1 opened 2 months ago

DidiD1 commented 2 months ago

Thanks to Rafie Walker's code we can try to train SD3 models with flow-matching! But some places don't seem to match what's in the paper. Rafie Walker's code is below:

def compute_density_for_timestep_sampling(
    weighting_scheme: str, batch_size: int, logit_mean: float = None, logit_std: float = None, mode_scale: float = None
):
    if weighting_scheme == "logit_normal":
        # See 3.1 in the SD3 paper ($rf/lognorm(0.00,1.00)$).
        u = torch.normal(mean=logit_mean, std=logit_std, size=(batch_size,), device="cpu")
        u = torch.nn.functional.sigmoid(u)
    elif weighting_scheme == "mode":
        u = torch.rand(size=(batch_size,), device="cpu")
        u = 1 - u - mode_scale * (torch.cos(math.pi * u / 2) ** 2 - 1 + u)
    else:
        u = torch.rand(size=(batch_size,), device="cpu")
    return u

def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None):
    if weighting_scheme == "sigma_sqrt":
        weighting = (sigmas**-2.0).float()
    elif weighting_scheme == "cosmap":
        bot = 1 - 2 * sigmas + 2 * sigmas**2
        weighting = 2 / (math.pi * bot)
    else:
        weighting = torch.ones_like(sigmas)
    return weighting

My question is below:

  1. when weighting_scheme == "mode“, the code only compute the f_mode. If you need to compute 'u', you should some additional operation?
  2. Cos-map seems to compute the weight of timesteps, not the weight of loss?
  3. when we use logit_normal, it based on the RF-setting. So the weight of the loss should be t/(1-t), but the code doesn't compute the weight instead of torch.ones_like(sigmas)?

So I think there need some modify to correctly compute the loss of SD3! Thanks for discussion together!

bghira commented 2 months ago

honestly none of the weighting tricks really seem relevant to finetuning SD3. not using the timestep weighting has better results.

xiao2mo commented 2 months ago

honestly none of the weighting tricks really seem relevant to finetuning SD3. not using the timestep weighting has better results.

Could u give some more details, thanks a lot

bghira commented 2 months ago

yes, if you look at the timestep selection distribution using the SD3 style training, it effectively does not ever train the 900-1000 or 0-100 range of timesteps. they are just ignored:

image

ignoring the gaps in the chart here (wandb was having issues) the timestep selection at the end is where i switched to uniform sampling and the model started learning composition and details properly

DidiD1 commented 2 months ago

This phenomenon was mentioned in the SD3 paper,maybe why they proposed 'mode sampling with heavy-tails' time-sampling method. However it's strange that in their experiment results 'log-norm' is much better the 'mode' and uniform sampling. So I guess that maybe the different sampling method has their special advantages which needs experiment to valid which one is suitable for own task.

bghira commented 2 months ago

image

bghira commented 2 months ago

it just needs an absolutely enormous batch size for these to make sense.

edit: also worth noting these parameters are likely dependent on model size, the same way LR scales with model size when not using microsoft/mup

DidiD1 commented 2 months ago

Thanks a lot. And for my question3: "when we use logit_normal, it based on the RF-setting. So the weight of the loss should be t/(1-t), but the code doesn't compute the weight instead of torch.ones_like(sigmas)?" Do I need to modify the loss weight?

xiao2mo commented 1 month ago

yes, if you look at the timestep selection distribution using the SD3 style training, it effectively does not ever train the 900-1000 or 0-100 range of timesteps. they are just ignored:

image

ignoring the gaps in the chart here (wandb was having issues) the timestep selection at the end is where i switched to uniform sampling and the model started learning composition and details properly

Thanks a lot

ivylilili commented 1 month ago

yes, if you look at the timestep selection distribution using the SD3 style training, it effectively does not ever train the 900-1000 or 0-100 range of timesteps. they are just ignored:

image

ignoring the gaps in the chart here (wandb was having issues) the timestep selection at the end is where i switched to uniform sampling and the model started learning composition and details properly

@bghira Hi bghira~ I'd like to know when you try the "SD3 style training (lognorm sampling)" or "uniform sampling", what is the difference between the training loss? When you switched to uniform sampling, did it help to lower the loss curve? In my uniform training, these is still some artifacts in the generated image, so I wonder which part in the noise sampling is important to improve this problem? Want to hear your insights, Thanks~

bghira commented 1 month ago

currently we're using sigmoid sampling for timesteps which seems fine but no one has really ablated whether it leaves fine details out

culeao commented 1 month ago

currently we're using sigmoid sampling for timesteps which seems fine but no one has really ablated whether it leaves fine details out

Actually, sigmoid and lognorm are mathematically equivalent. But I'm curious why existing open source training implementations don't use timeshift during training, but SD3 paper does.

image
DidiD1 commented 1 month ago

currently we're using sigmoid sampling for timesteps which seems fine but no one has really ablated whether it leaves fine details out

Actually, sigmoid and lognorm are mathematically equivalent. But I'm curious why existing open source training implementations don't use timeshift during training, but SD3 paper does.

image

In fact, the diffusers version for SD3 has used the timashifting, You can see it in the init of FlowMatchEulerDiscreteScheduler, { "_class_name": "FlowMatchEulerDiscreteScheduler", "_diffusers_version": "0.29.0.dev0", "num_train_timesteps": 1000, "shift": 3.0 }

shcedule

sigmas = shift sigmas / (1 + (shift - 1) sigmas)

github-actions[bot] commented 3 weeks ago

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.