Closed zanussbaum closed 1 year ago
Hi,
Thanks for the great question!
We are simplifying a bit in the main text of the paper, and you can regard this term tT_loss as part of Lsimple when t=T. It's a bit misuse of notation to simplify the main text: that by assumption \mu_theta(x_T, T)= 0 since p(x_T) = N(0, I), so this Lsimple term at t=T reduces to a L2 norm of x_T.
The Lsimple term is derived from Lvlb term, and we put a detailed derivation in appendix E (page 17), and this line in particular justifies the term tT_loss.
Intuitively, this term exists to avoid the embedding norms from being too large.
Thanks for the great explanation! I missed that part originally when looking over the paper but intuitively makes sense to have a L2 norm of x_T!
Hi, thanks for releasing the code! I had a quick question about the different loss functions in the code.
I'm trying to wrap my head around the loss function presented in the paper and compare it to what's in the code. I'm taking a look at the function
Le2esimple(w) = Eqφ(x0:T |w) Lsimple(x0) + ||EMB(w) − µθ(x1, 1)||2 − log pθ(w|x0)
LSimple appears to be this line
The loss between the embeddings seem to be these lines
And the cross entropy loss between the logits and input tokens appears to be here
However, I'm a little confused on what these lines account for. From my debugging, this just seems to be taking the embeddings multiplied with noise, multiplied with
sqrt_alphas_cumprod
across all timesteps.Am I misinterpreting what's in the code versus what's in the paper?