iree-org / iree

A retargetable MLIR-based machine learning compiler and runtime toolkit.
http://iree.dev/
Apache License 2.0
2.48k stars 553 forks source link

'func.call' op 'forward' does not reference a valid function #16797

Open aviator19941 opened 3 months ago

aviator19941 commented 3 months ago

What happened?

PNDMScheduler + Unet Torch IR file: stable_diffusion_xl_base_1_0_PNDM_64_1024x1024_fp16_unet_30.mlir EulerDiscreteScheduler + Unet Torch IR file: stable_diffusion_xl_base_1_0_Euler_64_1024x1024_fp16_unet_30.mlir

When running iree-compile with PNDMScheduler + Unet Torch IR file, I am able to compile to a vmfb. However, when running iree-compile with EulerDiscreteScheduler + Unet Torch IR file, I am not able to compile to a vmfb. I get this error:

/home/avsharma/SHARK-Turbine/stable_diffusion_xl_base_1_0_Euler_64_1024x1024_fp16_unet_30.mlir:1725:10: error: 'func.call' op 'forward' does not reference a valid function %6 = call @forward(%0, %1, %2, %3, %4, %5) : (!torch.vtensor<[1,4,128,128],f16>, !torch.vtensor<[2,64,2048],f16>, !torch.vtensor<[2,1280],f16>, !torch.vtensor<[2,6],f16>, !torch.vtensor<[1],f16>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[1,4,128,128],f16>

Steps to reproduce your issue

IREE commit: 3b30ab4e53ee85eb48b2a56081a25f87cc4b2a6d

Using this iree-compile command for the EulerDiscreteScheduler + Unet Torch IR file:

../iree-build/tools/iree-compile --iree-input-type=auto --iree-vm-bytecode-module-output-format=flatbuffer-binary \
--iree-hal-target-backends=rocm --mlir-print-debuginfo --mlir-print-op-on-diagnostic=false \
--iree-hal-target-backends=rocm --iree-rocm-target-chip=gfx940 --iree-rocm-link-bc=true \
--iree-global-opt-propagate-transposes=true --iree-opt-const-eval=false \
--iree-rocm-bc-dir=/opt/rocm/amdgcn/bitcode --iree-codegen-llvmgpu-use-vector-distribution \
--iree-opt-outer-dim-concat=true --iree-codegen-gpu-native-math-precision=true \
--iree-preprocessing-pass-pipeline="builtin.module(iree-preprocessing-transpose-convolution-pipeline)" \
--iree-codegen-transform-dialect-library=mfma_spec.mlir \
/home/avsharma/SHARK-Turbine/stable_diffusion_xl_base_1_0_Euler_64_1024x1024_fp16_unet_30.mlir \
-o scheduled_unet_euler.vmfb

/home/avsharma/SHARK-Turbine/stable_diffusion_xl_base_1_0_Euler_64_1024x1024_fp16_unet_30.mlir:1725:10: error: 'func.call' op 'forward' does not reference a valid function %6 = call @forward(%0, %1, %2, %3, %4, %5) : (!torch.vtensor<[1,4,128,128],f16>, !torch.vtensor<[2,64,2048],f16>, !torch.vtensor<[2,1280],f16>, !torch.vtensor<[2,6],f16>, !torch.vtensor<[1],f16>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[1,4,128,128],f16>

However, the PNDMScheduler + Unet Torch IR file is able to create a vmfb.

What component(s) does this issue relate to?

Compiler

Version information

IREE: 3b30ab4e53ee85eb48b2a56081a25f87cc4b2a6d

Additional context

No response

stellaraccident commented 3 months ago

FYI - simplified repro:

./tools/iree-opt --torch-to-iree ~/tmp/stable_diffusion_xl_base_1_0_Euler_64_1024x1024_fp16_unet_30.mlir
stellaraccident commented 3 months ago

Appears to be something wrong with the inliner (it is not inlining something that it must). Trying to figure out why.

benvanik commented 3 months ago

could be unregistered/unknown ops or ops in a dialect without a DialectInlinerInterface? I think it bails if it can't ask the dialect of an op whether it's ok to inline

stellaraccident commented 3 months ago

I'm running it with -debug.

stellaraccident commented 3 months ago
* Illegal to inline because of op: %global_seed = ml_program.global_load @global_seed : tensor<i64>
stellaraccident commented 3 months ago

Will send a patch.

stellaraccident commented 3 months ago

Not that I expect this to be a regular thing, but I got this by running a modified version of the above iree-opt command with --debug --debug-only=inlining

stellaraccident commented 3 months ago

I'll send a patch to LLVM shortly to fix this, however, locally when I run it, it the --torch-to-iree pipeline then fails with:

/home/stella/tmp/stable_diffusion_xl_base_1_0_Euler_64_1024x1024_fp16_unet_30.mlir:1735:10: error: failed to legalize operation 'torch.aten.nonzero'
    %5 = torch.aten.nonzero %4 : !torch.vtensor<[30],i1> -> !torch.vtensor<[?,1],si64>
         ^
/home/stella/tmp/stable_diffusion_xl_base_1_0_Euler_64_1024x1024_fp16_unet_30.mlir:1725:10: note: called from
    %6 = call @forward(%0, %1, %2, %3, %4, %5) : (!torch.vtensor<[1,4,128,128],f16>, !torch.vtensor<[2,64,2048],f16>, !torch.vtensor<[2,1280],f16>, !torch.vtensor<[2,6],f16>, !torch.vtensor<[1],f16>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[1,4,128,128],f16>
         ^
/home/stella/tmp/stable_diffusion_xl_base_1_0_Euler_64_1024x1024_fp16_unet_30.mlir:1735:10: note: see current operation: %93 = "torch.aten.nonzero"(%92) : (!torch.vtensor<[30],i1>) -> !torch.vtensor<[?,1],si64>
    %5 = torch.aten.nonzero %4 : !torch.vtensor<[30],i1> -> !torch.vtensor<[?,1],si64>
ScottTodd commented 3 months ago

error: failed to legalize operation 'torch.aten.nonzero'

Can confirm, that's currently failing: https://github.com/openxla/iree/blob/d05b4a16f826ca328f7e737ae473e2b581dc950f/experimental/regression_suite/external_test_suite/config_cpu_llvm_sync.json#L357 test case with MLIR: https://github.com/nod-ai/SHARK-TestSuite/tree/main/iree_tests/onnx/node/generated/test_nonzero_example

stellaraccident commented 3 months ago

This LLVM patch will fix the inliner issue: https://github.com/llvm/llvm-project/pull/85479

saienduri commented 3 months ago

Oh nice! Is there somewhere we can find some documentation on these debugging flags in IREE?

ScottTodd commented 3 months ago

Oh nice! Is there somewhere we can find some documentation on these debugging flags in IREE?

saienduri commented 3 months ago

Thanks Scott!