raoyongming / DynamicViT

[NeurIPS 2021] [T-PAMI] DynamicViT: Efficient Vision Transformers with Dynamic Token Sparsification
https://dynamicvit.ivg-research.xyz/
MIT License
572 stars 71 forks source link

Regarding connection elimination of the dropped patches. #35

Closed Alihjt closed 1 year ago

Alihjt commented 1 year ago

Hello. I have a question. Where do you, in your code, eliminate the connection of the dropped patches? I can see where you zero them out, but for inference where do you drop the connection for faster inference?

raoyongming commented 1 year ago

Hi @Alihjt, thanks for your interest in our work. This function is used to eliminate the connection based on a differentiable binary policy produced by Gumble-Softmax. In this part, we directly remove tokens based on their scores during inference.

Alihjt commented 1 year ago

@raoyongming Thank you for your response. Did you calculate your flops manually or used FlopCountAnalysis? Because I am not getting the same FLOPS for your models. For example, you reported that DynamicViT-LVViT-M/0.7 has 8.5 GFLOPS, but I got around 12.9 GFLOPS from FlopCountAnalysis:

from fvcore.nn import FlopCountAnalysis flops = FlopCountAnalysis(model, torch.zeros((1, 3, 224, 224), dtype=torch.float32, device='cuda')) print(flops.total()) 12962899456

raoyongming commented 1 year ago

We also use fvcore to compute FLOPs of DynamicViT. Please refer to this part of code. You may need to pass the right pruning ratio to the model and set the model to inference mode (model.eval()) before computing FLOPs.