Tencent / HunyuanDiT

Hunyuan-DiT : A Powerful Multi-Resolution Diffusion Transformer with Fine-Grained Chinese Understanding
https://dit.hunyuan.tencent.com/
Other
3.32k stars 285 forks source link

Question about vb_terms in training_losses function #112

Closed congdm closed 3 months ago

congdm commented 3 months ago

Hi, I'm currently writing training script for ipex based your example training script, I've noticed that in the training_losses function, there is an additional vb_terms when calculating the loss:

B, C = x_t.shape[:2]
assert_shape(model_output, (B, C * 2, *x_t.shape[2:]))
model_output, model_var_values = th.split(model_output, C, dim=1)
# Learn the variance using the variational bound, but don't let
# it affect our mean prediction.
frozen_out = th.cat([model_output.detach(), model_var_values], dim=1)
terms["vb"] = self._vb_terms_bpd(
    model=lambda *args, r=frozen_out: dict(x=r),
    x_start=x_start,
    x_t=x_t,
    t=t,
    clip_denoised=False,
)["output"]

I've seen that this vb_terms was added to the final loss, but I don't understand the significance of this "variational bound". If I ignore it like this in my training script, what would be the impact?

noise_pred = model(noisy_latents, timesteps.to(dtype=noisy_latents.dtype), **model_kwargs)['x']
B, C = noisy_latents.shape[:2]
assert noise_pred.shape == (B, C * 2, *noisy_latents.shape[2:])
noise_pred, model_var_values = torch.split(noise_pred, C, dim=1) # ignore model_var_values

if args.v_parameterization:
    # v-parameterization training
    target = self.noise_scheduler.get_velocity(latents, noise, timesteps)
else:
    target = noise

loss = conditional_loss(
    noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c
)
loss = loss.mean()
return loss
zml-ai commented 3 months ago

We suggest learning the variance through the variational bound during training. As noted in IDDPM, this helps the diffusion model learn more modes, creating more peaks in probability distribution. For text-to-image tasks, this enhances the diversity of the generated images. For more details, please refer to paper.