Closed ChintanTrivedi closed 2 years ago
Hi Chintan,
We didn't use any class embedding in this work.
emb_ = self.act_emb(emb) is your embedding, then you get a learnable slope and intercept from a linear/dense layer. Instead of doing two dense, one for linear and one for scale, we just do one dense that outputs 2x more output dim and we split it into two using chunk(). This gets us the scale(emb) and shift(emb). Both scale and shift are functions of the same embedding (time embedding).
I wasn't aware of them injecting the embedding differently for yemb and temb.
There are two ways of having two embeddings that I would use: 1) concatenate temb and yemb into emb, this make sense if you assume no interactions or 2) sequentially inject them x = AdaGN(x, temb) = slope(temb)GroupNorm(x)+ shift(temb) # inject temb x = act(x) # add nonlinearity x = slope(yemb)x+ shift(yemb) # inject yemb
Thanks for the clarifications. I read the literature again and my interpretations of ys and yb were incorrect. For now, I'm using the first embedding option for training.
No problem! Btw I made I typo, approach 2 would instead be: x = AdaGN(x, temb) = slope(temb)GroupNorm(x)+ shift(temb) # inject temb x = act(x) # add nonlinearity x = slope(yemb)x+ shift(yemb) # inject yemb
There is no double-ADAGN. But yeah approach 1 should be fine, interaction can happens anyways between yemb and temb from one layer to the next even if not within one layer.
Update: I trained the class-conditioned model using the first option to introduce embeddings, but it doesn't seem to correlate at all. In fact, the generated frames seem to completely ignore the class label. I'll try to add the classifier-gradient method for guided diffusion and see if that helps.
Referenced paper by Dhariwal et al. 2021 suggests to use
AdaGN(h, y) = ys GroupNorm(h)+yb
to combine timeys
and classyb
embeddings with resnet block activationsh
. I am having some trouble understanding how to implement this in the mcvd code in this repo since class conditioning lines are commented out. It seems time and class embeddings are to be concatenated together (based on commented code) and are fed together to the resnet block as "emb".My confusion:- How does splitting the linear projection of the combined embeddings into 2 chunks give us
scale
andshift
? How to interpret these two values in relation to the time and class embeddings? It seemsscale
might be analogous totemb
andshift
toyemb
, but that's not what the code suggests.PS: Getting some really good results for prediction tasks, thanks for making your code available!