Open ghostplant opened 1 week ago
@thakkarV Is there a cutlass template example for Hopper that computes Gemm with each of their input applied by a custom pre-calculation? (i.e. a fused matmul(x, exp(y))
)? I'd like to know that so as to answer if fusing exp(y) is efficient or even possible in latest warp-specialized TMA strategy.
if you are doing this modification to the input tensors, its usually much better to fuse this in the epilogue of the previous layer. we do not have a mainloop fusion layer for this reason
the padding you are describing above is something the framework or graph compiler layer would take care of rather than something cutlass does
the padding you are describing above is something the framework or graph compiler layer would take care of rather than something cutlass does
Sure, but the problem seems to be TMA's interface isn't flexible enough or compatible with all fusion requirement. Before Hopper, fusion can be taken care of by compiler during loading gmem to smem (e.g. smem[i] = (i < 0) ? -1 : gmem[i % 3]
), but after TMA is introduced, such flexibility is removed since TMA directly loads gmem data to smem without the flexibility of customized transformation during its loading. I think pad is a typical example whose gmem shape (e.g. [3x3]) doesn't align with smem's shape (e.g. [5x5]), so I hope there would be any example using Hopper-TMA style to deal with something like smem = tma_padding_load(gmem, border_val=-1)
.
Although I can disable TMA to perform GEMM following early GPU styles so that I can apply custom padding between gmem -> smem. But the sad thing is Hopper's Gemm won't be efficient without using TMA, which loses the sense to do padding fusion for better speed.
the padding you are describing above is something the framework or graph compiler layer would take care of rather than something cutlass does
For F8_scaled_GEMM(x, y, scale_x, scale_y) which is supported by Cutlass, it seems to do an easier & in-place fusion requirement so as to let smem = scale_x * tma_load(gmem)
, but tracking the kernel code newly designed for cutlass interface 3.x is pretty hard. For a warp-specialized & pingpong based GEMM kernel, can you share the hyperlink/position of how that row-wise scaling factor is applied around TMA load to smem in device code? I'll jump to that device code and check what else fusion purpose can be applied using similar ways. Thanks!
What is your question?
This is the computation requirement using a native way to execute a GEMM with custom input:
Is it supported by any Cutlass template to do above fused "pad + gemm" together? Which example can be good for reference?