zhijian-liu / torchprofile

A general and accurate MACs / FLOPs profiler for PyTorch models
https://pypi.org/project/torchprofile/
MIT License
571 stars 39 forks source link

support shape of b, *, l, c, thus matching the shape required by scaled_dot_product_attention #24

Closed Z-Zheng closed 6 months ago

Z-Zheng commented 6 months ago

Problem description: scaled_dot_product_attention API supports the input with the shape of N, ..., L, E, which means that the element of node.inputs may have multiple choices for input shapes, e.g., in SAM ViT encoder, they input a tensor with the shape of (B*num_head, H*W, C). Only three dimensionalities here. It would raise an error when we use the current version. (https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html)

Solution: We can generalize def scaled_dot_product_attention(node): via indexing the shape value with 0, -2, -1, and then computing product over values in shape[1:-2] in place of the original variable h. Thus, it can exactly match the shape (N, ..., L, E) required by scaled_dot_product_attention API.