iree-org / iree

A retargetable MLIR-based machine learning compiler and runtime toolkit.
http://iree.dev/
Apache License 2.0
2.85k stars 620 forks source link

[compiler] FLUX.1 transformer compilation for gfx942 hangs with pad-to-intrinsics #19249

Open monorimet opened 22 hours ago

monorimet commented 22 hours ago

What happened?

IREE compiler hangs when I use iree-preprocessing-pad-to-intrinsics on this IR:

iree-compile --iree-hal-target-device=amdgpu --iree-hip-target=gfx942 --iree-hal-target-backends=rocm --iree-hip-target=gfx942 --iree-execution-model=async-external --iree-preprocessing-pass-pipeline='builtin.module(util.func(iree-global-opt-raise-special-ops, iree-flow-canonicalize), iree-preprocessing-transpose-convolution-pipeline, iree-preprocessing-pad-to-intrinsics, util.func(iree-preprocessing-generalize-linalg-matmul-experimental))' --iree-global-opt-propagate-transposes=1 --iree-opt-const-eval=0 --iree-opt-outer-dim-concat=1 --iree-opt-aggressively-propagate-transposes=1 --iree-dispatch-creation-enable-aggressive-fusion --iree-hal-force-indirect-command-buffers --iree-codegen-llvmgpu-use-vector-distribution=1 --iree-llvmgpu-enable-prefetch=1 --iree-codegen-gpu-native-math-precision=1 --iree-hip-legacy-sync=0 --iree-opt-data-tiling=0 --iree-vm-target-truncate-unsupported-floats --iree-dispatch-creation-enable-fuse-horizontal-contractions=1 flux_1_dev.torch_onnx.mlir -o flux-dev_transformer_bs1_512_1024x1024_fp32_amdgpu-gfx942.vmfb

the above hangs for longer than 30 minutes. if I take out iree-preprocessing-pad-to-intrinsics, i.e., use:

iree-compile --iree-hal-target-device=amdgpu --iree-hip-target=gfx942 --iree-hal-target-backends=rocm --iree-hip-target=gfx942 --iree-execution-model=async-external --iree-preprocessing-pass-pipeline='builtin.module(util.func(iree-global-opt-raise-special-ops, iree-flow-canonicalize), iree-preprocessing-transpose-convolution-pipeline, util.func(iree-preprocessing-generalize-linalg-matmul-experimental))' --iree-global-opt-propagate-transposes=1 --iree-opt-const-eval=0 --iree-opt-outer-dim-concat=1 --iree-opt-aggressively-propagate-transposes=1 --iree-dispatch-creation-enable-aggressive-fusion --iree-hal-force-indirect-command-buffers --iree-codegen-llvmgpu-use-vector-distribution=1 --iree-llvmgpu-enable-prefetch=1 --iree-codegen-gpu-native-math-precision=1 --iree-hip-legacy-sync=0 --iree-opt-data-tiling=0 --iree-vm-target-truncate-unsupported-floats --iree-dispatch-creation-enable-fuse-horizontal-contractions=1 flux_1_dev.torch_onnx.mlir -o flux-dev_transformer_bs1_512_1024x1024_fp32_amdgpu-gfx942.vmfb

it compiles in under 30 seconds.

I'm not sure if this pass is still required for matching amdgpu intrinsics.

Steps to reproduce your issue

  1. install iree compiler
  2. wget https://gist.githubusercontent.com/zjgarvey/91c733825018b077565f668e6bda96d8/raw/de72daddf6ab27b7b07ea5e71ca4fc11504edcc8/flux_1_dev.torch_onnx.mlir
  3. run:
    iree-compile --iree-hal-target-device=amdgpu --iree-hip-target=gfx942 --iree-hal-target-backends=rocm --iree-hip-target=gfx942 --iree-execution-model=async-external --iree-preprocessing-pass-pipeline='builtin.module(util.func(iree-global-opt-raise-special-ops, iree-flow-canonicalize), iree-preprocessing-transpose-convolution-pipeline, iree-preprocessing-pad-to-intrinsics, util.func(iree-preprocessing-generalize-linalg-matmul-experimental))' --iree-global-opt-propagate-transposes=1 --iree-opt-const-eval=0 --iree-opt-outer-dim-concat=1 --iree-opt-aggressively-propagate-transposes=1 --iree-dispatch-creation-enable-aggressive-fusion --iree-hal-force-indirect-command-buffers --iree-codegen-llvmgpu-use-vector-distribution=1 --iree-llvmgpu-enable-prefetch=1 --iree-codegen-gpu-native-math-precision=1 --iree-hip-legacy-sync=0 --iree-opt-data-tiling=0 --iree-vm-target-truncate-unsupported-floats --iree-dispatch-creation-enable-fuse-horizontal-contractions=1 flux_1_dev.torch_onnx.mlir -o flux-dev_transformer_bs1_512_1024x1024_fp32_amdgpu-gfx942.vmfb

What component(s) does this issue relate to?

Compiler

Version information

IREE compiler version 3.0.0rc20241118 @ 29c451b00ecc9f9e5466e9d1079e0d69147da700

Additional context

The MLIR used is an ONNX export. It has its parameters externalized. The model precision is fp32.

monorimet commented 20 hours ago

It just occurred to me that this is probably because the model was exported with several dynamic input dims. Could this be why padding to intrinsics is getting stuck?

IanWood1 commented 20 hours ago

I'm not sure why iree-preprocessing-pad-to-intrinsics changes anything but OptimizeIntArithmetic seems to be the problem. Its spending a ton of time making calls to solver.eraseState(). I previously tried to fix it with https://github.com/iree-org/iree/pull/19130