ROCm / AMDMIGraphX

AMD's graph optimization engine.
https://rocm.docs.amd.com/projects/AMDMIGraphX/en/latest/
MIT License
181 stars 82 forks source link

Fuse GEMM across a reshape #2813

Closed CharlieL7 closed 1 month ago

CharlieL7 commented 6 months ago

From the 22 Feb 2024 performance model review of Distilgpt2:

There are several cases of dot+reshape+pointwise:

@47 = gpu::code_object[code_object=5688,symbol_name=mlir_reshape_dot,global=67584,local=256,](@44,@45,@46) -> half_type, {348, 3072}, {3072, 1}
@48 = hip::hip_copy_literal[id=main:@literal:66] -> half_type, {3072}, {1}
@49 = reshape_lazy[dims={1, 348, 3072}](@47) -> half_type, {1, 348, 3072}, {1069056, 3072, 1}
@50 = load[offset=1069056,end=3207168](@1) -> half_type, {1, 348, 3072}, {1069056, 3072, 1}
@51 = multibroadcast[out_lens={348, 3072},out_dyn_dims={}](@48) -> half_type, {348, 3072}, {0, 1}
@52 = reshape_lazy[dims={1, 348, 3072}](@51) -> half_type, {1, 348, 3072}, {0, 0, 1}
@53 = gpu::code_object[code_object=9624,symbol_name=add_mul_mul_mul_mul_add_neg_sub_exp_add_div_mul_kernel,global=534528,local=1024,](@49,@52,@50) -> half_type, {1, 348, 3072}, {1069056, 3072, 1}

And also here:

@32 = gpu::code_object[code_object=5520,symbol_name=mlir_transpose_reshape_dot,global=18432,local=256,](@29,@30,@31) -> half_type, {348, 768}, {768, 1}
@33 = hip::hip_copy_literal[id=main:@literal:70] -> half_type, {768}, {1}
@34 = load[offset=534528,end=1069056](@1) -> half_type, {1, 348, 768}, {267264, 768, 1}
@35 = multibroadcast[out_lens={348, 768},out_dyn_dims={}](@33) -> half_type, {348, 768}, {0, 1}
@36 = reshape_lazy[dims={1, 348, 768}](@32) -> half_type, {1, 348, 768}, {267264, 768, 1}
@37 = reshape_lazy[dims={1, 348, 768}](@35) -> half_type, {1, 348, 768}, {0, 0, 1}
@38 = gpu::code_object[code_object=9416,symbol_name=add_add_kernel,global=133632,local=1024,](@36,@37,@10,@34) -> half_type, {1, 348, 768}, {267264, 768, 1}

The dot is only used once so we could fuse this with MLIR. MLIR should already handle a reshape as it already fuses them on the inputs to dot, but we could do some tweaking in migraphx to fuse it when its inbetween the dot and pointwise.

Deliverables:

umangyadav commented 6 months ago

Need to write a matcher in MIGX that will handle this case of fusing dot -> reshape -> pointwise to pass to MLIR.

I think, It can be very much similar to

https://github.com/ROCm/AMDMIGraphX/blob/6e505cd43c5286f4fafff9b8b863fe94e8735a25/src/fuse_pointwise.cpp#L197

umangyadav commented 1 month ago

https://github.com/ROCm/AMDMIGraphX/pull/3280 solves this : I see fusion across reshape.

@34 = gpu::code_object[code_object=5824,symbol_name=mlir_transpose_reshape_dot_reshape_add_add,global=18432,local=256,](@33,@11,@28,@30,@31) -> half_type,     {1, 348, 768}, {267264, 768, 1}

For the other case i see this fused kernel :

 @44 = gpu::code_object[code_object=11600,symbol_name=mlir_dot_add_add_mul_mul_add_mul_exp_add_div,global=36864,local=256,](@42,@41,@38,@36,@43) -> half_ty    pe, {1, 348, 3072}, {1069056, 3072, 1}