voletiv / mcvd-pytorch

Official implementation of MCVD: Masked Conditional Video Diffusion for Prediction, Generation, and Interpolation (https://arxiv.org/abs/2205.09853)
MIT License
331 stars 26 forks source link

Adding class condition to time embeddings in resnet block #4

Closed ChintanTrivedi closed 2 years ago

ChintanTrivedi commented 2 years ago

Referenced paper by Dhariwal et al. 2021 suggests to use AdaGN(h, y) = ys GroupNorm(h)+yb to combine time ys and class yb embeddings with resnet block activations h. I am having some trouble understanding how to implement this in the mcvd code in this repo since class conditioning lines are commented out. It seems time and class embeddings are to be concatenated together (based on commented code) and are fed together to the resnet block as "emb".

# resnetblock 
def forward(self, x, temb=None, yemb=None, cond=None):
    if emb is not None:
        emb = torch.cat([temb, yemb], dim=1) # Combine time and class embeddings
        emb_out = self.Dense_0(self.act_emb(emb))[:, :, None, None]  # Linear projection
        scale, shift = torch.chunk(emb_out, 2, dim=1)
        [ ... ]
        emb_norm = self.Norm_0(x)
        x = emb_norm * (1 + scale) + shift

My confusion:- How does splitting the linear projection of the combined embeddings into 2 chunks give us scale and shift? How to interpret these two values in relation to the time and class embeddings? It seems scale might be analogous to temb and shift to yemb, but that's not what the code suggests.

PS: Getting some really good results for prediction tasks, thanks for making your code available!

AlexiaJM commented 2 years ago

Hi Chintan,

We didn't use any class embedding in this work.

emb_ = self.act_emb(emb) is your embedding, then you get a learnable slope and intercept from a linear/dense layer. Instead of doing two dense, one for linear and one for scale, we just do one dense that outputs 2x more output dim and we split it into two using chunk(). This gets us the scale(emb) and shift(emb). Both scale and shift are functions of the same embedding (time embedding).

I wasn't aware of them injecting the embedding differently for yemb and temb.

There are two ways of having two embeddings that I would use: 1) concatenate temb and yemb into emb, this make sense if you assume no interactions or 2) sequentially inject them x = AdaGN(x, temb) = slope(temb)GroupNorm(x)+ shift(temb) # inject temb x = act(x) # add nonlinearity x = slope(yemb)x+ shift(yemb) # inject yemb

ChintanTrivedi commented 2 years ago

Thanks for the clarifications. I read the literature again and my interpretations of ys and yb were incorrect. For now, I'm using the first embedding option for training.

AlexiaJM commented 2 years ago

No problem! Btw I made I typo, approach 2 would instead be: x = AdaGN(x, temb) = slope(temb)GroupNorm(x)+ shift(temb) # inject temb x = act(x) # add nonlinearity x = slope(yemb)x+ shift(yemb) # inject yemb

There is no double-ADAGN. But yeah approach 1 should be fine, interaction can happens anyways between yemb and temb from one layer to the next even if not within one layer.

ChintanTrivedi commented 2 years ago

Update: I trained the class-conditioned model using the first option to introduce embeddings, but it doesn't seem to correlate at all. In fact, the generated frames seem to completely ignore the class label. I'll try to add the classifier-gradient method for guided diffusion and see if that helps.