lucidrains / vit-pytorch

Implementation of Vision Transformer, a simple way to achieve SOTA in vision classification with only a single transformer encoder, in Pytorch
MIT License
20.21k stars 3.02k forks source link

How to calculate Params and Flops for Vit? #225

Open myzhuang opened 2 years ago

myzhuang commented 2 years ago

When I use torchstat to calculate Params and Flops for Vit, some errors happened. After debug I found some ops not supported. class Attention(nn.Module): def forward(self, x): qkv = self.to_qkv(x).chunk(3, dim=-1) q, k, v = map(lambda t: rearrange(t, 'b p n (h d) -> b p h n d', h=self.heads), qkv) Here chunk op can not be supported. Which tool shuold I use to calculate Params and Flops? Thanks!

code: from torchstat import stat mbvit_xs = MobileViT( image_size = (256, 256), dims = [96, 120, 144], channels = [16, 32, 48, 48, 64, 64, 64, 32, 16, 2], num_classes = 1000 ) stat(mbvit_xs, (3, 256, 256))

Error print: File "mobile_vit.py", line 410, in stat(mbvit_xs, (3, 256, 256)) File "/home2/zmy/anaconda3/envs/py38/lib/python3.8/site-packages/torchstat/statistics.py", line 71, in stat ms.show_report() File "/home2/zmy/anaconda3/envs/py38/lib/python3.8/site-packages/torchstat/statistics.py", line 64, in show_report collected_nodes = self._analyze_model() File "/home2/zmy/anaconda3/envs/py38/lib/python3.8/site-packages/torchstat/statistics.py", line 57, in _analyze_model model_hook = ModelHook(self._model, self._input_size) File "/home2/zmy/anaconda3/envs/py38/lib/python3.8/site-packages/torchstat/model_hook.py", line 24, in init self._model(x) File "/home2/zmy/anaconda3/envs/py38/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl return forward_call(*input, kwargs) File "mobile_vit.py", line 383, in forward h5 = self.trunk12(h4) File "/home2/zmy/anaconda3/envs/py38/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl return forward_call(*input, *kwargs) File "/home2/zmy/anaconda3/envs/py38/lib/python3.8/site-packages/torch/nn/modules/container.py", line 139, in forward input = module(input) File "/home2/zmy/anaconda3/envs/py38/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl return forward_call(input, kwargs) File "mobile_vit.py", line 186, in forward x = self.transformer(x) File "/home2/zmy/anaconda3/envs/py38/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl return forward_call(*input, kwargs) File "mobile_vit.py", line 111, in forward x = attn(x) + x File "/home2/zmy/anaconda3/envs/py38/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl return forward_call(input, kwargs) File "mobile_vit.py", line 32, in forward return self.fn(self.norm(x), kwargs) File "/home2/zmy/anaconda3/envs/py38/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl return forward_call(input, kwargs) File "mobile_vit.py", line 71, in forward qkv = self.to_qkv(x).chunk(3, dim=-1) File "/home2/zmy/anaconda3/envs/py38/lib/python3.8/site-packages/torchstat/model_hook.py", line 76, in wrap_call madd = compute_madd(module, input[0], output) File "/home2/zmy/anaconda3/envs/py38/lib/python3.8/site-packages/torchstat/compute_madd.py", line 156, in compute_madd return compute_Linear_madd(module, inp, out) File "/home2/zmy/anaconda3/envs/py38/lib/python3.8/site-packages/torchstat/compute_madd.py", line 117, in compute_Linear_madd assert len(inp.size()) == 2 and len(out.size()) == 2 AssertionError

ChuxiaYang commented 2 years ago

I have encountered the same problem. Have you found a solution to this problem?

myzhuang commented 2 years ago

sorry, I can not find a good solution for this problem so far.

krrishdholakia commented 1 year ago

What version of TorchStat are you using? - Clerkie (https://clerkie.co/)

cc: @lucidrains trying to gather some context for future debugging - hope that's okay!

mcihadarslanoglu commented 1 year ago

I use torch-summary, maybe it can help you.

pohunshi commented 7 months ago

@myzhuang you can see this:https://github.com/Swall0w/torchstat/issues/18 the code runs succefully,but the result of MACs seems don‘t right ,campare with other methods