raoyongming / DynamicViT

[NeurIPS 2021] [T-PAMI] DynamicViT: Efficient Vision Transformers with Dynamic Token Sparsification
https://dynamicvit.ivg-research.xyz/
MIT License
551 stars 69 forks source link

Regarding connection elimination of the dropped patches. #35

Closed Alihjt closed 9 months ago

Alihjt commented 11 months ago

Hello. I have a question. Where do you, in your code, eliminate the connection of the dropped patches? I can see where you zero them out, but for inference where do you drop the connection for faster inference?

raoyongming commented 11 months ago

Hi @Alihjt, thanks for your interest in our work. This function is used to eliminate the connection based on a differentiable binary policy produced by Gumble-Softmax. In this part, we directly remove tokens based on their scores during inference.

Alihjt commented 11 months ago

@raoyongming Thank you for your response. Did you calculate your flops manually or used FlopCountAnalysis? Because I am not getting the same FLOPS for your models. For example, you reported that DynamicViT-LVViT-M/0.7 has 8.5 GFLOPS, but I got around 12.9 GFLOPS from FlopCountAnalysis:

from fvcore.nn import FlopCountAnalysis flops = FlopCountAnalysis(model, torch.zeros((1, 3, 224, 224), dtype=torch.float32, device='cuda')) print(flops.total()) 12962899456

raoyongming commented 11 months ago

We also use fvcore to compute FLOPs of DynamicViT. Please refer to this part of code. You may need to pass the right pruning ratio to the model and set the model to inference mode (model.eval()) before computing FLOPs.