iree-org / iree

A retargetable MLIR-based machine learning compiler and runtime toolkit.
http://iree.dev/
Apache License 2.0
2.47k stars 551 forks source link

(Codegen) --iree-opt-aggressively-propagate-transposes does not work with fx-decomposed attention. #17624

Open monorimet opened 3 weeks ago

monorimet commented 3 weeks ago

What happened?

This error was difficult to pin down to a specific flag, so I thought I would share.

When debugging unet numerics on gfx1100 with WMMA FA etc., I had to try the fx-decomposed attention route to get a baseline read on our ROCM codegen numerical accuracy (with minimal flags, then with all the experimental opts)

I ran into the following error with "maximum perf flags", i.e., what we run when sdpa is in the torch IR:

<unknown>:0: error: cannot get concrete layout for contraction
sdxl_turbo_bs1_64_512x512_fp16_unet_rocm.mlir:2800:12: error: 'func.func' op failed to distribute
    %463 = torch.aten.mul.Scalar %462, %float3.535530e-01_571 : !torch.vtensor<[2,10,64,1024],f16>, !torch.float -> !torch.vtensor<[2,10,64,1024],f16>
           ^
sdxl_turbo_bs1_64_512x512_fp16_unet_rocm.mlir:2800:12: error: Failures have been detected while processing an MLIR pass pipeline
    %463 = torch.aten.mul.Scalar %462, %float3.535530e-01_571 : !torch.vtensor<[2,10,64,1024],f16>, !torch.float -> !torch.vtensor<[2,10,64,1024],f16>
           ^
sdxl_turbo_bs1_64_512x512_fp16_unet_rocm.mlir:2800:12: note: Pipeline failed while executing [`TranslateExecutablesPass` on 'hal.executable' operation: @run_forward$async_dispatch_13, `TranslateTargetExecutableVariantsPass` on 'hal.executable.variant' operation: @rocm_hsaco_fb, `TranslateExecutablesPass` on 'hal.executable' operation: @run_forward$async_dispatch_21, `TranslateTargetExecutableVariantsPass` on 'hal.executable.variant' operation: @rocm_hsaco_fb, `TranslateExecutablesPass` on 'hal.executable' operation: @run_forward$async_dispatch_27, `TranslateTargetExecutableVariantsPass` on 'hal.executable.variant' operation: @rocm_hsaco_fb, `TranslateExecutablesPass` on 'hal.executable ........... <truncated>.................`LLVMGPUVectorDistribute` on 'func.func' operation: @run_forward$async_dispatch_75_matmul_like_2x10x64x1024x640_f16xf16xf32, `CSE` on 'func.func' operation: @run_forward$async_dispatch_78_elementwise_20971520_f32xf16]: reproducer generated at `./shark_tmp/core-reproducer.mlir`
sdxl_turbo_bs1_64_512x512_fp16_unet_rocm.mlir:2800:12: error: failed to run translation of source executable to target executable for backend #hal.executable.target<"rocm", "rocm-hsaco-fb", {iree.gpu.target = #iree_gpu.target<arch = "gfx1100", features = "", wgp = <compute =  fp64|fp32|fp16|int64|int32|int16|int8, storage =  b64|b32|b16|b8, subgroup =  shuffle|arithmetic, dot =  dp4xi8toi32, mma = [<WMMA_F16_16x16x16_F32>], subgroup_size_choices = [32, 64], max_workgroup_sizes = [1024, 1024, 1024], max_thread_count_per_workgroup = 1024, max_workgroup_memory_bytes = 65536>>, ukernels = "none"}>
    %463 = torch.aten.mul.Scalar %462, %float3.535530e-01_571 : !torch.vtensor<[2,10,64,1024],f16>, !torch.float -> !torch.vtensor<[2,10,64,1024],f16>
           ^

Invoked with:
 iree-compile.exe C:\V\iree-build\compiler\bindings\python\iree\compiler\tools\..\_mlir_libs\iree-compile.exe sdxl_turbo_bs1_64_512x512_fp16_unet_rocm.mlir --iree-input-type=torch --iree-vm-bytecode-module-output-format=flatbuffer-binary --iree-hal-target-backends=rocm --mlir-print-debuginfo --mlir-print-op-on-diagnostic=false --mlir-pass-pipeline-crash-reproducer=./shark_tmp/core-reproducer.mlir --iree-hal-target-backends=rocm --iree-rocm-target-chip=gfx1100 --iree-vm-bytecode-module-output-format=flatbuffer-binary --iree-flow-inline-constants-max-byte-length=1 --iree-global-opt-propagate-transposes=true --iree-opt-outer-dim-concat=true --iree-vm-target-truncate-unsupported-floats --iree-llvmgpu-enable-prefetch=true --iree-opt-data-tiling=false --iree-opt-const-eval=false --iree-opt-aggressively-propagate-transposes=true --iree-flow-enable-aggressive-fusion --iree-global-opt-enable-fuse-horizontal-contractions=true --iree-codegen-gpu-native-math-precision=true --iree-codegen-llvmgpu-use-vector-distribution=true --iree-codegen-llvmgpu-enable-transform-dialect-jit=false --iree-preprocessing-pass-pipeline=builtin.module(iree-preprocessing-transpose-convolution-pipeline, iree-global-opt-raise-special-ops, util.func(iree-preprocessing-pad-to-intrinsics)) --iree-codegen-transform-dialect-library=attention_and_matmul_spec_wmma.mlir

So, naturally, my first thought was that vector distribution simply failed and wouldn't work with this IR.

(here's the IR, since I mentioned it)

The actual compile-breaking flag here, however, seems to be --iree-opt-aggressively-propagate-transposes=true, as the others did not need to be removed for a successful compile and run. Since vector distribution is an important optimization that we should use whenever allowed, how should we keep track of all these aggressive fusions and optimizations that can opaquely break it at compile time?

Steps to reproduce your issue

Download the IR and run the compile command printed in the error.

What component(s) does this issue relate to?

No response

Version information

tip of shared/tresleches-united or shared/tresleches-united-cpu-merge (more recently rebased on main) (f3e4f56477c8282ff51e0f95ac86fdf57e090c21)

Additional context

No response

monorimet commented 3 weeks ago

It would also be helpful to know which of these flags are dependent on vector distribution, or to have some useful error messages for any flags that are not mutually exclusive to use -- if I'm unraveling a giant stack of compiler options so I can rule them out as causing numerics issues, I'd assume there's a big chunk of them that either depend on one another or are only useful for cases where we're actually lowering attention and not its decomposition, etc...