Closed Z-Zheng closed 2 years ago
def linear(node):
return math.prod(node.inputs[0].shape) * node.inputs[1].shape[0]
Inserting these two lines into handlers.py seems to work well.
Thanks for reporting this issue! Would you mind submitting a PR for this change?
It seems that linear has been supported in the latest version, as https://github.com/Z-Zheng/torchprofile/blob/a3b483497c30591dd9658ccc0528ecba7c467e4e/torchprofile/handlers.py#L105 presents.
nn.Linear layer is always used in Transformer-like models. Should I add this part when I use torchprofile to compute flops for a transformer model?