autoliuweijie / FastBERT

The score code of FastBERT (ACL2020)
https://www.aclweb.org/anthology/2020.acl-main.537/
604 stars 90 forks source link

Miss attention FLOPS? #14

Open YeDeming opened 4 years ago

YeDeming commented 4 years ago

Hi,

I found in MultiHeadedAttention, thop only count the FLOPS of linear layer, missing the attention operation.

autoliuweijie commented 4 years ago

Hello! Could you describe this issue in more detail? E.g, how you find that the attention operation is missed?

In our run_fastbert.py, we use the following code to obtain the total FLOPs of the model with thop.profile:

 # Get FLOPs at this batch
  inputs = (input_ids_batch, label_ids_batch, mask_ids_batch, fast_mode)
  flops, params = profile(model, inputs, verbose=False)
  total_flops += flops

And, in our previous experiment, we measured the FLOPs of self-attention operation, which is about 603.0M, and that of FeedForward layer is about 1207.9M.

YeDeming commented 4 years ago

When testing on the multiheadattention


from uer.layers.multi_headed_attn import MultiHeadedAttention

encoder = MultiHeadedAttention(768, 12, 0.0)
l = 512
key = torch.Tensor(1, l, 768)
mask = torch.Tensor(1, l, l)
mask = mask.unsqueeze(1)

inputs = (key, key, key, mask)
macs, params = profile(encoder, inputs, verbose=False)
print (macs)

if you delete the matmul operation in the code, the macs will be the same, e.g., delete the following two lines in multi_headed_attn.py

scores = torch.matmul(query, key.transpose(-2, -1))
....
output = unshape(torch.matmul(probs, value))
autoliuweijie commented 4 years ago

Thank you for your testing! I will analyze it further, and show my results as soon as possible.

autoliuweijie commented 4 years ago

After testing, we found that thop.profile does not calculate FLOPs for torch.matmul() operation.

So, the FLOPs we obtained miss the torch.matmul parts.

References: https://discuss.pytorch.org/t/get-the-matmul-operations-in-a-net/61058

This is a mistake in our work, however, this does not affect the conclusion of this paper, because the speedup is unchanged, and the FLOPs of torch.matmul is only a small part.

Thank you for finding this issue.