This is a proposal to enable MmaOp to receive inputs shaped like [M, K] and [N, K] instead of [M, 1, K] and [1, N, K].
This is an alternative to #3366.
Motivation
Currently, MmaOp requires at least 3D inputs in which all of the dimensions "line up". That means that M dimensions should be Iteration in the A operand and Broadcast in the B operand. For example:
This lets us use the default exact domain mapping between operands and MmaOp output to determine the following groups {0, 3, 6}, {1, 4, 7}, {2, 5, 8}.
However, it means that if we are translating a Fusion that has MatmulOp or LinearOp to use MmaOp, we need to introduce BroadcastOp nodes, which interferes with the optimal gmem->smem->mma pipeline on Hopper.
Proposed Approach
In order to avoid needing BroadcastOp when our segment inputs do not already have broadcasts, we need to handle cases like this:
We can no longer just specify some numbered axes to reduce, since the inputs do not have the same number of axes as the output. And even if we can specify it in the op, the IR node will need to hold some more information so that we can perform Exact and Broadcast mapping of IterDomains across the MmaOp.
We now need to specify the axes to reduce as well as the axis correspondences in the inputs. One possibility is to specify the position of the corresponding input axis for each output axis as an axis mapping:
I know this looks a little verbose but remember that we're not going to be creating these by hand very often as our main path will be to receive MatmulOp and LinearOp from Thunder and that interface will not change.
I propose to do the following:
Add attributes to the MmaOp indicating the roles of dimensions in the tv_a and tv_b inputs.
Add a special case in PairwiseLogicalDomainMap that will map the output domains to domains in the inputs using the map above. This is similar to what we do for SdpaFwdOp and SdpaBwdOp currently.
Update mma_utils::MatmulPattern::translateToMmaOp to skip inserting broadcasts and use this interface instead.
Update the Ampere matmul scheduler to not assume there is a broadcast M or N dimension in the ab and bb tensors.
Note that I also plan to keep the current interface for fusedMultiplySum available, so that we can use broadcasted inputs if we want to. The only caveat with keeping that old behavior around is that it might complicate the changes to the Ampere scheduler.
This is a proposal to enable MmaOp to receive inputs shaped like [M, K] and [N, K] instead of [M, 1, K] and [1, N, K].
This is an alternative to #3366.
Motivation
Currently, MmaOp requires at least 3D inputs in which all of the dimensions "line up". That means that M dimensions should be Iteration in the A operand and Broadcast in the B operand. For example:
This lets us use the default exact domain mapping between operands and MmaOp output to determine the following groups {0, 3, 6}, {1, 4, 7}, {2, 5, 8}.
However, it means that if we are translating a Fusion that has
MatmulOp
orLinearOp
to useMmaOp
, we need to introduceBroadcastOp
nodes, which interferes with the optimal gmem->smem->mma pipeline on Hopper.Proposed Approach
In order to avoid needing
BroadcastOp
when our segment inputs do not already have broadcasts, we need to handle cases like this:We can no longer just specify some numbered axes to reduce, since the inputs do not have the same number of axes as the output. And even if we can specify it in the op, the IR node will need to hold some more information so that we can perform Exact and Broadcast mapping of IterDomains across the MmaOp.
We now need to specify the axes to reduce as well as the axis correspondences in the inputs. One possibility is to specify the position of the corresponding input axis for each output axis as an axis mapping:
I know this looks a little verbose but remember that we're not going to be creating these by hand very often as our main path will be to receive
MatmulOp
andLinearOp
from Thunder and that interface will not change.I propose to do the following:
MmaOp
indicating the roles of dimensions in thetv_a
andtv_b
inputs.PairwiseLogicalDomainMap
that will map the output domains to domains in the inputs using the map above. This is similar to what we do forSdpaFwdOp
andSdpaBwdOp
currently.mma_utils::MatmulPattern::translateToMmaOp
to skip inserting broadcasts and use this interface instead.ab
andbb
tensors.Note that I also plan to keep the current interface for
fusedMultiplySum
available, so that we can use broadcasted inputs if we want to. The only caveat with keeping that old behavior around is that it might complicate the changes to the Ampere scheduler.