NVIDIA / Fuser

A Fusion Code Generator for NVIDIA GPUs (commonly known as "nvFuser")
Other
271 stars 53 forks source link

Enable MmaOp to receive unbroadcasted inputs #3372

Closed jacobhinkle closed 1 week ago

jacobhinkle commented 2 weeks ago

This is a proposal to enable MmaOp to receive inputs shaped like [M, K] and [N, K] instead of [M, 1, K] and [1, N, K].

This is an alternative to #3366.

Motivation

Currently, MmaOp requires at least 3D inputs in which all of the dimensions "line up". That means that M dimensions should be Iteration in the A operand and Broadcast in the B operand. For example:

tv_a [ iS0{M}, bS1{1}, iS2{K} ]
tv_b [ bS3{1}, iS4{N}, iS5{K} ]
tv_c [ iS6{M}, iS7{N}, rS8{K} ] = fusedMultiplySum(tv_a, tv_b, axes={-1})

This lets us use the default exact domain mapping between operands and MmaOp output to determine the following groups {0, 3, 6}, {1, 4, 7}, {2, 5, 8}.

However, it means that if we are translating a Fusion that has MatmulOp or LinearOp to use MmaOp, we need to introduce BroadcastOp nodes, which interferes with the optimal gmem->smem->mma pipeline on Hopper.

Proposed Approach

In order to avoid needing BroadcastOp when our segment inputs do not already have broadcasts, we need to handle cases like this:

(MatmulOp translation)
tv_a [ iS0{M}, iS1{K} ]
tv_b [ iS2{K}, iS3{N} ]
tv_c [ iS4{M}, iS5{N}, rS6{K} ] = fusedMultiplySum(tv_a, tv_b, ??)

We can no longer just specify some numbered axes to reduce, since the inputs do not have the same number of axes as the output. And even if we can specify it in the op, the IR node will need to hold some more information so that we can perform Exact and Broadcast mapping of IterDomains across the MmaOp.

We now need to specify the axes to reduce as well as the axis correspondences in the inputs. One possibility is to specify the position of the corresponding input axis for each output axis as an axis mapping:

tv_c = fusedMultiplySum(
    tv_a,
    tv_b,
    /*init=*/nullptr,
    /*axis_mapping=*/{/*a_axes=*/{0, -1, 1}, /*b_axes=*/{-1, 1, 0}});

I know this looks a little verbose but remember that we're not going to be creating these by hand very often as our main path will be to receive MatmulOp and LinearOp from Thunder and that interface will not change.

I propose to do the following:

Note that I also plan to keep the current interface for fusedMultiplySum available, so that we can use broadcasted inputs if we want to. The only caveat with keeping that old behavior around is that it might complicate the changes to the Ampere scheduler.