sovrasov / flops-counter.pytorch

Flops counter for convolutional networks in pytorch framework
MIT License
2.82k stars 307 forks source link

The efficientnet paper insists efficientnet-b0 is 0.39 Gflops, but the measurement is 00.2Gflops #59

Closed ytfksw closed 3 years ago

ytfksw commented 3 years ago

The efficientnet paper insists efficientnet-b0 is 0.39 Gflops, but the measurement is 00.2Gflops = (2*0.01 GMac)

% python                                                                       
Python 3.6.5 (default, Nov 11 2019, 18:04:50)
[GCC 5.4.0 20160609] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> from ptflops import get_model_complexity_info
>>> from efficientnet_pytorch import EfficientNet
>>> network = EfficientNet.from_pretrained("efficientnet-b0")
Loaded pretrained weights for efficientnet-b0
>>> get_model_complexity_info(network, (3, 224, 224), print_per_layer_stat=False)
('0.01 GMac', '5.29 M')
sovrasov commented 3 years ago

This implementation of efficient net uses F.conv2d to perform convs. Unfortunately ptflops can't catch arbitrary F.* functions since it works via adding hooks to known children classes of nn.Module such as nn.Conv2d and so on. If you enable print_per_layer_stat you'll see that Conv2dStaticSamePadding module is treated as a zero-op since there is no handler for this custom operation (which is actually implemented via call of F.conv2d). That's why ptflops reports less operations than expected. This issue is described in readme (see usage tips). I guess all similar packages will experience the same problem. The only way here is to implement a custom hook for Conv2dStaticSamePadding and pass it via custom_modules_hooks parameter. I think that it can be actually the same hook as ptflops uses for a regular nn.Conv2d.

I'd recommend you to use pytorchcv. From my point of view it has a more clear implementation.

from pytorchcv.model_provider import get_model as ptcv_get_model 
net = ptcv_get_model("efficientnet_b0", pretrained=True)
get_model_complexity_info(net, (3, 224, 224), print_per_layer_stat=False) 
# ('0.4 GMac', '5.29 M')
sovrasov commented 3 years ago

Also it seems that in the original paper FLOPs means actually MACs. They are often mixed up. This table states 0.4 GFLOPs in efficientnet-b0 https://github.com/osmr/imgclsmob/tree/master/pytorch I guess the the source of their numbers is a built-in ops counter in TF or in some other framework with similar feature, so the numbers in the table are reliable.

ytfksw commented 3 years ago

@sovrasov I am grateful for your support!!

I see, I understood why it's not counted. As you pointed out, I used the pytorchcv model and got 0.4 GMAC. I also confirmed that it is counted correctly by using custom_modules_hooks!