def linear_flops_counter_hook(module, input, output):
input = input[0]
# pytorch checks dimensions, so here we don't care much
output_last_dim = output.shape[-1]
bias_flops = output_last_dim if module.bias is not None else 0
module.__flops__ += int(np.prod(input.shape) * output_last_dim + bias_flops) # bug in this line !!!!!!!!!!!!!!!!!!!!!!!
The default type of np.prod(input.shape) is np.int32, so I change it to np.prod(input.shape,dtype=np.int64).
Then, the bug was fixed
when I used ptflops to calculate the MACs of the ViT on Windows computer, I found the MACs of some module is negative.
I found this bug in this code:
The default type of
np.prod(input.shape)
is np.int32, so I change it tonp.prod(input.shape,dtype=np.int64)
. Then, the bug was fixed