NVIDIA / Fuser

A Fusion Code Generator for NVIDIA GPUs (commonly known as "nvFuser")
Other
271 stars 53 forks source link

Accept axis mapping when defining MmaOp #3391

Closed jacobhinkle closed 1 week ago

jacobhinkle commented 1 week ago

This keeps the default interface of fusedMultiplySum but also adds an option to provide an MmaOp::AxisMapping object. This mapping defines, for each output dimension, which axis in each operand (if any) corresponds to that output dimension.

This PR does not alter the behavior of mma_utils::MatmulPattern::translateToMmaOp meaning we still have BroadcastOp in translations for Hopper matmuls, but that change should be relatively simpler.

Fixes #3372

The included test only checks that dimensions are properly mapped in an MmaOp defined without broadcast axes. In followup PRs I plan to do the following:

  1. Demonstrate scheduling a Hopper matmul with unbroadcasted inputs manually. This should surface any bugs in the lowering of the MmaOp instruction when broadcasts are absent.
  2. Ensure that we don't depend on having broadcast dims in the Hopper matmul scheduler. For example, we will handle this case in moveInnerBroadcastLeft and we may also need to adjust the swizzling of the TMA smem load TensorView. At this point we will be able to automatically schedule an MmaOp without broadcasted inputs that has been manually defined using our automatic scheduler.
  3. Add an option MatmulPattern::translateToMmaOp(/*avoid_intermediates=*/true) and enable that in the Hopper matmul scheduler. At this point it will be safe for us to accept MatmulOp and LinearOp in the Hopper matmul scheduler.
jacobhinkle commented 1 week ago

!test

jacobhinkle commented 1 week ago

!test

jacobhinkle commented 1 week ago

!test

jacobhinkle commented 1 week ago

!test

jacobhinkle commented 1 week ago

After this PR, one thing we can do is specify the dimension order of the output of the MmaOp independently from the inputs. When we translate MatmulOp and LinearOp, the output already has logical order M, N and we are free to place K wherever we want, so I'll place it last. I think this will let us avoid using commitLeafToLogical like is done here: https://github.com/NVIDIA/Fuser/blob/d34553f46e61341f8ac138630273dceb9f6cfbf8/tests/cpp/test_mma.cpp#L603 So in that case we can see how the AxisMapping is standing in for a root->logical reordering. Since there is one for each input operand this feels like another nice use case for read/write/compute domains as suggested by @zasdfgbnm for indexing ldmatrix.

jacobhinkle commented 1 week ago

!build

jacobhinkle commented 1 week ago

!build