Open saienduri opened 2 months ago
When running SDXL with the latest version of torch, we are running into this error when compiling vae:
iree.compiler.tools.binaries.CompilerToolError: Error invoking IREE compiler tool iree-compile Error code: 1 Diagnostics: <stdin>:620:12: error: failed to legalize operation 'torch.operator' that was explicitly marked illegal %153 = torch.operator "torch.aten._safe_softmax"(%152, %int-1_176, %none_177) : (!torch.vtensor<[1,1,16384,16384],f32>, !torch.int, !torch.none) -> !torch.vtensor<[1,1,16384,16384],f32> ^ <stdin>:4452:12: error: failed to legalize operation 'torch.operator' that was explicitly marked illegal %674 = torch.operator "torch.aten._safe_softmax"(%673, %int-1_822, %none_823) : (!torch.vtensor<[1,1,16384,16384],f32>, !torch.int, !torch.none) -> !torch.vtensor<[1,1,16384,16384],f32>
Looks like we are missing the torch -> linalg lowering for torch.aten._safe_softmax
torch.aten._safe_softmax
To repro the issue, here are the instructions to get the artifacts and compile:
curl https://sharkpublic.blob.core.windows.net/sharkpublic/sai/sdxl-latest-torch/vae.mlir --output vae.mlir
curl https://raw.githubusercontent.com/nod-ai/sdxl-scripts/shared/sdxl_on_main/int8-model/specs/attention_and_matmul_spec.mlir --output attn_spec.mlir
iree-compile vae.mlir --iree-hal-target-backends=rocm --iree-hip-target=gfx942 \ --iree-vm-bytecode-module-output-format=flatbuffer-binary --iree-dispatch-creation-enable-aggressive-fusion \ --iree-dispatch-creation-enable-fuse-horizontal-contractions --iree-opt-aggressively-propagate-transposes=true \ --iree-codegen-llvmgpu-use-vector-distribution=true --iree-opt-data-tiling=false \ --iree-codegen-gpu-native-math-precision=true --iree-vm-target-truncate-unsupported-floats \ --iree-global-opt-propagate-transposes=true --iree-opt-const-eval=false --iree-llvmgpu-enable-prefetch=true \ --iree-execution-model=async-external \ --iree-preprocessing-pass-pipeline="builtin.module(util.func(iree-global-opt-raise-special-ops, iree-flow-canonicalize), iree-preprocessing-transpose-convolution-pipeline, iree-preprocessing-pad-to-intrinsics, util.func(iree-preprocessing-generalize-linalg-matmul-experimental))" \ --iree-codegen-transform-dialect-library=attn_spec.mlir -o vae.vmfb
So, when to cure this?
When running SDXL with the latest version of torch, we are running into this error when compiling vae:
Looks like we are missing the torch -> linalg lowering for
torch.aten._safe_softmax
To repro the issue, here are the instructions to get the artifacts and compile:
curl https://sharkpublic.blob.core.windows.net/sharkpublic/sai/sdxl-latest-torch/vae.mlir --output vae.mlir
curl https://raw.githubusercontent.com/nod-ai/sdxl-scripts/shared/sdxl_on_main/int8-model/specs/attention_and_matmul_spec.mlir --output attn_spec.mlir