iree-org / iree

A retargetable MLIR-based machine learning compiler and runtime toolkit.
http://iree.dev/
Apache License 2.0
2.62k stars 590 forks source link

[codegen] [gpu]: SD3 MMDiT attention dispatch fails on LinalgExtToLoops for amdgpu targets #18629

Open monorimet opened 2 weeks ago

monorimet commented 2 weeks ago

What happened?

Error log:

(turb.env) PS C:\Users\eagarvey\SHARK\SHARK-Turbine> iree-compile --iree-input-type=torch --iree-vm-bytecode-module-output-format=flatbuffer-binary --iree-hal-target-backends=rocm --mlir-print-debuginfo=false --mlir-print-op-on-diagnostic=false --iree-hal-target-backends=rocm --iree-hip-target=gfx1103 --iree-vm-bytecode-module-output-format=flatbuffer-binary .\sd3_mmdit_gfx1103_dps\dispatch_27_attn.mlir
.\sd3_mmdit_gfx1103_dps\dispatch_27_attn.mlir:2:3: error: failed to run translation of source executable to target executable for backend #hal.executable.target<"rocm", "rocm-hsaco-fb", {iree.gpu.target = #iree_gpu.target<arch = "gfx1103", features = "", wgp = <compute =  fp64|fp32|fp16|int64|int32|int16|int8, storage =  b64|b32|b16|b8, subgroup =  shuffle|arithmetic, dot =  dp4xi8toi32, mma = [<WMMA_F32_16x16x16_F16>, <WMMA_F16_16x16x16_F16>, <WMMA_I32_16x16x16_I8>], subgroup_size_choices = [32, 64], max_workgroup_sizes = [1024, 1024, 1024], max_thread_count_per_workgroup = 1024, max_workgroup_memory_bytes = 65536, max_workgroup_counts = [2147483647, 2147483647, 2147483647]>>, ukernels = "none"}>
  hal.executable.variant public @rocm_hsaco_fb target(<"rocm", "rocm-hsaco-fb", {iree.gpu.target = #iree_gpu.target<arch = "gfx1103", features = "", wgp = <compute =  fp64|fp32|fp16|int64|int32|int16|int8, storage =  b64|b32|b16|b8, subgroup =  shuffle|arithmetic, dot =  dp4xi8toi32, mma = [<WMMA_F32_16x16x16_F16>, <WMMA_F16_16x16x16_F16>, <WMMA_I32_16x16x16_I8>], subgroup_size_choices = [32, 64], max_workgroup_sizes = [1024, 1024, 1024], max_thread_count_per_workgroup = 1024, max_workgroup_memory_bytes = 65536, max_workgroup_counts = [2147483647, 2147483647, 2147483647]>>, ukernels = "none"}>) {
  ^
failed to translate executables

Original MLIR: Azure Dispatch IR: Azure IR dump: Azure

Steps to reproduce your issue

What component(s) does this issue relate to?

Compiler

Version information

IREE compiler version 20240927.1029 @ 76c3e61d563dbc22b74aa9d3d79c11c24a799697

Additional context

Shapes for this model in the past have required a masked attention implementation using transform dialect scripts as they do not match intrinsics. I'm not sure if this has any bearing on the error given.

The compile flags used are stripped down from what is normally used to compile these models, but since the error reproduces without experimental flags, I thought it best to leave them out.

nirvedhmeshram commented 1 week ago

Sharing the smallest repro with which I am able to captures this issue

      func.func @run_forward$async_dispatch_27_attention_2x1178x24x64xf16_generic(%12 : tensor<2x1178x24x64xf16>, 
                      %13 : tensor<2x1178x24x64xf16>, %14 : tensor<2x1178x24x64xf16>) -> tensor<2x24x1178x64xf16> {
        %cst = arith.constant 1.250000e-01 : f16
        %16 = tensor.empty() : tensor<2x24x1178x64xf16>
        %17 = iree_linalg_ext.attention 
                {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d2, d1, d4)>, 
                affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d5, d1, d4)>, 
                affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d5, d1, d3)>, 
                affine_map<(d0, d1, d2, d3, d4, d5) -> ()>, 
                affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>]} 
                ins(%12, %13, %14, %cst : tensor<2x1178x24x64xf16>, tensor<2x1178x24x64xf16>, tensor<2x1178x24x64xf16>, f16) 
                outs(%16 : tensor<2x24x1178x64xf16>) -> tensor<2x24x1178x64xf16>
        return %17 : tensor<2x24x1178x64xf16>
      }

compile with

iree-compile input_repro.mlir --iree-hip-target=gfx1103 --iree-hal-target-backends=rocm