Closed zhangbw21 closed 1 year ago
Hi, I think it is correct in the code. In timm's implementation, the shape of input for each transformer layer is (batch_size, token_num, dim). It is (batch_size, 197, 768) in ViT-B/16. See https://github.com/rwightman/pytorch-image-models/blob/7c4682dc08e3964bc6eb2479152c5cdde465a961/timm/models/vision_transformer.py#L200
hello, the shape of tensor in vit seem to be (patch_num ** 2 + 1, bs, dim), however, in the released code,
B, N, C = x.shape
is used. Is it correct or just a mistake?