open-mmlab / mmcv

OpenMMLab Computer Vision Foundation
https://mmcv.readthedocs.io/en/latest/
Apache License 2.0
5.92k stars 1.66k forks source link

[Bug] something wrong when calculating the GLOPs for Linear #2936

Closed xushilin1 closed 1 year ago

xushilin1 commented 1 year ago

Prerequisite

Environment

MMCV==2.0

Reproduces the problem - code sample

from mmcv.cnn.utils.flops_counter import add_flops_counting_methods, print_model_with_flops,flops_to_string, params_to_string
import torch

class Net(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = torch.nn.Linear(10, 10)

    def forward(self, x):
        return self.linear(x)

model = Net()
input = torch.randn(32, 10)

flops_model = add_flops_counting_methods(model)
flops_model.eval()
flops_model.start_flops_count()
_ = flops_model(input)
flops_count, params_count = flops_model.compute_average_flops_cost()
print_model_with_flops(flops_model, flops_count, params_count)
flops_model.stop_flops_count()
print(flops_count, params_count)

Reproduces the problem - command or script

None

Reproduces the problem - error message

GFLOPs = input_batches $\times$ input_channel $\times$ output_channels = $32\times 10\times 10$. But the result of MMCV given is just $10\times10$

Additional information

No response

zhouzaida commented 1 year ago

Hi @xushilin1 , the calculation of FLOPs does not include the batch size dimension. This is because users might use different batch sizes when calculating, which would make comparisons inaccurate.

Additionally, it is recommended to use the FLOP calculation tool in mmengine. The usage documentation can be accessed by clicking https://mmengine.readthedocs.io/en/latest/common_usage/model_analysis.html

xushilin1 commented 1 year ago

Hi, Thanks for your reply! When I use the FLOP calculation tool in mmengine, I get a different FLOPs result( 3200=32x10x10 compared with 100=10x10 when using mmcv). There may be some wrong in mmcv.

from mmengine.analysis import get_model_complexity_info
model = Net()
analysis_results = get_model_complexity_info(model, (32,10))
analysis_results['flops_str']
zhouzaida commented 1 year ago

The input should not contain the batch size dim. You can see the details at https://github.com/open-mmlab/mmengine/blob/53474ef1ba0b166508c231fa525b55b580adf20f/mmengine/analysis/print_helper.py#L730.

The correct usage is as follows:

model = Net()
analysis_results = get_model_complexity_info(model, (10,))
analysis_results['flops_str']  # '100'