Closed Priya2698 closed 4 months ago
We should also make sure we support more than one batch dim, including for example batch dims between the M and K dims of A
. Since there's a potential to implement these differently based on nDims()
we should also be sure to add a benchmark (single input size is fine) comparing bmm with batch_size=1
to a 2D matmul of the same size as a kind of perf sanity check.
This issue will be now resolved through work on new IR nodes: https://github.com/NVIDIA/Fuser/issues/2149
PR #2175 and #2209 add support for all cases accepted by torch.matmul
.
We currently only support 2D inputs for matmul. Extend implementation to support additional cases:
[M,] x [M,]
)[M,] x [M, N]
)[M, N] x [N,]
)[B, M, N] x [N,]
).Torch reference: https://pytorch.org/docs/stable/generated/torch.matmul.html Thunder reference: https://github.com/Lightning-AI/lightning-thunder/blob/a28575345fcdc18bf4b9163dfb239195dca9f34d/thunder/tests/opinfos.py#L5299