sovrasov / flops-counter.pytorch

Flops counter for convolutional networks in pytorch framework
MIT License
2.83k stars 306 forks source link

support for torch.compile? #110

Open CocytusDuo opened 1 year ago

CocytusDuo commented 1 year ago

Wrong flops count if the model is compiled with torch.compile:

  1. Flops of modules in torch.nn, for example, nn.Linear, nn.Conv2d are tripled.
  2. Flops of custom modules are not counted.

Here is a code example:

import torch.nn as nn
import ptflops
import torch

class MyModule(nn.Module):
    def __init__(self) -> None:
        super().__init__()

    def forward(self, x):
        return x

class Test_model(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.linear_layer = nn.Linear(1000, 1000, bias=False)
        self.custom_layer = MyModule()

    def forward(self, x):
        out = self.linear_layer(x)
        out = self.custom_layer(out)
        return out

def mymodule_flops_counter_hook(conv_module, input: torch.Tensor, output):
    input = input[0]
    mul_count = input.numel() * 1000
    conv_module.__flops__ += int(mul_count)

MyModuleMapping = {
    MyModule: mymodule_flops_counter_hook
}

net = Test_model()
net = torch.compile(net)
print(ptflops.get_model_complexity_info(net, (1000,), custom_modules_hooks=MyModuleMapping, output_precision=3))

Output: without torch.compile: Test_model( 1.0 M, 100.000% Params, 2.0 MMac, 100.000% MACs, (linear_layer): Linear(1.0 M, 100.000% Params, 1.0 MMac, 50.000% MACs, in_features=1000, out_features=1000, bias=False) (custom_layer): MyModule(0, 0.000% Params, 1.0 MMac, 50.000% MACs, ) ) ('2.0 MMac', '1.0 M')

with torch.compile:

OptimizedModule( 1.0 M, 100.000% Params, 3.0 MMac, 100.000% MACs, (_orig_mod): Test_model( 1.0 M, 100.000% Params, 3.0 MMac, 100.000% MACs, (linear_layer): Linear(1.0 M, 100.000% Params, 3.0 MMac, 100.000% MACs, in_features=1000, out_features=1000, bias=False) (custom_layer): MyModule(0, 0.000% Params, 0.0 Mac, 0.000% MACs, ) ) ) ('3.0 MMac', '1.0 M')

sovrasov commented 1 year ago

This issue comes down to pytorch 2.0 support. I'll have a look, but probably the only solution is avoiding compile mode.