kahnchana / svt

Official repository for "Self-Supervised Video Transformer" (CVPR'22)
https://kahnchana.github.io/svt
MIT License
104 stars 22 forks source link

Code for loading Spatial Attention weights #7

Closed ChengzhiCU closed 2 years ago

ChengzhiCU commented 2 years ago

In the paper, you mentioned that "We randomly initialize weights relevant to temporal attention while spatial attention weights are initialized using a ViT model trained in a self-supervised manner over the ImageNet-1k dataset". May I ask if you have open sourced this code in the repo? If not, when will you release it?

kahnchana commented 2 years ago

This is already released; the code should load these by default in this manner.

ChengzhiCU commented 2 years ago

Thank you for your quick response. Can you point me to where does the loading happen? I may have missed it, but I failed to find where that happens. Thanks again!

kahnchana commented 2 years ago

https://github.com/kahnchana/svt/blob/master/models/timesformer.py#L28

https://github.com/kahnchana/svt/blob/master/models/timesformer.py#L604

javierselva commented 1 year ago

Hi! First of all, thank you for providing the code for this fantastic paper.

I've been trying to run some experiments using the ImageNet pre-trained weights from DINO, to compare it to training from scratch on video data. I am not completely sure that these are loaded by default as it is now. The default values on default_cfgs (timesformer.py#L28) are only used in timesformer.py#L391 within class vit_base_patch16_224 and in timesformer.py#L416 within class TimeSformer (but I haven't been able to find anywhere in the code where these classes are used), as well as within functions get_vit_base_patch16_224 and get_aux_token_vit (both in timesformer.py #L592 and #L612 respectively).

The line #L604 you mention (within get_aux_token_vit) does indeed pre-load the weights, but as far as I understand it this function is only called if MODEL.TWO_TOKEN is True (see train_ssl.py#L204). As MODEL.TWO_TOKEN is set to False by default in scripts/train.sh, get_vit_base_patch16_224 (#L592) is called instead, which only pre-loads if cfg.TIMESFORMER.PRETRAINED_MODEL evaluates to True, which it doesn't, as it defaults to an empty string (see utils/defaults.py#L259).

I think that I can get to modify the code so it actually pre-loads by modifying get_vit_base_patch16_224 to pre-load weights similarly to get_aux_token_vit, but wanted to make sure I am not mistaken. I see that the only difference between these two functions (aside from the weight loading) is one extra token... So what does TWO_TOKEN stand for? What does it change aside from this? I have not been able to find any comment documenting this in the code.

Thanks a lot in advance!