Closed pfultz2 closed 1 month ago
There are several cases of dot+reshape+pointwise in distilgpt2:
@47 = gpu::code_object[code_object=5688,symbol_name=mlir_reshape_dot,global=67584,local=256,](@44,@45,@46) -> half_type, {348, 3072}, {3072, 1} @48 = hip::hip_copy_literal[id=main:@literal:66] -> half_type, {3072}, {1} @49 = reshape_lazy[dims={1, 348, 3072}](@47) -> half_type, {1, 348, 3072}, {1069056, 3072, 1} @50 = load[offset=1069056,end=3207168](@1) -> half_type, {1, 348, 3072}, {1069056, 3072, 1} @51 = multibroadcast[out_lens={348, 3072},out_dyn_dims={}](@48) -> half_type, {348, 3072}, {0, 1} @52 = reshape_lazy[dims={1, 348, 3072}](@51) -> half_type, {1, 348, 3072}, {0, 0, 1} @53 = gpu::code_object[code_object=9624,symbol_name=add_mul_mul_mul_mul_add_neg_sub_exp_add_div_mul_kernel,global=534528,local=1024,](@49,@52,@50) -> half_type, {1, 348, 3072}, {1069056, 3072, 1}
And also here:
@32 = gpu::code_object[code_object=5520,symbol_name=mlir_transpose_reshape_dot,global=18432,local=256,](@29,@30,@31) -> half_type, {348, 768}, {768, 1} @33 = hip::hip_copy_literal[id=main:@literal:70] -> half_type, {768}, {1} @34 = load[offset=534528,end=1069056](@1) -> half_type, {1, 348, 768}, {267264, 768, 1} @35 = multibroadcast[out_lens={348, 768},out_dyn_dims={}](@33) -> half_type, {348, 768}, {0, 1} @36 = reshape_lazy[dims={1, 348, 768}](@32) -> half_type, {1, 348, 768}, {267264, 768, 1} @37 = reshape_lazy[dims={1, 348, 768}](@35) -> half_type, {1, 348, 768}, {0, 0, 1} @38 = gpu::code_object[code_object=9416,symbol_name=add_add_kernel,global=133632,local=1024,](@36,@37,@10,@34) -> half_type, {1, 348, 768}, {267264, 768, 1}
The dot is only used once so we could fuse this with MLIR. MLIR should already handle a reshape as it already fuses them on the inputs to dot, but we could do some tweaking in migraphx to fuse it when its inbetween the dot and pointwise.
dot
Is this a duplicate of https://github.com/ROCm/AMDMIGraphX/issues/2813 ?
Closing since this is duplicate of #2813
There are several cases of dot+reshape+pointwise in distilgpt2:
And also here:
The dot is only used once so we could fuse this with MLIR. MLIR should already handle a reshape as it already fuses them on the inputs to
dot
, but we could do some tweaking in migraphx to fuse it when its inbetween the dot and pointwise.