sihyun-yu / PVDM

Official PyTorch implementation of Video Probabilistic Diffusion Models in Projected Latent Space (CVPR 2023).
https://sihyun.me/PVDM
MIT License
287 stars 15 forks source link

Code can't adapt to different number of timesteps #37

Closed PallottaEnrico closed 1 month ago

PallottaEnrico commented 2 months ago

The repo as a few hardcoded things that makes it difficult to use with a different setting, like different resolution or timesteps. I think I managed the resolution problem also thanks to this issue. Now I'm really struggling with the timesteps (number of frames in a video) parameter.

Apparently using a number that's not a power of two (8, 16, 32) causes problems with the UNet (when concatenating residuals with the new upsampled dim).

I managed to train the AE with timesteps 8 and res 128, so now it produces an embedding of dim [1, 4, 1536], one for the noisy frames one for the conditioning frames. I also had to change the code in the UNet that is marked with a TODO:

# TODO: treat 32 and 16 as variables
h_xy = h[:, :, 0:32*32].view(h.size(0), h.size(1), 32, 32)
h_yt = h[:, :, 32*32:32*(32+16)].view(h.size(0), h.size(1), 16, 
h_xt = h[:, :, 32*(32+16):32*(32+16+16)].view(h.size(0), h.size(1), 16, 32)

So I defined a variable n2 = 32 and n = n2 // 2 to replace the raw numbers. To use timesteps 8 i set n2 to 16, which I'm not sure is correct but if for timesteps 16 was 32 then the thing should hold.

The problem now is that the forward pass of the UNet produces a tensor of shape [1, 4, 512], so there's a dimension mismatch when trying to compute the loss. I'm referring to the code in the function

    def p_losses(self, x_start, cond, t, noise=None):
        noise = default(noise, lambda: torch.randn_like(x_start))
        x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
        model_out = self.model(x_noisy, cond, t)
        ...
        loss = self.get_loss(model_out, target, mean=False).mean(dim=[1, 2])

Which causes the following error:

RuntimeError: The size of tensor a (1536) must match the size of tensor b (512) at non-singleton dimension 2

@sihyun-yu Did I miss anything else that should be changed in order to make this code "timesteps adaptive" ?

PallottaEnrico commented 1 month ago

Solution: There quite a lot of things to change in order to use a different number of frames and resolution, indeed in the code the are a few TODO comments that give some hints.

Regarding the 32 and 16 things to be treated as variable, I understood it by reading carefully the paper, so you need to change it in this way:

# you can also use only one res in case you are not interested in non squared resolutions.
res1 = self.image_size
res2 = self.image_size
t = self.timesteps

h_xy = h[:, :, 0:res1*res2].view(h.size(0), h.size(1), res1, res2)
h_yt = h[:, :, res1*res2:res1*(res2+t)].view(h.size(0), h.size(1), t, res1)
h_xt = h[:, :, res1*(res2+t):res1*(res2+t+t)].view(h.size(0), h.size(1), t, res1)

And you also need to apply the reshaping logic I wrote here to the following lines in the forward pass.

This requires a few changes in the object constructor of the UNetModel class, namely:

self.timesteps = timesteps
self.ae_emb_dim = (image_size * image_size) + (timesteps * image_size) + (timesteps * image_size)
if cond_model:
    self.register_buffer("zeros", torch.zeros(1, self.in_channels, self.ae_emb_dim))

So remember to pass the timesteps argument when you call the constructor (diffusion.py :: diffusion function) Plus you need to change also the DDPM constructor adding the image_size attribute used later for the reshaping part:

self.image_size = model.module.diffusion_model.ae_emb_dim

Another change is required in the autoencoder_vit.py ViTAutoencoder class, namely:

def decode_from_sample(self, h):
    latent_res = self.res // (2**self.down)
    h_xy = h[:, :, 0:latent_res*latent_res].view(h.size(0), h.size(1), latent_res, latent_res)
    h_yt = h[:, :, latent_res*latent_res:latent_res*(latent_res+self.s)].view(h.size(0), h.size(1), self.s, latent_res)
    h_xt = h[:, :, latent_res*(latent_res+self.s):latent_res*(latent_res+self.s+self.s)].view(h.size(0), h.size(1), self.s, latent_res)

You're almost done, the evaluation will raise an error related to the kernel size of the average pooling operation, just add another avgpool operator (avg_pool1)

self.avg_pool = nn.AvgPool3d(kernel_size=[2, 7, 7],
                                     stride=(1, 1, 1))
self.avg_pool1 = nn.AvgPool3d(kernel_size=[1, 7, 7],
                                        stride=(1, 1, 1))

And use the one that matches your input shape in the forward method:

if x.size(2) == 2:
    x = self.logits(self.dropout(self.avg_pool(x)))
elif x.size(2) == 1:
    x = self.logits(self.dropout(self.avg_pool1(x)))