NVlabs / A-ViT

Official PyTorch implementation of A-ViT: Adaptive Tokens for Efficient Vision Transformer (CVPR 2022)
Apache License 2.0
138 stars 12 forks source link

Non-zero outputs from discarded tokens #13

Open bartwojcik opened 9 months ago

bartwojcik commented 9 months ago

Hi! Congratulations on the great paper.

These lines are concerning to me:

else:
    x = x + self.drop_path(self.attn(self.norm1(x*(1-mask).view(bs, token, 1))*(1-mask).view(bs, token, 1), mask=mask))
    x = x + self.drop_path(self.mlp(self.norm2(x*(1-mask).view(bs, token, 1))*(1-mask).view(bs, token, 1)))

I can see two issues here:

  1. Masking of layer normalization inputs seems redundant. To see what I mean:
    
    In [54]: some_tokens = torch.randn(2, 5, 10)

In [55]: ln = torch.nn.LayerNorm(10)

In [56]: continue_mask = torch.zeros(2, 5, 1)

In [57]: continue_mask[0, :3] = 1.0

In [58]: continue_mask[1, 2:] = 1.0

In [59]: ln(continue_mask some_tokens) continue_mask == ln(some_tokens) * continue_mask Out[59]: tensor([[[True, True, True, True, True, True, True, True, True, True], [True, True, True, True, True, True, True, True, True, True], [True, True, True, True, True, True, True, True, True, True], [True, True, True, True, True, True, True, True, True, True], [True, True, True, True, True, True, True, True, True, True]],

    [[True, True, True, True, True, True, True, True, True, True],
     [True, True, True, True, True, True, True, True, True, True],
     [True, True, True, True, True, True, True, True, True, True],
     [True, True, True, True, True, True, True, True, True, True],
     [True, True, True, True, True, True, True, True, True, True]]])

2. On the other hand the outputs of MLP or MHA are not masked. Since MLP consists of two linear layers with biases (and an activation function in between, of course) - see [here](https://github.com/NVlabs/A-ViT/blob/master/timm/models/act_vision_transformer.py#L221) and [here](https://github.com/NVlabs/A-ViT/blob/master/timm/models/layers/mlp.py#L9-L27) - the outputs of that module are not zero. While this can be seen as a kind of bias that is added to the output when the token has already been dropped, it goes against the spirit of: 1. not adding any parameters; and 2. reducing compute/FLOPs - as described in your paper. This seems to be a bug to me.

Please correct me if I am wrong.