lucidrains / TimeSformer-pytorch

Implementation of TimeSformer from Facebook AI, a pure attention-based solution for video classification
MIT License
686 stars 86 forks source link

fix runtime error in SpaceTime Attention #21

Closed adam-mehdi closed 3 years ago

adam-mehdi commented 3 years ago

There is a shape mismatch error in Attention. When we splice out the classification token from the first token of each sequence in q, k and v, the shape becomes (batch_size * num_heads, num_frames * num_patches - 1, head_dim). Then we try to reshape the tensor by taking out a factor of num_frames or num_patches (depending on whether it is space or time attention) from dimension 1. That doesn't work because we subtracted out the classification token.

I found that performing the rearrange operation before splicing the token fixes the issue.

I recreate the problem and illustrate the solution in this notebook: https://colab.research.google.com/drive/1lHFcn_vgSDJNSqxHy7rtqhMVxe0nUCMS?usp=sharing.

By the way, thank you to @lucidrains; all of your implementations on attention-based models are helping me more than you know.

adam-mehdi commented 3 years ago

Correction: The shape was (batch_size * num_heads, num_frames * num_patches, head_dim) as we spliced out the classification token. That's because we concatted the classification token in TimeSformer. The error was in my test then.