zhijian-liu / torchprofile

A general and accurate MACs / FLOPs profiler for PyTorch models
https://pypi.org/project/torchprofile/
MIT License
560 stars 38 forks source link

No handlers found: "aten::linear". Skipped. #16

Closed Z-Zheng closed 2 years ago

Z-Zheng commented 2 years ago

nn.Linear layer is always used in Transformer-like models. Should I add this part when I use torchprofile to compute flops for a transformer model?

Z-Zheng commented 2 years ago
def linear(node):
    return math.prod(node.inputs[0].shape) * node.inputs[1].shape[0]

Inserting these two lines into handlers.py seems to work well.

zhijian-liu commented 2 years ago

Thanks for reporting this issue! Would you mind submitting a PR for this change?

Z-Zheng commented 2 years ago

It seems that linear has been supported in the latest version, as https://github.com/Z-Zheng/torchprofile/blob/a3b483497c30591dd9658ccc0528ecba7c467e4e/torchprofile/handlers.py#L105 presents.