ROCm / AMDMIGraphX

AMD's graph optimization engine.
https://rocm.docs.amd.com/projects/AMDMIGraphX/en/latest/
MIT License
181 stars 82 forks source link

MLIR: GEMM -> reshape -> pointwise #2822

Closed pfultz2 closed 1 month ago

pfultz2 commented 6 months ago

There are several cases of dot+reshape+pointwise in distilgpt2:

@47 = gpu::code_object[code_object=5688,symbol_name=mlir_reshape_dot,global=67584,local=256,](@44,@45,@46) -> half_type, {348, 3072}, {3072, 1}
@48 = hip::hip_copy_literal[id=main:@literal:66] -> half_type, {3072}, {1}
@49 = reshape_lazy[dims={1, 348, 3072}](@47) -> half_type, {1, 348, 3072}, {1069056, 3072, 1}
@50 = load[offset=1069056,end=3207168](@1) -> half_type, {1, 348, 3072}, {1069056, 3072, 1}
@51 = multibroadcast[out_lens={348, 3072},out_dyn_dims={}](@48) -> half_type, {348, 3072}, {0, 1}
@52 = reshape_lazy[dims={1, 348, 3072}](@51) -> half_type, {1, 348, 3072}, {0, 0, 1}
@53 = gpu::code_object[code_object=9624,symbol_name=add_mul_mul_mul_mul_add_neg_sub_exp_add_div_mul_kernel,global=534528,local=1024,](@49,@52,@50) -> half_type, {1, 348, 3072}, {1069056, 3072, 1}

And also here:

@32 = gpu::code_object[code_object=5520,symbol_name=mlir_transpose_reshape_dot,global=18432,local=256,](@29,@30,@31) -> half_type, {348, 768}, {768, 1}
@33 = hip::hip_copy_literal[id=main:@literal:70] -> half_type, {768}, {1}
@34 = load[offset=534528,end=1069056](@1) -> half_type, {1, 348, 768}, {267264, 768, 1}
@35 = multibroadcast[out_lens={348, 768},out_dyn_dims={}](@33) -> half_type, {348, 768}, {0, 1}
@36 = reshape_lazy[dims={1, 348, 768}](@32) -> half_type, {1, 348, 768}, {267264, 768, 1}
@37 = reshape_lazy[dims={1, 348, 768}](@35) -> half_type, {1, 348, 768}, {0, 0, 1}
@38 = gpu::code_object[code_object=9416,symbol_name=add_add_kernel,global=133632,local=1024,](@36,@37,@10,@34) -> half_type, {1, 348, 768}, {267264, 768, 1}

The dot is only used once so we could fuse this with MLIR. MLIR should already handle a reshape as it already fuses them on the inputs to dot, but we could do some tweaking in migraphx to fuse it when its inbetween the dot and pointwise.

umangyadav commented 6 months ago

Is this a duplicate of https://github.com/ROCm/AMDMIGraphX/issues/2813 ?

umangyadav commented 1 month ago

Closing since this is duplicate of #2813