raoyongming / DynamicViT

[NeurIPS 2021] [T-PAMI] DynamicViT: Efficient Vision Transformers with Dynamic Token Sparsification
https://dynamicvit.ivg-research.xyz/
MIT License
551 stars 69 forks source link

Attention mask computation during training #11

Closed mtchiu2 closed 2 years ago

mtchiu2 commented 2 years ago

Hello,

Thank you for your work. I'm reading your implementation code for computing the attention mask. In lvvit.py line 665, you used F.gumbel_softmax(pred_score) to find the binary mask, but the last layer for computing pred_score is a log_softmax (line 529, same file). Shouldn't it be just a linear unit rather than a log_softmax, since F.gumbel_softmax takes logit rather than probability?

Thank you.

raoyongming commented 2 years ago

Hi @mtchiu2,

Thanks for your interest in our work. According to the Pytorch document, the input of F.gumbel_softmax should be unnormalized log probabilities. By comparing the original paper of Gumbel Softmax and the implementation in Pytorch, the logits should be the log pi in Eq.1 and Eq.2 of the Gumbel Softmax paper, which is the log of class probabilities. Therefore, we use log_softmax to compute the logits.

mtchiu2 commented 2 years ago

I see. Thank you very much for the quick response.