ROCm / AMDMIGraphX

AMD's graph optimization engine.
https://rocm.docs.amd.com/projects/AMDMIGraphX/en/latest/
MIT License
185 stars 83 forks source link

GEMM fusion (over slice or not) #2818

Open CharlieL7 opened 7 months ago

CharlieL7 commented 7 months ago

From the 22 Feb 2024 performance model review of Distilgpt2:

There are several gemms that are applied together(this is the tailend of attention):

@17 = hip::hip_copy_literal[id=main:@literal:6] -> half_type, {348, 2304}, {2304, 1}
@18 = hip::hip_copy_literal[id=main:@literal:73] -> half_type, {768, 2304}, {2304, 1}
@21 = gpu::gemm[alpha=1,beta=1,compute_fp32=1,trans_batch=0,solution_idx=0](@20,@18,@17,@19) -> half_type, {348, 2304}, {2304, 1}
@22 = reshape_lazy[dims={1, 348, 36, 64}](@21) -> half_type, {1, 348, 36, 64}, {801792, 2304, 64, 1}
@23 = transpose[permutation={0, 2, 1, 3}](@22) -> half_type, {1, 36, 348, 64}, {801792, 64, 2304, 1}
@35 = slice[axes={1},starts={24},ends={36}](@23) -> half_type, {1, 12, 348, 64}, {801792, 64, 2304, 1}
@36 = gpu::gemm[alpha=1,beta=0,compute_fp32=1,trans_batch=1,solution_idx=0](@32,@35,@34) -> half_type, {1, 12, 348, 64}, {267264, 64, 768, 1}
@37 = hip::hip_copy_literal[id=main:@literal:72] -> half_type, {768, 768}, {768, 1}
@38 = load[offset=1069056,end=1603584](@1) -> half_type, {348, 768}, {768, 1}
@39 = transpose[permutation={0, 2, 1, 3}](@36) -> half_type, {1, 348, 12, 64}, {267264, 768, 64, 1}
@40 = reshape_lazy[dims={348, 768}](@39) -> half_type, {348, 768}, {768, 1}
@41 = gpu::gemm[alpha=1,beta=0,compute_fp32=1,trans_batch=0,solution_idx=0](@40,@37,@38) -> half_type, {348, 768}, {768, 1}

We have something like X * (Y*A + b) * C (where is matmul) if we get rid of the slice(which is undoing some of the horizontal fusions). So we could possibly rewrite it as `X (YAC + bC), which after const folding we would just haveX (Y*A' + b')` which gets rid of the gemm completely.

This case can be generalized to also not have the slice operator, simplifying the manipulations needed.

Deliverables:

pfultz2 commented 7 months ago

Instead of matching the slice we can try to apply this rewrite before horizontal fusion.

Also, we could start but just writing a matcher for A * B * C that rewrites it to A * (B * C). If this could be applied before horizontal fusion than this would work.