sovrasov / flops-counter.pytorch

Flops counter for convolutional networks in pytorch framework
MIT License
2.82k stars 307 forks source link

MultiheadAttention 0 MACs #65

Closed lawlict closed 3 years ago

lawlict commented 3 years ago

Hi, I run the torch.nn.MultiheadAttention model and find it is 0 MACs. The simplified code is shown as follows. Could anyone give me a hand? The pytorch version is 1.8.1.

import torch
import torch.nn as nn
from ptflops import get_model_complexity_info

class NN(nn.Module):
    def __init__(self, in_size, nhead):
        super().__init__()
        self.net = nn.MultiheadAttention(in_size, nhead)

    def forward(self, x):
        x = x.transpose(0, 1)
        x = self.net(x, x, x)[0]
        x = x.transpose(0, 1)
        return x

model = NN(512, 4)
model.eval()

with torch.no_grad():
    macs, params = get_model_complexity_info(
        model, (580, 512), as_strings=True, print_per_layer_stat=True, verbose=True
    )
    print('{:<30}  {:<8}'.format('Computational complexity: ', macs))
    print('{:<30}  {:<8}'.format('Number of parameters: ', params))

Error message:

~/anaconda3/envs/torch1.8/lib/python3.7/site-packages/ptflops/flops_counter.py in flops_repr(self)
    122                           flops_to_string(accumulated_flops_cost,
    123                                           units=units, precision=precision),
--> 124                           '{:.3%} MACs'.format(accumulated_flops_cost / total_flops),
    125                           self.original_extra_repr()])
    126

ZeroDivisionError: float division by zero
sovrasov commented 3 years ago

@lawlict please check the implementation of the MHA hook form #68