kohya-ss / sd-scripts

Apache License 2.0
4.59k stars 780 forks source link

with_prior_preservation #613

Open Ted-developer opened 1 year ago

Ted-developer commented 1 year ago

Hello,

I've been training LORA using sd-script, and I've incorporated regular images (reg_data_dir). However, I've noticed that the training results seem a bit peculiar. I've examined the code in train_network.py, specifically regarding the loss calculation:

loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
loss = loss.mean([1, 2, 3])

Many do not utilize these regular images. Similarly, I've sought out examples from diffusers regarding the use of regular images, wherein the loss calculation appears alike:

if args.with_prior_preservation:
    # Chunk the noise and model_pred into two parts and compute the loss on each part separately.
    model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)
    target, target_prior = torch.chunk(target, 2, dim=0)

    # Compute instance loss
    pred_loss = F.mse_loss(model_pred.float(), target.float(), reduction="none").mean([1, 2, 3]).mean()

    # Compute prior loss
    prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean")

    # Add the prior loss to the instance loss.
    loss = pred_loss + args.prior_loss_weight * prior_loss
else:
    loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")

I would like to inquire why the loss calculation in sd-script doesn't utilize the regular images? Have I missed the corresponding code, or is there another consideration?

Thank you very much for your response.

crj1998 commented 1 year ago

This is because the loss calcuation of reg image is as same as normal image. The only difference is loss weight, so we can work like 'combining similar terms'. For example, we calcuate sum of two loss:

loss = a * F.mse_loss(pred_a, target_a, reduction="mean") + b * F.mse_loss(pred_b, target_b, reduction="mean")

it is almost equal to

loss = torch.tensor([a, a, a..., a, b, b, ....b]) * F.mse_loss(torch.cat([pred_a, pred_b]), torch.cat([target_a, target_b]), reduction="none")

So, the loss weight already prepared in dataset. https://github.com/kohya-ss/sd-scripts/blob/0cfcb5a49cf813547d728101cc05edf1a9b7d06c/library/train_util.py#L917-L920