wilson1yan / VideoGPT

MIT License
968 stars 119 forks source link

Some questions on VideoGPT. #17

Closed kami93 closed 3 years ago

kami93 commented 3 years ago

Hi. Thanks for this great work. I have some questions on the general status of video generative modeling and the implementation sides, which I would appreciate answers from authors and other people who might be interested in making them.

Does this model worth being called "a video GPT" when compared to the community's consensus on the achievable encoding/decoding capacities and qualities of publicly available state-of-the-art video generative models?

After looking through the implementation, I quickly noticed that VideoGPT consumes video files that are composed of around 16 consecutive frames from 25+ fps raw videos (e.g., UCF-101). That is, modeling video clips that run much less than one second is the problem of interest in this model.

I am not clearly aware of the current status of video generative modeling, however, I am just curious even these short (I am not even sure if I can say ``short'', but since most YouTube video clips are, I think, at least longer than 10 seconds, I am choosing the word short.) video models are significantly harder to design/train than VQ-VAE families for still images.

The first impression I had from the name "VideoGPT" was like it can model videos of tens of seconds, just like the GPT models which can generate paragraph-length sentences. But I think, in the current form of VideoGPT, there may be lots of spaces for further improvements to actually achieve similar generative capabilities of the GPT from field of natural language generation.

Why are codebook vectors first initialized with very-specific initializers like randn, zeros, then re-initialized with latent states of a training batch?

Refering to lines 126-156 at https://github.com/wilson1yan/VideoGPT/blob/master/videogpt/vqvae.py.

I am just wondering if there are some hacky reasons behind doing this. It is obviously simpler to just initialize the codebook for just one time with those calculated latents of a training batch.

Why are Transformer and VQ-VAE separately (not jointly) trained?

I think there would be mixed reasons for this (e.g., memory consumption, training instabilities, etc.). But what is the major & most important reason not to perform joint training?

wilson1yan commented 3 years ago

Thanks for the questions!

Does this model worth being called "a video GPT" when compared to the community's consensus on the achievable encoding/decoding capacities and qualities of publicly available state-of-the-art video generative models?

Sorry for the misunderstanding, but the VideoGPT model merely adopts the architecture of the standard GPT model, but in the context for video modeling, hence VideoGPT. We primarily train on 16 frames since it's the standard in the video generation literature when evaluating models. You could scale to longer sequences and larger GPT models, but we don't have the compute to try that. But yes, you are right in that 16 frames is very short, and I believe some GAN models have gone to 32 or 50 frames, but nothing much more past that without large degradation in prediction quality.

The main reason for such short sequences is that modeling videos is prohibitively expensive. i.e. even after encoding a 16 x 64 x 64 video with a VQ-VAE, we get latent codes of size 4 x 32 x 32 = 4096, which in natural language is like training a transformer on text with 4096 tokens. A lot of related research is video generation is constructing architectures (transformer-like, or GANs) that can process and predict video efficiently.

Why are codebook vectors first initialized with very-specific initializers like randn, zeros, then re-initialized with latent states of a training batch?

In this case, the initialization of the embeddings torch.randn(n_codes, embedding_dim) actually doesn't matter since as you've said it's just reinitialized later with decoder training outputs. I think that was something leftover of past code that didn't do data-dependent initialization.

For the zeros initialization, if you mean self.N, that is keeping track of average use for each code, so it must start at zero, since no code is being used in the beginning.

The data dependent initialization and usage tracking is borrowed from the jukebox paper. In general, it helps with training stability in improving code usage and prevents codebook collapse which is pretty common with VQ-VAEs.

Why are Transformer and VQ-VAE separately (not jointly) trained? They could be jointly trained, but I imagine there might be a few difficulties: 1) Memory usage: this might be less of an issue if you have substantially more compute available (the primary reason why we didn't do it) 2) The VQ-VAE and GPT differ widely in the number of parameters, and may learn at different rates, so there may need to be adjustments in learning rate and other training hyperparmaeters (a guess since I didn't try it)

Hope that answers your questions!

kami93 commented 3 years ago

@wilson1yan Thanks much for your detailed comments. Now I better understand how is VideoGPT designed. I'm working on reproducing the results in the paper, and I have some extra questions on the implementation side. I really appreciate any help you can provide.

What exactly is "Residual Units" in Table A.1 referring to? Apparently, "Residual Units" should be referring to some parts of AttentionResidualBlock, but the current code does not separately set the hidden state size of AttentionResidualBlock. It just uses the same number as "n_hiddens".

Coefficient that the "reconstruction loss" should be multiplied by In the paper, there is no coefficient specified for the reconstruction loss (i.e., the coefficient == 1). However, in the codes (https://github.com/wilson1yan/VideoGPT/blob/master/videogpt/vqvae.py#L54), the reconstruction loss is being divided by 0.07, making the coefficient ~= 14.29. It would be helpful to know coefficients that are supposed to be used for each dataset to reproduce the result.

Additional training hyperparameters details other than the learning rate In particular, It would be really helpful for reproducibility to know which optimizer, learning rate schedule, N-fold testing scheme for datasets that do not come with preset train-valid-test sets (e.g., M-MNIST and UCF-101), random seeds, etc.

Why is "weights" instead of "self.N" used to normalize the moving-average codebook? Referring to lines 187-189 at https://github.com/wilson1yan/VideoGPT/blob/master/videogpt/vqvae.py.

What is the reason for calculating n = self.N.sum(); weights = (self.N + 1e-7) / (n + self.n_codes * 1e-7) * n, then use it instead of self.N to normalize the moving-avrage codebook? This specific weights does not exist in the jukebox code, where they normalize using self.k_elem which is equivalent to self.N in VideoGPT.

Guessing it as a heuristic that performs better on VideoGPT's settings, I still do not understand the reason to add the self.n_codes * 1e-7 to the n at the denominator, while n is supposed to be a large enough number (because n = self.N.sum() >> 1) that would not cause overflow. It would be helpful to see how weights is designed.

Thank you for your time and efforts for the codes.

wilson1yan commented 3 years ago

What exactly is "Residual Units" in Table A.1 referring to?

It refers to the bottleneck size in the residual blocks, found here. It is hard-coded to be n_hiddens // 2. I think some old code had it as an input, but I removed it later on - sorry for the confusion!

Coefficient that the "reconstruction loss" should be multiplied by

I believe originally it was supposed to be the variance of the pixels in the dataset. if I remember correclty, it was around 0.0598 for BAIR, so we just used 0.06 for everything: BAIR, UCF-101, Kinetics, etc and seemed to work fine. Using 0.06 for any dataset should work fine.

Additional training hyperparameters details other than the learning rate

The default optimizer is Adam, with a cosine annealing learning rate schedule, which should work for generally any dataset. The LR schedule is only used when training the transformer. Unfortunately, I don't think I have set seeds for the experiments, but I haven't observed any substantial variance of results between different seeds when training the same model.

Why is "weights" instead of "self.N" used to normalize the moving-average codebook?

It is applying a Laplace smoothing prior to the codebook counts. It is generally just there for stability reasons, in case a count for a particular code is zero.

kami93 commented 3 years ago

@wilson1yan Wow. Thank you for your kind and detailed explanations. I was not aware of the concept of Laplace Smoothing before. Now everything seems to make sense.