torch.matmul supports vector-vector and matrix-vector product. However, the second assertion in matmul_flop_jit() seems to assume that the second input tensor is at least 2-dimensional.
def matmul_flop_jit(inputs: List[Any], outputs: List[Any]) -> Number:
"""
Count flops for matmul.
"""
# Inputs should be a list of length 2.
# Inputs contains the shapes of two matrices.
input_shapes = [get_shape(v) for v in inputs]
assert len(input_shapes) == 2, input_shapes
assert input_shapes[0][-1] == input_shapes[1][-2], input_shapes
...
If the second input tensor is a 1-dimensional vector, an IndexError: list index out of range will be raised (due to input_shapes[1][-2]).
Is it possible to treat "the 2nd input is a vector" as a special case and compute FLOPs respectively?
torch.matmul
supports vector-vector and matrix-vector product. However, the second assertion inmatmul_flop_jit()
seems to assume that the second input tensor is at least 2-dimensional.If the second input tensor is a 1-dimensional vector, an
IndexError: list index out of range
will be raised (due toinput_shapes[1][-2]
).Is it possible to treat "the 2nd input is a vector" as a special case and compute FLOPs respectively?