TylerYep / torchinfo

View model summaries in PyTorch!
MIT License
2.56k stars 119 forks source link

Inaccurate Mult-Adds Estimation for Transformers #226

Open Yiming-M opened 1 year ago

Yiming-M commented 1 year ago

Describe the bug

For ViT, the returned total mult-adds from torchinfo.summary is much smaller than that reported in other websites.

To Reproduce

Code snippet:

from torchinfo import summary
from torchvision.models import vit_b_16
vit = vit_b_16()
input_size = 1, 3, 224, 224
summary(vit, input_size)

Output:

...
Total params: 86,567,656
Trainable params: 86,567,656
Non-trainable params: 0
Total mult-adds (M): 173.23
===============================================================================================
Input size (MB): 0.60
Forward/backward pass size (MB): 104.09
Params size (MB): 232.27
Estimated Total Size (MB): 336.96

Expected behavior

From other resources such as MMClassification and PapersWithCode, the number of flops is 33.03G. I understand that the number of mult-adds is different than the number of flops, but in the case of transformers, where matrix multiplication accounts for a large proportion of overall computation, these two numbers should be similar (not like 33.03G and 173.23M!)

Screenshots If applicable, add screenshots to help explain your problem.

Environment (please complete the following information):

hellcer commented 1 year ago

I meet the same question and hope developers to pay attention to it, Thanks a lot.

quancs commented 1 year ago

encountered similar bug: The MACs of MultiheadAttention module doesn't get counted

snimu commented 1 year ago

The problem is that currently, torchinfo only traces nn.Modules, not functions. Transformer Modules often use shortcut functions, so they often don't get traced.

Discussion #192 proposes a tracing mechanism that would fix this issue, but it is a big change. If anyone is up to implementing the change, I think that @TylerYep would be happy about it.