wilson1yan / VideoGPT

MIT License
962 stars 115 forks source link

Generating longer sequences #22

Closed universome closed 3 years ago

universome commented 3 years ago

Hi! Thank you for your wonderful work and for releasing the code publicly! Could you please tell whether there is an easy way (from the implementation perspective) to generate videos longer than 16 frames with your model trained on 16-frames videos?

My initial attempt was to simply edit the sample(...) method to run over the latent shape of (8, 32, 32) instead of (4, 32, 32) (for UCF-101), but that didn't work due to caching and positional embeddings being constrained to the (4, 32, 32) shape. So I suppose that I need to change other places as well to discard old context, but it's not clear what these places should be.

wilson1yan commented 3 years ago

Hi, unfortunately there isn't really any easy way for adapting the model for generalizing to longer sequences. You would need to (1) train the whole model on longer sequences or (2) adapt the code to train a conditional model in latent space (e.g. predict next from from previous frame, or next x frames from previous x frames)

You could try interpolating the positional embeddings to a larger size, since I know that generally works well when adapting ViT models on larger resolution images, but I don't think it'd work too well.

universome commented 3 years ago

Ok, after inspecting the shapes during sampling, processed by the GPT model, I noticed that attn_stack uses neither context nor conditioning information (for ucf-101), since it always operates on top of [8, 1, 1, 1, 768]-sized tensors. Do I get it right that that it is not an autoregressive model and the only place where it makes the generated sequence coherent (i.e. where the representations become dependent on one another before being passed into the decoder) is the layer norm operation?

I am asking this, because since it generates tokens one-by-one (i.e. there is a for-loop of size 4096), this made me think that it generates them autoregressively

wilson1yan commented 3 years ago

You're right in that during sampling it passes in slices, but the model is still fully autoregressive. It is using caching to decrease memory usage and increase sampling speed. The caching code can be found here. The index of the current slice is computed by applying attention across all of the previous generated / cached activations at each layer.

universome commented 3 years ago

Oh, I missed this, thank you. Then I believe it should be able to generate longer sequences without any retraining, if one carefully shifts the preceding context and adjusts the positional embeddings accordingly. For example, when we are generating a token at the position [4, 0, 0] (i.e. the 4097-th token) out of the [8, 32, 32] latent shape, then we can "pretend" like it's a token at the position [3, 0, 0] and use the previously generated tokens in positions [1, 0, 0] - [3, 31, 31] as its context.* Do you see any issue with the above reasoning? The main problem I see here is that it might be non-trivial to implement, but maybe I am missing something.

*the latent shape format has the format of [t, h, w] in the paragraph above, where t, h, w run over time/height/width axes respectively.

wilson1yan commented 3 years ago

I don't think that would work exactly as intended since each latent code is dependent on every pixel of the input video (due to the 3D convolutions, and attention layers in the VQ-VAE). The resulting latent code is also out of distribution for the VQ-VAE, so not too sure what it would output. It might end up being a similar video as when using a small latent shape [4, 32, 32] but just 2x slower.

If you want to do what you're describing, you could train a frame-wise VQ-VAE (if you wanted to do this using this codebase, you would need to remove the axial attention to remove dependency across time), and then train a VideoGPT on the stacked encoded frames. Then you could arbitrarily sample as many timesteps in the future as you want by predicting the next set of latent codes using the ones from the preceding timesteps.

The above also generalizes for any chunk of x frames. i.e. Train a VQ-VAE to encode x frames, train a VideoGPT on c chunks of encoded x frames. Then for predicting you can predict the next set of encoded x frames using the preceding c-1 chunks.

universome commented 3 years ago

Ok, I think I understand what you mean. Thank you for your explanations. And once again thank you for your nice work and for the repo!