sovrasov / flops-counter.pytorch

Flops counter for convolutional networks in pytorch framework
MIT License
2.82k stars 307 forks source link

interpolation hook returns wrong result when `scale_factor` is scalar #144

Closed rtyasdf closed 1 month ago

rtyasdf commented 1 month ago

Consider following example:

import torch
import torch.nn as nn
from ptflops import get_model_complexity_info
from ptflops.flops_counter import FLOPS_BACKEND

class CustomModel(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        return nn.functional.interpolate(input=x, scale_factor=2,
                                         mode='bilinear', align_corners=False)

def custom_constructor(shape):
    return torch.randn(*shape)

if __name__ == '__main__':
    macs, params = get_model_complexity_info(CustomModel(), (2000, 3, 10, 10),
                                             as_strings=False, print_per_layer_stat=False,
                                             input_constructor=custom_constructor, backend=FLOPS_BACKEND.PYTORCH)

To compute value of one element in output tensor bilinear interpolation spends 1 MaC (actually 4, but assume 1 for simplicity, support of "interpolation modes" is another issue). So it's expected that tensor of shape (2000, 3, 10, 10), when interpolated on spatial dimensions with scale = 2, will be of shape (2000, 3, 20, 20). So totally 2000 3 20 * 20 = 2.4 GMaC.

At the same time, call to get_model_complexity_info raises OverflowError: int too large to convert to float.

Problem is caused by wrong computation in _interpolate_fucntional_flops_hook. In case of scalar scale_factor following lines get called:

flops = input.numel()
flops *= scale_factor ** len(input)  # which is equivalent to flops *= scale_factor ** input.shape[0]

So for flops to be equal to number of elements in output tensor, last line should be changed to:

flops *= scale_factor ** (input.dim() - 2) # covering cases of 3d, 4d and 5d input to F.interpolate
sovrasov commented 1 month ago

@rtyasdf you're right, the hook calculates nonsense now. I should also fix the size case as well