iree-org / iree

A retargetable MLIR-based machine learning compiler and runtime toolkit.
http://iree.dev/
Apache License 2.0
2.5k stars 557 forks source link

[SPIR-V] Shared memory limit exceeded for batch matmul dispatch #16937

Open IanNod opened 3 months ago

IanNod commented 3 months ago

What happened?

Compiling a phi-2 model for vulkan-spirv backend with target triple rdna2-unknown-linux gives the following error:

failed to translate executables
haldump/configured_state_update_run_initialize$async_dispatch_19.mlir:9:6: error: 'func.func' op uses 1310720 bytes of shared memory; exceeded the limit of 65536 bytes
      func.func @run_initialize$async_dispatch_19_batch_matmul_transpose_b_32xDxDx80_f16xf32xf32() {
    ^
haldump/configured_state_update_run_initialize$async_dispatch_19.mlir:2:2: error: failed to run translation of source executable to target executable for backend #hal.executable.target<"vulkan-spirv", "vulkan-spirv-fb", {spirv.target_env = #spirv.target_env<#spirv.vce<v1.6, [Shader, Float64, Float16, Int64, Int16, Int8, StorageBuffer16BitAccess, StorageUniform16, StoragePushConstant16, StorageBuffer8BitAccess, UniformAndStorageBuffer8BitAccess, StoragePushConstant8, GroupNonUniform, GroupNonUniformVote, GroupNonUniformArithmetic, GroupNonUniformBallot, GroupNonUniformShuffle, GroupNonUniformShuffleRelative, GroupNonUniformClustered, GroupNonUniformQuad, PhysicalStorageBufferAddresses, VariablePointers, VariablePointersStorageBuffer, DotProduct, DotProductInputAll, DotProductInput4x8BitPacked, DotProductInput4x8Bit], [SPV_KHR_16bit_storage, SPV_KHR_8bit_storage, SPV_KHR_integer_dot_product, SPV_KHR_storage_buffer_storage_class, SPV_KHR_physical_storage_buffer, SPV_KHR_variable_pointers]>, api=Vulkan, AMD:DiscreteGPU, #spirv.resource_limits<max_compute_shared_memory_size = 65536, max_compute_workgroup_invocations = 1024, max_compute_workgroup_size = [1024, 1024, 1024], subgroup_size = 64, min_subgroup_size = 32, max_subgroup_size = 64, cooperative_matrix_properties_khr = []>>}>
  hal.executable.variant public @vulkan_spirv_fb target(<"vulkan-spirv", "vulkan-spirv-fb", {spirv.target_env = #spirv.target_env<#spirv.vce<v1.6, [Shader, Float64, Float16, Int64, Int16, Int8, StorageBuffer16BitAccess, StorageUniform16, StoragePushConstant16, StorageBuffer8BitAccess, UniformAndStorageBuffer8BitAccess, StoragePushConstant8, GroupNonUniform, GroupNonUniformVote, GroupNonUniformArithmetic, GroupNonUniformBallot, GroupNonUniformShuffle, GroupNonUniformShuffleRelative, GroupNonUniformClustered, GroupNonUniformQuad, PhysicalStorageBufferAddresses, VariablePointers, VariablePointersStorageBuffer, DotProduct, DotProductInputAll, DotProductInput4x8BitPacked, DotProductInput4x8Bit], [SPV_KHR_16bit_storage, SPV_KHR_8bit_storage, SPV_KHR_integer_dot_product, SPV_KHR_storage_buffer_storage_class, SPV_KHR_physical_storage_buffer, SPV_KHR_variable_pointers]>, api=Vulkan, AMD:DiscreteGPU, #spirv.resource_limits<max_compute_shared_memory_size = 65536, max_compute_workgroup_invocations = 1024, max_compute_workgroup_size = [1024, 1024, 1024], subgroup_size = 64, min_subgroup_size = 32, max_subgroup_size = 64, cooperative_matrix_properties_khr = []>>}>) {

Steps to reproduce your issue

phi-2 mlir can be found here: https://sharkpublic.blob.core.windows.net/sharkpublic/ian/phi_2.mlir

failed dispatch here: https://sharkpublic.blob.core.windows.net/sharkpublic/ian/configured_state_update_run_initialize$async_dispatch_19.mlir

iree compiler command used:

../iree-build/tools/iree-compile --iree-input-type=auto --iree-vm-bytecode-module-output-format=flatbuffer-binary --iree-hal-target-backends=vulkan-spirv --mlir-print-debuginfo --mlir-print-op-on-diagnostic=false --iree-input-type=torch --mlir-print-debuginfo --mlir-print-op-on-diagnostic=false --iree-llvmcpu-target-cpu-features=host --iree-llvmcpu-target-triple=x86_64-linux-gnu --iree-stream-resource-index-bits=64 --iree-vm-target-index-bits=64 --iree-vulkan-target-triple=rdna2-unknown-linux --iree-stream-resource-max-allocation-size=4294967296  --iree-hal-dump-executable-files-to=haldump ../SHARK-Turbine/phi_2.mlir -o phi_2.vmfb                                                    

What component(s) does this issue relate to?

Compiler

Version information

SHA: cc2ef92a232e4b6de9b845b6854d4d8667a6162b

Additional context

No response

kuhar commented 3 months ago

@Groverkss has a WIP PR for this on the LLVMGPU side here: https://github.com/openxla/iree/pull/16927. Kunwar, could you also take care of the SPIR-V path?

Groverkss commented 3 months ago

AFAIU my patch should also take care of SPIRV

kuhar commented 3 months ago

AFAIU my patch should also take care of SPIRV

Can you also add a SPIR-V regression test based on the batch matmul from this issue?

Groverkss commented 3 months ago

Sure, will do tomorrow

IanNod commented 3 months ago

AFAIU my patch should also take care of SPIRV

Checked out the https://github.com/openxla/iree/pull/16927 PR and still see the same error @Groverkss

Groverkss commented 3 months ago

AFAIU my patch should also take care of SPIRV

Checked out the #16927 PR and still see the same error @Groverkss

I'll have a look and fix it if something is wrong. Thanks for checking.

Groverkss commented 3 months ago

@kuhar I looked into this more. Currently, the cooperative matrix path is reusing the MMA heuristics deduction which my patch adds the shared memory check on. I will send a followup patch which makes the default SPIRV Distribute pipeline matmul config deduction also use that.