Closed CharlieL7 closed 1 month ago
Need to write a matcher in MIGX that will handle this case of fusing dot -> reshape -> pointwise to pass to MLIR.
I think, It can be very much similar to
https://github.com/ROCm/AMDMIGraphX/pull/3280 solves this : I see fusion across reshape.
@34 = gpu::code_object[code_object=5824,symbol_name=mlir_transpose_reshape_dot_reshape_add_add,global=18432,local=256,](@33,@11,@28,@30,@31) -> half_type, {1, 348, 768}, {267264, 768, 1}
For the other case i see this fused kernel :
@44 = gpu::code_object[code_object=11600,symbol_name=mlir_dot_add_add_mul_mul_add_mul_exp_add_div,global=36864,local=256,](@42,@41,@38,@36,@43) -> half_ty pe, {1, 348, 3072}, {1069056, 3072, 1}
From the 22 Feb 2024 performance model review of Distilgpt2:
There are several cases of dot+reshape+pointwise:
And also here:
The dot is only used once so we could fuse this with MLIR. MLIR should already handle a reshape as it already fuses them on the inputs to dot, but we could do some tweaking in migraphx to fuse it when its inbetween the dot and pointwise.
Deliverables: