Closed johndpope closed 2 months ago
@hologerry - can you confirm that the paper is correct and your using 2 different sized tensors to calculate the attention ? Or that image is wrong and if so - what’s the correct dimensions….
@johndpope Sorry for the late, I've been busy this week, I will look into the information you provided and check your implementation in this weekend.
there's so much flux in code - this commit is my last stable run
git checkout -b "eternal-sun-1473" 24f9d4edff4aa8de17eb0eb9b0b19d1f1c57e79b
Note - this is diverging from paper
class LatentTokenDecoder(nn.Module):
def __init__(self, latent_dim=32, const_dim=32):
super().__init__()
self.const = nn.Parameter(torch.randn(1, const_dim, 4, 4))
self.style_conv_layers = nn.ModuleList([
StyledConv(const_dim, 512, 3, latent_dim),
StyledConv(512, 512, 3, latent_dim, upsample=True),
StyledConv(512, 512, 3, latent_dim),
StyledConv(512, 512, 3, latent_dim),# 512
StyledConv(512, 512, 3, latent_dim, upsample=True),
StyledConv(512, 512, 3, latent_dim),
StyledConv(512, 512, 3, latent_dim),# 512
StyledConv(512, 512, 3, latent_dim, upsample=True),
StyledConv(512, 512, 3, latent_dim),
StyledConv(512, 256, 3, latent_dim),# 512 ? 🤷 or 256??? https://github.com/hologerry/IMF/issues/4
StyledConv(256, 256, 3, latent_dim, upsample=True),
StyledConv(256, 256, 3, latent_dim),
StyledConv(256, 128, 3, latent_dim) # 256 ? 🤷 or 128 ??
])
def forward(self, t):
x = self.const.repeat(t.shape[0], 1, 1, 1)
m1, m2, m3, m4 = None, None, None, None
for i, layer in enumerate(self.style_conv_layers):
x = layer(x, t)
if i == 3:
m1 = x
elif i == 6:
m2 = x
elif i == 9:
m3 = x
elif i == 12:
m4 = x
return m4, m3, m2, m1
https://wandb.ai/snoozie/IMF/runs/aged896s/workspace?nw=nwusersnoozie
main branch is suffering modal collapse.
Q) did you use EMA / ADA? q) did you considering just using resnet50 as dense feature encoder? q) are you using gradient clipping? I have it set at 0.75 are you using mixed precision? q) did you consider just using a vanilla multiscale discriminator ? or is there something important about the 2 scales? q) i had some code to work with mp4s - creating npz cached tensor but i abandon and just use images rendered out. do you do the same?
UPDATE fyi - this is my model training - i don't know why - but it suffers from catestrophic artifacts https://wandb.ai/snoozie/IMF/runs/qaw8axgg?nw=nwusersnoozie maybe vanishing gradients ... up to 36 its fine and working as expected.
The log you provided in
so i end up with these dims f:torch.Size([2, 128, 64, 64]) f:torch.Size([2, 256, 32, 32]) f:torch.Size([2, 512, 16, 16]) f:torch.Size([2, 512, 8, 8])
m_r:torch.Size([2, 256, 64, 64]) m_r:torch.Size([2, 512, 32, 32]) m_r:torch.Size([2, 512, 16, 16]) m_r:torch.Size([2, 512, 8, 8])
m_c:torch.Size([2, 256, 64, 64]) m_c:torch.Size([2, 512, 32, 32]) m_c:torch.Size([2, 512, 16, 16]) m_c:torch.Size([2, 512, 8, 8])
is correct.
As I mentioned in https://github.com/hologerry/IMF/issues/3#issuecomment-2255369735, the tensors' H W should be same, the motion features should contain the same channel dim, the reference features do not need be the same. Here is a verification code:
import torch
import torch.nn as nn
import torch.nn.functional as F
class Attention(nn.Module):
def __init__(
self,
dim_spatial=4096,
dim_qk=256,
dim_v=256,
**kwargs,
):
super().__init__()
self.scale = dim_qk**-0.5
self.q_pos_embedding = nn.Parameter(torch.randn(1, dim_spatial, dim_qk))
self.k_pos_embedding = nn.Parameter(torch.randn(1, dim_spatial, dim_qk))
self.attend = nn.Softmax(dim=-1)
def forward(self, queries, keys, values):
# (b, dim_qk, h, w) -> (b, dim_qk, dim_spatial) -> (b, dim_spatial, dim_qk)
q = torch.flatten(queries, start_dim=2).transpose(-1, -2)
q = q + self.q_pos_embedding # (b, dim_spatial, dim_qk)
# in paper, key dim_spatial may be different from query dim_spatial
# (b, dim_qk, h, w) -> (b, dim_qk, dim_spatial) -> (b, dim_spatial, dim_qk)
k = torch.flatten(keys, start_dim=2).transpose(-1, -2)
k = k + self.k_pos_embedding # (b, dim_spatial, dim_qk)
# (b, dim_v, h, w) -> (b, dim_v, dim_spatial) -> (b, dim_spatial, dim_v)
v = torch.flatten(values, start_dim=2).transpose(-1, -2)
# # (b, dim_spatial, dim_qk) * (b, dim_qk, dim_spatial) -> (b, dim_spatial, dim_spatial)
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
attn = self.attend(dots) # (b, dim_spatial, dim_spatial)
# (b, dim_spatial, dim_spatial) * (b, dim_spatial, dim_v) -> (b, dim_spatial, dim_v)
out = torch.matmul(attn, v)
# Or the torch version fast attention
# out = F.scaled_dot_product_attention(q, k, v)
out = torch.reshape(out.transpose(-1, -2), values.shape) # (b, dim_spatial, dim_v) -> (b, dim_v, h, w)
return out
if __name__ == "__main__":
# (b, c, h, w)
queries = torch.randn(1, 128, 32, 32)
keys = torch.randn(1, 128, 32, 32)
values = torch.randn(1, 256, 32, 32)
attn = Attention(dim_spatial=32 * 32, dim_qk=128, dim_v=256)
out = attn(queries, keys, values)
print(out.shape)
We use the stylegan2-pytorch implementations: https://github.com/rosinality/stylegan2-pytorch, You can refer to this repo https://github.com/wyhsirius/LIA/ for StyleGAN2 related code, such as Encoder, Decoder. We turn OFF ALL NOISE.
Maybe you should increase the perceptual loss weight. Try other perceptual loss implementation. In my experiments, the perceptual loss is crucial, the discriminator loss is important for realism.
(1). We do not use EMA or ADA. EMA is worth a try; it does work in my other projects. But ADA is not necessary.
(2). We do not use any pretrained feature encoders.
(3). We do not use gradient clipping or mix precision. (I think these should only be applied after the model is working)
(4). I think MultiScaleDiscriminator is important, but we do not fully ablated it as it is basicly the default setting for GAN model. (It's old.., Diffusion Model is the way)
(5). During training, we use imageio
save mp4 files, convert tensor to numpy frames and save it.
Thanks Yue. Have a great weekend.
sorry - one more thing - do you use mixed_precision? i set it to no and discriminator loss goes down 1000x
We do not use mixed precision. Regarding the loss range, our perceptual loss decreases from above 150 to 50/60, the pixel loss decreases from above 0.5 to 0.01/0.02, and the GAN loss oscillates around 1.0.
before we do the implicitmotionfunction
i have the dense_feature_encoding - 128,256,512,512 HxW
but then the latentTokenDecoder (reversed order) = 512,512,512,256
so i end up with these dims
and then i do implicit_motion_alignment -
but it just seems off that their different sizes.
IMFModel
are you able to log the forward pass of imf (for the 4 layers)? I can adjust either way to make the it higher or lower - but when i train i end up results that are a bit off and I'm not why. https://github.com/johndpope/IMF/pull/17
specifically - here https://wandb.ai/snoozie/IMF/runs/zh1o9mo0/workspace?nw=nwusersnoozie
https://github.com/johndpope/IMF/blob/main/vit.py