hologerry / IMF

[CVPR 2024] Implicit Motion Function
2 stars 0 forks source link

Question on tensor sizes before cross attention #4

Closed johndpope closed 2 months ago

johndpope commented 2 months ago

before we do the implicitmotionfunction

i have the dense_feature_encoding - 128,256,512,512 HxW

densefeatureencoder

but then the latentTokenDecoder (reversed order) = 512,512,512,256

latentTokenDecoder

so i end up with these dims

f:torch.Size([2, 128, 64, 64])
f:torch.Size([2, 256, 32, 32])
f:torch.Size([2, 512, 16, 16])
f:torch.Size([2, 512, 8, 8])

m_r:torch.Size([2, 256, 64, 64])
m_r:torch.Size([2, 512, 32, 32])
m_r:torch.Size([2, 512, 16, 16])
m_r:torch.Size([2, 512, 8, 8])

m_c:torch.Size([2, 256, 64, 64])
m_c:torch.Size([2, 512, 32, 32])
m_c:torch.Size([2, 512, 16, 16])
m_c:torch.Size([2, 512, 8, 8])

and then i do implicit_motion_alignment -

but it just seems off that their different sizes.

IMFModel


 def forward(self, x_current, x_reference):
        x_current = x_current.requires_grad_()
        x_reference = x_reference.requires_grad_()

        # Dense feature encoding
        f_r = self.dense_feature_encoder(x_reference)

        # Latent token encoding
        t_r = self.latent_token_encoder(x_reference)
        t_c = self.latent_token_encoder(x_current)

        # StyleGAN2-like mapping network
        t_r = self.mapping_network(t_r)
        t_c = self.mapping_network(t_c)

        # Add noise to latent tokens
        t_r = self.add_noise(t_r)
        t_c = self.add_noise(t_c)

        # Apply style mixing
        t_c, t_r = self.style_mixing(t_c, t_r)

        # Latent token decoding
        m_r = self.latent_token_decoder(t_r)
        m_c = self.latent_token_decoder(t_c)

        # Implicit motion alignment with noise injection
        aligned_features = []
        for i in range(len(self.implicit_motion_alignment)):
            f_r_i = f_r[i]
            m_r_i = self.noise_injection(m_r[i])
            m_c_i = self.noise_injection(m_c[i])
            align_layer = self.implicit_motion_alignment[i]
            aligned_feature = align_layer(m_c_i, m_r_i, f_r_i)
            aligned_features.append(aligned_feature)

        # Frame decoding
        reconstructed_frame = self.frame_decoder(aligned_features)

        return reconstructed_frame, {
            'dense_features': f_r,
            'latent_tokens': (t_c, t_r),
            'motion_features': (m_c, m_r),
            'aligned_features': aligned_features
        }

are you able to log the forward pass of imf (for the 4 layers)? I can adjust either way to make the it higher or lower - but when i train i end up results that are a bit off and I'm not why. https://github.com/johndpope/IMF/pull/17

specifically - here https://wandb.ai/snoozie/IMF/runs/zh1o9mo0/workspace?nw=nwusersnoozie

class ImplicitMotionAlignment(nn.Module):
    def __init__(self, feature_dim, motion_dim, depth=2, heads=8, dim_head=64, mlp_dim=1024):
        super().__init__()
        self.cross_attention = CrossAttentionModule(feature_dim, motion_dim, heads, dim_head)
        # x4
        self.transformer_blocks = nn.ModuleList([
            TransformerBlock....
        ])

    def forward(self, ml_c, ml_r, fl_r):
        print(f"ml_c:{ml_c.shape},ml_r:{ml_r.shape} fl_r:{fl_r.shape}") <-- here

https://github.com/johndpope/IMF/blob/main/vit.py

johndpope commented 2 months ago

@hologerry - can you confirm that the paper is correct and your using 2 different sized tensors to calculate the attention ? Or that image is wrong and if so - what’s the correct dimensions….

hologerry commented 2 months ago

@johndpope Sorry for the late, I've been busy this week, I will look into the information you provided and check your implementation in this weekend.

johndpope commented 2 months ago

there's so much flux in code - this commit is my last stable run

git checkout -b "eternal-sun-1473" 24f9d4edff4aa8de17eb0eb9b0b19d1f1c57e79b

Note - this is diverging from paper


class LatentTokenDecoder(nn.Module):
    def __init__(self, latent_dim=32, const_dim=32):
        super().__init__()
        self.const = nn.Parameter(torch.randn(1, const_dim, 4, 4))

        self.style_conv_layers = nn.ModuleList([
            StyledConv(const_dim, 512, 3, latent_dim),
            StyledConv(512, 512, 3, latent_dim, upsample=True),
            StyledConv(512, 512, 3, latent_dim),
            StyledConv(512, 512, 3, latent_dim),# 512
            StyledConv(512, 512, 3, latent_dim, upsample=True),
            StyledConv(512, 512, 3, latent_dim),
            StyledConv(512, 512, 3, latent_dim),# 512
            StyledConv(512, 512, 3, latent_dim, upsample=True),
            StyledConv(512, 512, 3, latent_dim),
            StyledConv(512, 256, 3, latent_dim),# 512 ? 🤷 or 256??? https://github.com/hologerry/IMF/issues/4
            StyledConv(256, 256, 3, latent_dim, upsample=True),
            StyledConv(256, 256, 3, latent_dim),
            StyledConv(256, 128, 3, latent_dim) #  256 ? 🤷 or 128 ??
        ])

    def forward(self, t):
        x = self.const.repeat(t.shape[0], 1, 1, 1)
        m1, m2, m3, m4 = None, None, None, None
        for i, layer in enumerate(self.style_conv_layers):
            x = layer(x, t)
            if i == 3:
                m1 = x
            elif i == 6:
                m2 = x
            elif i == 9:
                m3 = x
            elif i == 12:
                m4 = x
        return m4, m3, m2, m1 

https://wandb.ai/snoozie/IMF/runs/aged896s/workspace?nw=nwusersnoozie

main branch is suffering modal collapse.

Q) did you use EMA / ADA? q) did you considering just using resnet50 as dense feature encoder? q) are you using gradient clipping? I have it set at 0.75 are you using mixed precision? q) did you consider just using a vanilla multiscale discriminator ? or is there something important about the 2 scales? q) i had some code to work with mp4s - creating npz cached tensor but i abandon and just use images rendered out. do you do the same?

UPDATE fyi - this is my model training - i don't know why - but it suffers from catestrophic artifacts https://wandb.ai/snoozie/IMF/runs/qaw8axgg?nw=nwusersnoozie maybe vanishing gradients ... up to 36 its fine and working as expected.

hologerry commented 2 months ago

1. tensor sizes:

The log you provided in

so i end up with these dims f:torch.Size([2, 128, 64, 64]) f:torch.Size([2, 256, 32, 32]) f:torch.Size([2, 512, 16, 16]) f:torch.Size([2, 512, 8, 8])

m_r:torch.Size([2, 256, 64, 64]) m_r:torch.Size([2, 512, 32, 32]) m_r:torch.Size([2, 512, 16, 16]) m_r:torch.Size([2, 512, 8, 8])

m_c:torch.Size([2, 256, 64, 64]) m_c:torch.Size([2, 512, 32, 32]) m_c:torch.Size([2, 512, 16, 16]) m_c:torch.Size([2, 512, 8, 8])

is correct.

As I mentioned in https://github.com/hologerry/IMF/issues/3#issuecomment-2255369735, the tensors' H W should be same, the motion features should contain the same channel dim, the reference features do not need be the same. Here is a verification code:

import torch
import torch.nn as nn
import torch.nn.functional as F

class Attention(nn.Module):
    def __init__(
        self,
        dim_spatial=4096,
        dim_qk=256,
        dim_v=256,
        **kwargs,
    ):
        super().__init__()

        self.scale = dim_qk**-0.5

        self.q_pos_embedding = nn.Parameter(torch.randn(1, dim_spatial, dim_qk))
        self.k_pos_embedding = nn.Parameter(torch.randn(1, dim_spatial, dim_qk))

        self.attend = nn.Softmax(dim=-1)

    def forward(self, queries, keys, values):
        # (b, dim_qk, h, w) -> (b, dim_qk, dim_spatial) -> (b, dim_spatial, dim_qk)
        q = torch.flatten(queries, start_dim=2).transpose(-1, -2)
        q = q + self.q_pos_embedding  # (b, dim_spatial, dim_qk)

        # in paper, key dim_spatial may be different from query dim_spatial
        # (b, dim_qk, h, w) -> (b, dim_qk, dim_spatial) -> (b, dim_spatial, dim_qk)
        k = torch.flatten(keys, start_dim=2).transpose(-1, -2)
        k = k + self.k_pos_embedding  # (b, dim_spatial, dim_qk)
        # (b, dim_v, h, w) -> (b, dim_v, dim_spatial) -> (b, dim_spatial, dim_v)
        v = torch.flatten(values, start_dim=2).transpose(-1, -2)

        # # (b, dim_spatial, dim_qk) * (b, dim_qk, dim_spatial) -> (b, dim_spatial, dim_spatial)
        dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale

        attn = self.attend(dots)  # (b, dim_spatial, dim_spatial)

        # (b, dim_spatial, dim_spatial) * (b, dim_spatial, dim_v) -> (b, dim_spatial, dim_v)
        out = torch.matmul(attn, v)

        # Or the torch version fast attention
        # out = F.scaled_dot_product_attention(q, k, v)
        out = torch.reshape(out.transpose(-1, -2), values.shape)  # (b, dim_spatial, dim_v) -> (b, dim_v, h, w)

        return out

if __name__ == "__main__":
    # (b, c, h, w)
    queries = torch.randn(1, 128, 32, 32)
    keys = torch.randn(1, 128, 32, 32)
    values = torch.randn(1, 256, 32, 32)

    attn = Attention(dim_spatial=32 * 32, dim_qk=128, dim_v=256)
    out = attn(queries, keys, values)
    print(out.shape)
hologerry commented 2 months ago

2. LatentTokenEncoder and LatentTokenDecoder

We use the stylegan2-pytorch implementations: https://github.com/rosinality/stylegan2-pytorch, You can refer to this repo https://github.com/wyhsirius/LIA/ for StyleGAN2 related code, such as Encoder, Decoder. We turn OFF ALL NOISE.

3. model collapse

Maybe you should increase the perceptual loss weight. Try other perceptual loss implementation. In my experiments, the perceptual loss is crucial, the discriminator loss is important for realism.

4. questions

(1). We do not use EMA or ADA. EMA is worth a try; it does work in my other projects. But ADA is not necessary. (2). We do not use any pretrained feature encoders. (3). We do not use gradient clipping or mix precision. (I think these should only be applied after the model is working) (4). I think MultiScaleDiscriminator is important, but we do not fully ablated it as it is basicly the default setting for GAN model. (It's old.., Diffusion Model is the way) (5). During training, we use imageio save mp4 files, convert tensor to numpy frames and save it.

johndpope commented 2 months ago

Thanks Yue. Have a great weekend.

johndpope commented 2 months ago

sorry - one more thing - do you use mixed_precision? i set it to no and discriminator loss goes down 1000x

hologerry commented 2 months ago

We do not use mixed precision. Regarding the loss range, our perceptual loss decreases from above 150 to 50/60, the pixel loss decreases from above 0.5 to 0.01/0.02, and the GAN loss oscillates around 1.0.