ROCm / AMDMIGraphX

AMD's graph optimization engine.
https://rocm.docs.amd.com/projects/AMDMIGraphX/en/latest/
MIT License
185 stars 85 forks source link

Improve fusions with dequantizelinear #3551

Open causten opened 1 week ago

causten commented 1 week ago

dequantizelinear_kernel on the llama2 7B 16a4w model is not fusing as expected

turneram commented 1 week ago

I see two patterns in llama2-int4 that are not getting fused. unpack_int4->dequantize_linear->reshape->transpose->dot and unpack_int4->unsqueeze->multibroadcast->contiguous->reshape_lazy->dequantize_linear->reshape->transpose->dot

pfultz2 commented 1 week ago

Are you able to show the operators with the shapes to understand how its reshaping and broadcasting?

lakhinderwalia commented 1 week ago

A test case can show the matcher isn't being hit for unpack_int4:

bin/driver compile ../test/onnx/matmulnbits_mm2_test.onnx.

module: "main"
@0 = check_context::migraphx::gpu::context -> float_type, {}, {}
@1 = hip::hip_allocate_memory[shape=int8_type, {752}, {1},id=main:scratch] -> int8_type, {752}, {1}
@2 = load[offset=0,end=384](@1) -> float_type, {2, 3, 16}, {48, 16, 1}
scales = @param:scales -> float_type, {6}, {1}
@4 = reshape_lazy[dims={2, 3, 1}](scales) -> float_type, {2, 3, 1}, {3, 1, 1}
@5 = multibroadcast[out_lens={2, 3, 16},out_dyn_dims={}](@4) -> float_type, {2, 3, 16}, {3, 1, 0}
@6 = gpu::code_object[code_object=5016,symbol_name=contiguous_kernel,global=48,local=1024,](@5,@2) -> float_type, {2, 3, 16}, {48, 16, 1}
@7 = load[offset=656,end=752](@1) -> uint8_type, {2, 48}, {48, 1}
b = @param:b -> uint8_type, {2, 3, 8}, {24, 8, 1}
@9 = reshape_lazy[dims={2, -1}](b) -> uint8_type, {2, 24}, {24, 1}
@10 = gpu::code_object[code_object=5016,symbol_name=unpack_int4_kernel,global=48,local=1024,](@9,@7) -> uint8_type, {2, 48}, {48, 1}
@11 = reshape_lazy[dims={2, 48}](@6) -> float_type, {2, 48}, {48, 1}
@12 = load[offset=384,end=648](@1) -> float_type, {2, 33}, {33, 1}
@13 = slice[axes={1},starts={0},ends={33}](@11) -> float_type, {2, 33}, {48, 1}
@14 = slice[axes={1},starts={0},ends={33}](@10) -> uint8_type, {2, 33}, {48, 1}
@15 = gpu::code_object[code_object=5088,symbol_name=dequantizelinear_kernel,global=66,local=1024,](@14,@13,@12) -> float_type, {2, 33}, {33, 1}
main:#output_0 = @param:main:#output_0 -> float_type, {2, 2}, {2, 1}
a = @param:a -> float_type, {2, 33}, {33, 1}
@18 = gpu::code_object[code_object=5216,symbol_name=mlir_transpose_dot,global=256,local=256,](a,@15,main:#output_0) -> float_type, {2, 2}, {2, 1}
@19 = @return(@18)
lakhinderwalia commented 1 week ago
  1. We have an issue with the type of zero_point: when it is uint8, there is no easy way to fuse within the current MLIR supported types.
  2. Addtionally, a bug is being fixed in parsing matmulnbits. That bug won't fix the above problem, however.
lakhinderwalia commented 1 week ago

https://github.com/ROCm/AMDMIGraphX/pull/3566

With this PR, it will properly fuse in this new test-case: bin/driver compile ../test/onnx/matmulnbits_mm2_signed_test.onnx.

But this uint8 case is going to still be unfixed: bin/driver compile ../test/onnx/matmulnbits_mm2_test.onnx