facebookresearch / deit

Official DeiT repository
Apache License 2.0
4.07k stars 556 forks source link

The concatenation of 'cls_tokens' and 'patch_embedding' is not necceassay. #188

Closed thb1314 closed 2 years ago

thb1314 commented 2 years ago

As shown in https://github.com/facebookresearch/deit/blob/main/cait_models.py#L241

x = torch.cat((cls_tokens, x), dim=1)

x = self.norm(x)
return x[:, 0]

is equivalent to

x = self.norm(cls_tokens)
return x[:, 0]

Suppose x is tensor with shape [B,N,C], because LayerNorm calculate the mean and std of the last dim of the input feature, and the shape of mean are [B,N,1], which is irrelevant to the dim B and N. Therefore, torch.cat operation seems not necceassay.

TouvronHugo commented 2 years ago

Hi @thb1314, Thanks for your comment, yes you are right the concatenation is not necessary. But this does not affect the performance and allows to have a final representation similar to that of the DeiT code therefore we keep this implementation.

Best,

Hugo