Problem description:
scaled_dot_product_attention API supports the input with the shape of N, ..., L, E, which means that the element of node.inputs may have multiple choices for input shapes, e.g., in SAM ViT encoder, they input a tensor with the shape of (B*num_head, H*W, C). Only three dimensionalities here. It would raise an error when we use the current version.
(https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html)
Solution:
We can generalize def scaled_dot_product_attention(node): via indexing the shape value with 0, -2, -1, and then computing product over values in shape[1:-2] in place of the original variable h.
Thus, it can exactly match the shape (N, ..., L, E) required by scaled_dot_product_attention API.
Problem description:
scaled_dot_product_attention
API supports the input with the shape of N, ..., L, E, which means that the element ofnode.inputs
may have multiple choices for input shapes, e.g., in SAM ViT encoder, they input a tensor with the shape of (B*num_head, H*W, C). Only three dimensionalities here. It would raise an error when we use the current version. (https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html)Solution: We can generalize
def scaled_dot_product_attention(node):
via indexing the shape value with 0, -2, -1, and then computing product over values in shape[1:-2] in place of the original variableh
. Thus, it can exactly match the shape (N, ..., L, E) required byscaled_dot_product_attention
API.