facebookresearch / fvcore

Collection of common code that's shared among different research projects in FAIR computer vision team.
Apache License 2.0
2k stars 226 forks source link

Flop counter for matmul does not support matrix-vector product. #130

Open AllenYolk opened 1 year ago

AllenYolk commented 1 year ago

torch.matmul supports vector-vector and matrix-vector product. However, the second assertion in matmul_flop_jit() seems to assume that the second input tensor is at least 2-dimensional.

def matmul_flop_jit(inputs: List[Any], outputs: List[Any]) -> Number:
    """
    Count flops for matmul.
    """
    # Inputs should be a list of length 2.
    # Inputs contains the shapes of two matrices.
    input_shapes = [get_shape(v) for v in inputs]
    assert len(input_shapes) == 2, input_shapes
    assert input_shapes[0][-1] == input_shapes[1][-2], input_shapes
    ...

If the second input tensor is a 1-dimensional vector, an IndexError: list index out of range will be raised (due to input_shapes[1][-2]).

Is it possible to treat "the 2nd input is a vector" as a special case and compute FLOPs respectively?