Closed PallottaEnrico closed 1 month ago
Solution: There quite a lot of things to change in order to use a different number of frames and resolution, indeed in the code the are a few TODO comments that give some hints.
Regarding the 32 and 16 things to be treated as variable, I understood it by reading carefully the paper, so you need to change it in this way:
# you can also use only one res in case you are not interested in non squared resolutions.
res1 = self.image_size
res2 = self.image_size
t = self.timesteps
h_xy = h[:, :, 0:res1*res2].view(h.size(0), h.size(1), res1, res2)
h_yt = h[:, :, res1*res2:res1*(res2+t)].view(h.size(0), h.size(1), t, res1)
h_xt = h[:, :, res1*(res2+t):res1*(res2+t+t)].view(h.size(0), h.size(1), t, res1)
And you also need to apply the reshaping logic I wrote here to the following lines in the forward pass.
This requires a few changes in the object constructor of the UNetModel class, namely:
self.timesteps = timesteps
self.ae_emb_dim = (image_size * image_size) + (timesteps * image_size) + (timesteps * image_size)
if cond_model:
self.register_buffer("zeros", torch.zeros(1, self.in_channels, self.ae_emb_dim))
So remember to pass the timesteps argument when you call the constructor (diffusion.py :: diffusion function)
Plus you need to change also the DDPM constructor adding the image_size
attribute used later for the reshaping part:
self.image_size = model.module.diffusion_model.ae_emb_dim
Another change is required in the autoencoder_vit.py ViTAutoencoder class, namely:
def decode_from_sample(self, h):
latent_res = self.res // (2**self.down)
h_xy = h[:, :, 0:latent_res*latent_res].view(h.size(0), h.size(1), latent_res, latent_res)
h_yt = h[:, :, latent_res*latent_res:latent_res*(latent_res+self.s)].view(h.size(0), h.size(1), self.s, latent_res)
h_xt = h[:, :, latent_res*(latent_res+self.s):latent_res*(latent_res+self.s+self.s)].view(h.size(0), h.size(1), self.s, latent_res)
You're almost done, the evaluation will raise an error related to the kernel size of the average pooling operation, just add another avgpool operator (avg_pool1)
self.avg_pool = nn.AvgPool3d(kernel_size=[2, 7, 7],
stride=(1, 1, 1))
self.avg_pool1 = nn.AvgPool3d(kernel_size=[1, 7, 7],
stride=(1, 1, 1))
And use the one that matches your input shape in the forward method:
if x.size(2) == 2:
x = self.logits(self.dropout(self.avg_pool(x)))
elif x.size(2) == 1:
x = self.logits(self.dropout(self.avg_pool1(x)))
The repo as a few hardcoded things that makes it difficult to use with a different setting, like different resolution or timesteps. I think I managed the resolution problem also thanks to this issue. Now I'm really struggling with the timesteps (number of frames in a video) parameter.
Apparently using a number that's not a power of two (8, 16, 32) causes problems with the UNet (when concatenating residuals with the new upsampled dim).
I managed to train the AE with timesteps 8 and res 128, so now it produces an embedding of dim [1, 4, 1536], one for the noisy frames one for the conditioning frames. I also had to change the code in the UNet that is marked with a TODO:
So I defined a variable n2 = 32 and n = n2 // 2 to replace the raw numbers. To use timesteps 8 i set n2 to 16, which I'm not sure is correct but if for timesteps 16 was 32 then the thing should hold.
The problem now is that the forward pass of the UNet produces a tensor of shape [1, 4, 512], so there's a dimension mismatch when trying to compute the loss. I'm referring to the code in the function
Which causes the following error:
@sihyun-yu Did I miss anything else that should be changed in order to make this code "timesteps adaptive" ?