XiangLi1999 / Diffusion-LM

Diffusion-LM
Apache License 2.0
1.02k stars 133 forks source link

About the tT_loss #63

Open zzc681 opened 1 year ago

zzc681 commented 1 year ago

Hi, Thanks for your excellent work, but I have a small question about the loss function. When I was reading the code, I found that tT_loss calculates the loss between X_t and 0. Is there any meaning to doing this? The code in the gaussian_diffusion.py, the function training_losses_e2e in class GaussianDiffusion out_mean, _, _ = self.q_mean_variance(x_start, torch.LongTensor([self.num_timesteps - 1]).to(x_start.device)) tT_loss = mean_flat(out_mean ** 2)

ryuliuxiaodong commented 7 months ago

Same question for me.

The other loss terms written in training_losses_e2e are clear, which are also described in the paper (Equation 2). But I don't quite understand this tT_loss: why loss is calculated on each timestep of this forward diffusion process?