Open causten opened 1 week ago
I see two patterns in llama2-int4 that are not getting fused.
unpack_int4->dequantize_linear->reshape->transpose->dot
and
unpack_int4->unsqueeze->multibroadcast->contiguous->reshape_lazy->dequantize_linear->reshape->transpose->dot
Are you able to show the operators with the shapes to understand how its reshaping and broadcasting?
A test case can show the matcher isn't being hit for unpack_int4
:
bin/driver compile ../test/onnx/matmulnbits_mm2_test.onnx
.
module: "main"
@0 = check_context::migraphx::gpu::context -> float_type, {}, {}
@1 = hip::hip_allocate_memory[shape=int8_type, {752}, {1},id=main:scratch] -> int8_type, {752}, {1}
@2 = load[offset=0,end=384](@1) -> float_type, {2, 3, 16}, {48, 16, 1}
scales = @param:scales -> float_type, {6}, {1}
@4 = reshape_lazy[dims={2, 3, 1}](scales) -> float_type, {2, 3, 1}, {3, 1, 1}
@5 = multibroadcast[out_lens={2, 3, 16},out_dyn_dims={}](@4) -> float_type, {2, 3, 16}, {3, 1, 0}
@6 = gpu::code_object[code_object=5016,symbol_name=contiguous_kernel,global=48,local=1024,](@5,@2) -> float_type, {2, 3, 16}, {48, 16, 1}
@7 = load[offset=656,end=752](@1) -> uint8_type, {2, 48}, {48, 1}
b = @param:b -> uint8_type, {2, 3, 8}, {24, 8, 1}
@9 = reshape_lazy[dims={2, -1}](b) -> uint8_type, {2, 24}, {24, 1}
@10 = gpu::code_object[code_object=5016,symbol_name=unpack_int4_kernel,global=48,local=1024,](@9,@7) -> uint8_type, {2, 48}, {48, 1}
@11 = reshape_lazy[dims={2, 48}](@6) -> float_type, {2, 48}, {48, 1}
@12 = load[offset=384,end=648](@1) -> float_type, {2, 33}, {33, 1}
@13 = slice[axes={1},starts={0},ends={33}](@11) -> float_type, {2, 33}, {48, 1}
@14 = slice[axes={1},starts={0},ends={33}](@10) -> uint8_type, {2, 33}, {48, 1}
@15 = gpu::code_object[code_object=5088,symbol_name=dequantizelinear_kernel,global=66,local=1024,](@14,@13,@12) -> float_type, {2, 33}, {33, 1}
main:#output_0 = @param:main:#output_0 -> float_type, {2, 2}, {2, 1}
a = @param:a -> float_type, {2, 33}, {33, 1}
@18 = gpu::code_object[code_object=5216,symbol_name=mlir_transpose_dot,global=256,local=256,](a,@15,main:#output_0) -> float_type, {2, 2}, {2, 1}
@19 = @return(@18)
zero_point
: when it is uint8
, there is no easy way to fuse within the current MLIR supported types.https://github.com/ROCm/AMDMIGraphX/pull/3566
With this PR, it will properly fuse in this new test-case:
bin/driver compile ../test/onnx/matmulnbits_mm2_signed_test.onnx
.
But this uint8
case is going to still be unfixed:
bin/driver compile ../test/onnx/matmulnbits_mm2_test.onnx
dequantizelinear_kernel on the llama2 7B 16a4w model is not fusing as expected