import torch
import torch.nn as nn
from ptflops import get_model_complexity_info
from ptflops.flops_counter import FLOPS_BACKEND
class CustomModel(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
return nn.functional.interpolate(input=x, scale_factor=2,
mode='bilinear', align_corners=False)
def custom_constructor(shape):
return torch.randn(*shape)
if __name__ == '__main__':
macs, params = get_model_complexity_info(CustomModel(), (2000, 3, 10, 10),
as_strings=False, print_per_layer_stat=False,
input_constructor=custom_constructor, backend=FLOPS_BACKEND.PYTORCH)
To compute value of one element in output tensor bilinear interpolation spends 1 MaC (actually 4, but assume 1 for simplicity, support of "interpolation modes" is another issue). So it's expected that tensor of shape (2000, 3, 10, 10), when interpolated on spatial dimensions with scale = 2, will be of shape (2000, 3, 20, 20). So totally 2000 3 20 * 20 = 2.4 GMaC.
At the same time, call to get_model_complexity_info raises OverflowError: int too large to convert to float.
Problem is caused by wrong computation in _interpolate_fucntional_flops_hook. In case of scalar scale_factor following lines get called:
flops = input.numel()
flops *= scale_factor ** len(input) # which is equivalent to flops *= scale_factor ** input.shape[0]
So for flops to be equal to number of elements in output tensor, last line should be changed to:
flops *= scale_factor ** (input.dim() - 2) # covering cases of 3d, 4d and 5d input to F.interpolate
Consider following example:
To compute value of one element in output tensor bilinear interpolation spends 1 MaC (actually 4, but assume 1 for simplicity, support of "interpolation modes" is another issue). So it's expected that tensor of shape (2000, 3, 10, 10), when interpolated on spatial dimensions with scale = 2, will be of shape (2000, 3, 20, 20). So totally 2000 3 20 * 20 = 2.4 GMaC.
At the same time, call to
get_model_complexity_info
raisesOverflowError: int too large to convert to float
.Problem is caused by wrong computation in
_interpolate_fucntional_flops_hook
. In case of scalarscale_factor
following lines get called:So for
flops
to be equal to number of elements in output tensor, last line should be changed to: