NVlabs / A-ViT

Official PyTorch implementation of A-ViT: Adaptive Tokens for Efficient Vision Transformer (CVPR 2022)
Apache License 2.0
138 stars 12 forks source link

A question about the halting score distribution code #4

Open DYZhang09 opened 1 year ago

DYZhang09 commented 1 year ago

In the paper, the halting score distribution is defined as below:

image

However, the corresponding code seems wrong. https://github.com/NVlabs/A-ViT/blob/120c9cb90acf86828f1c61dd42c08722aa7173c7/timm/models/act_vision_transformer.py#L464-L465

The shape of h_lst[1] is [B, N], so the code seems to average on the whole batch and ignores the first sample of each batch. I think the right code is: self.halting_score_layer.append(torch.mean(h_lst[1][:, 1:], dim=-1))

Can you tell me which one is correct? Thanks!

Ther-nullptr commented 1 year ago

I have the same question