iree-org / iree

A retargetable MLIR-based machine learning compiler and runtime toolkit.
http://iree.dev/
Apache License 2.0
2.84k stars 614 forks source link

im2col with passpipeline: VK_ERROR_INITIALIZATION_FAILED in SIMT path (non wmma) #12057

Closed powderluv closed 1 year ago

powderluv commented 1 year ago

What happened?

When compiling UNET with

 % iree-compile --iree-input-type=none --iree-vm-bytecode-module-output-format=flatbuffer-binary --iree-hal-target-backends=vulkan --iree-llvm-target-cpu-features=host  --iree-stream-resource-index-bits=64 --iree-vm-target-index-bits=64 --iree-util-zero-fill-elided-attrs -iree-vulkan-target-triple=rdna2-unknown-linux -iree-preprocessing-pass-pipeline="builtin.module(func.func(iree-preprocessing-convert-conv2d-to-img2col,iree-preprocessing-pad-linalg-ops{pad-size=32}))" compile_tests/unet64_512_512_fp16_stabilityai_stable_diffusion_2_1_base_torch.mlir -o output.vmfb 

We get a VK_ERROR_INITIALIZATION_FAILED when it is run.

 % iree-benchmark-module --module=./output.vmfb --function=forward --device=vulkan --input=2x4x64x64xf32 --input=1xf32 --input=2x77x768xf32
main_checkout/runtime/src/iree/hal/drivers/vulkan/native_executable.cc:153: UNAVAILABLE; VK_ERROR_INITIALIZATION_FAILED; while invoking native function hal.executable.create; while calling import; 
[ 1]   native hal.executable.create:0 -
[ 0] bytecode module.__init:7566 <eval_with_key>.204:651:13

However dropping the pass pipeline makes it work

 % iree-compile --iree-input-type=none --iree-vm-bytecode-module-output-format=flatbuffer-binary --iree-hal-target-backends=vulkan --iree-llvm-target-cpu-features=host  --iree-stream-resource-index-bits=64 --iree-vm-target-index-bits=64 --iree-util-zero-fill-elided-attrs -iree-vulkan-target-triple=rdna2-unknown-linux compile_tests/unet64_512_512_fp16_stabilityai_stable_diffusion_2_1_base_torch.mlir -o output.vmfb 

it runs ok.

Steps to reproduce your issue

No response

What component(s) does this issue relate to?

No response

Version information

No response

Additional context

No response

powderluv commented 1 year ago

ok so just im2col seems to be broken.

 % iree-compile --iree-input-type=none --iree-vm-bytecode-module-output-format=flatbuffer-binary --iree-hal-target-backends=vulkan --iree-llvm-target-cpu-features=host  --iree-stream-resource-index-bits=64 --iree-vm-target-index-bits=64 --iree-util-zero-fill-elided-attrs -iree-vulkan-target-triple=rdna2-unknown-linux -iree-preprocessing-pass-pipeline="builtin.module(func.func(iree-preprocessing-pad-linalg-ops{pad-size=32}))" ../SHARK/compile_tests/unet64_512_512_fp16_stabilityai_stable_diffusion_2_1_base_torch.mlir -o output.vmfb 
``` works ok. 

while the following doesn't

-iree-preprocessing-pass-pipeline="builtin.module(func.func(iree-preprocessing-convert-conv2d-to-img2col,iree-preprocessing-pad-linalg-ops{pad-size=32}))"

MaheshRavishankar commented 1 year ago

Maybe @antiagainst can take a look. I don't really use that pass.

ThomasRaoux commented 1 year ago

This creates large shared memory:

    spirv.GlobalVariable @__workgroup_mem__6 : !spirv.ptr<!spirv.struct<(!spirv.array<2112 x vector<4xf32>>)>, Workgroup>
    spirv.GlobalVariable @__workgroup_mem__5 : !spirv.ptr<!spirv.struct<(!spirv.array<640 x vector<4xf32>>)>, Workgroup>
    spirv.GlobalVariable @__workgroup_mem__4 : !spirv.ptr<!spirv.struct<(!spirv.array<2112 x vector<4xf32>>)>, Workgroup>

from forward_dispatch_161_generic_2x640x1024x320

This is most likely caused by the change in pass order due to https://github.com/iree-org/iree/commit/142894130b8506cb1860729dd2a988b7756fcdca

powderluv commented 1 year ago

This is now resolved downstream with the right flags.