Open jacobhinkle opened 2 months ago
Edited the example to show the tensor we might want to save for backward, which is a full matrix. That matrix would be saved in the loss kernel in the unfused case as well. Overall fusion would save IO proportional to 2MN as expected by not saving the intermediate GEMM output.
Is it possible to support row-wise reduction epilogue when matmul is performed tile by tile? I am curious.
Is it possible to support row-wise reduction epilogue when matmul is performed tile by tile? I am curious.
It seems like it should be possible using the usual grid reduction approaches either using atomics or semaphores to coordinate the reduction. If the reduction result is used to compute other values in the epilogue, then we would require a cooperative launch to ensure that all CTAs in a row of tiles are placed in the same wave so that we can broadcast the resulting reduction back across the grid.
We currently support pointwise epilogues. However, there is an opportunity for us to fuse reductions into our epilogues. For example, adding cross-entropy loss function would allow us to reduce total IO considerably (ignoring the possible need to save matrix outputs for the backwards pass). Pytorch has an open issue requesting such a fused op https://github.com/pytorch/pytorch/issues/124480.
I tried the following straightforward approach:
However, this fails during
scheduleOutputTensor
, which expects tiled 2D outputs:To address this, we need to ensure that we remove any explicit or implicit assumptions that every fusion output is 2D. That might mean we just need to pick appropriate epilogue reference tensors and schedule those then propagate.