sovrasov / flops-counter.pytorch

Flops counter for convolutional networks in pytorch framework
MIT License
2.83k stars 306 forks source link

FLOPs for a linear layer with 3D input #121

Closed martinferianc closed 1 year ago

martinferianc commented 1 year ago

Hey! Thanks for amazing work. I want to ask, is the FLOPs counting method for 3D input for a linear layer correct?

Right now it is:

input = input[0]
# pytorch checks dimensions, so here we don't care much
output_last_dim = output.shape[-1]
bias_flops = output_last_dim if module.bias is not None else 0
module.__flops__ += int(np.prod(input.shape, dtype=np.int64) *
                            output_last_dim + bias_flops)

The bias is counted only once for a 3D input, if I am not mistaken bias is added to each element of the output, so shouldn't it be:

input = input[0]
output_last_dim = output.shape[-1]
bias_flops = output_last_dim if module.bias is not None else 0
if len(input.shape) == 3:
    B, C, D = input.shape
    module.__flops__ += int((D * output_last_dim + bias_flops) * B * C)
else:
    B, D = input.shape
    module.__flops__ += int((D * output_last_dim + bias_flops) * B)
sovrasov commented 1 year ago

@martinferianc you're right, pytorch supports any number of dimensions that can come before the feature dimension. I'll change to computations accordingly. Initially that hook supposes 1d case, which was massively used in old-fashioned CNNs

martinferianc commented 1 year ago

Understood and thank you!