ROCm / AMDMIGraphX

AMD's graph optimization engine.
https://rocm.docs.amd.com/projects/AMDMIGraphX/en/latest/
MIT License
185 stars 83 forks source link

Optimize Dot + Slice #3267

Closed umangyadav closed 3 months ago

umangyadav commented 3 months ago

RNN has patterns where Output of GEMM is immediately Sliced.

MIGraphX can slice inputs of the GEMMs instead and only compute portion of the GEMM that is useful instead.

umangyadav commented 3 months ago

For exmaple: Following is the code snippet before fuse_mlir pass for the test_gru_bidirct_3args test where @22 is a dot operation and only used once where it is sliced.

@22 = gpu::mlir_op[op=dot](@0,r), [mlir_dot1] -> float_type, {2, 10}, {10, 1}
@23 = slice[axes={1},starts={0},ends={5}](@21) -> float_type, {2, 5}, {15, 1}
@24 = contiguous(@23) -> float_type, {2, 5}, {5, 1}
@25 = slice[axes={1},starts={10},ends={15}](@21) -> float_type, {2, 5}, {15, 1}
@26 = contiguous(@25) -> float_type, {2, 5}, {5, 1}
@27 = slice[axes={1},starts={0},ends={5}](@22) -> float_type, {2, 5}, {10, 1}