nod-ai / SHARK

SHARK - High Performance Machine Learning Distribution
Apache License 2.0
1.4k stars 169 forks source link

Shared memory issues with SD on Vulkan #2112

Closed gpetters-amd closed 2 months ago

gpetters-amd commented 2 months ago

Unet on Vulkan RDNA2 is failing with the following shared memory issue:

iree.compiler.tools.binaries.CompilerToolError: Error invoking IREE compiler tool iree-compile.exe
Error code: 1
Diagnostics:
failed to translate executables
failed to translate executables
C:\Users\gpetters\Desktop\SHARK\apps\shark_studio\web\shark_tmp\unet.torch.tempfile:843:26: error: 'func.func' op uses 10486016 bytes of shared memory; exceeded the limit of 16384 bytes
    %result0, %result1 = torch.aten.var_mean.correction %58, %59, %int0_61, %true : !torch.vtensor<[2,32,10,4096],f32>, !torch.list<int>, !torch.int, !torch.bool -> !torch.vtensor<[2,32,1,1],f32>, !torch.vtensor<[2,32,1,1],f32>
                         ^
C:\Users\gpetters\Desktop\SHARK\apps\shark_studio\web\shark_tmp\unet.torch.tempfile:693:10: note: called from
    %4 = call @forward(%0, %1, %2, %3) : (!torch.vtensor<[1,4,64,64],f16>, !torch.vtensor<[1],f16>, !torch.vtensor<[2,77,1024],f16>, !torch.vtensor<[1],f16>) -> !torch.vtensor<[1,4,64,64],f16>
         ^
C:\Users\gpetters\Desktop\SHARK\apps\shark_studio\web\shark_tmp\unet.torch.tempfile:843:26: error: failed to run translation of source executable to target executable for backend #hal.executable.target<"vulkan-spirv", "vulkan-spirv-fb", {spirv.target_env = #spirv.target_env<#spirv.vce<v1.3, [Shader, GroupNonUniform], [SPV_KHR_storage_buffer_storage_class, SPV_KHR_variable_pointers]>, api=Vulkan, #spirv.resource_limits<max_compute_workgroup_size = [128, 128, 64], subgroup_size = 64, cooperative_matrix_properties_khr = []>>}>
    %result0, %result1 = torch.aten.var_mean.correction %58, %59, %int0_61, %true : !torch.vtensor<[2,32,10,4096],f32>, !torch.list<int>, !torch.int, !torch.bool -> !torch.vtensor<[2,32,1,1],f32>, !torch.vtensor<[2,32,1,1],f32>
                         ^
C:\Users\gpetters\Desktop\SHARK\apps\shark_studio\web\shark_tmp\unet.torch.tempfile:693:10: note: called from
    %4 = call @forward(%0, %1, %2, %3) : (!torch.vtensor<[1,4,64,64],f16>, !torch.vtensor<[1],f16>, !torch.vtensor<[2,77,1024],f16>, !torch.vtensor<[1],f16>) -> !torch.vtensor<[1,4,64,64],f16>
         ^
C:\Users\gpetters\Desktop\SHARK\apps\shark_studio\web\shark_tmp\unet.torch.tempfile:850:11: error: 'func.func' op uses 10486016 bytes of shared memory; exceeded the limit of 16384 bytes
    %63 = torch.aten.mul.Tensor %62, %61 : !torch.vtensor<[2,32,10,4096],f32>, !torch.vtensor<[2,32,1,1],f32> -> !torch.vtensor<[2,32,10,4096],f32>
          ^
C:\Users\gpetters\Desktop\SHARK\apps\shark_studio\web\shark_tmp\unet.torch.tempfile:693:10: note: called from
    %4 = call @forward(%0, %1, %2, %3) : (!torch.vtensor<[1,4,64,64],f16>, !torch.vtensor<[1],f16>, !torch.vtensor<[2,77,1024],f16>, !torch.vtensor<[1],f16>) -> !torch.vtensor<[1,4,64,64],f16>
         ^
C:\Users\gpetters\Desktop\SHARK\apps\shark_studio\web\shark_tmp\unet.torch.tempfile:850:11: error: failed to run translation of source executable to target executable for backend #hal.executable.target<"vulkan-spirv", "vulkan-spirv-fb", {spirv.target_env = #spirv.target_env<#spirv.vce<v1.3, [Shader, GroupNonUniform], [SPV_KHR_storage_buffer_storage_class, SPV_KHR_variable_pointers]>, api=Vulkan, #spirv.resource_limits<max_compute_workgroup_size = [128, 128, 64], subgroup_size = 64, cooperative_matrix_properties_khr = []>>}>
    %63 = torch.aten.mul.Tensor %62, %61 : !torch.vtensor<[2,32,10,4096],f32>, !torch.vtensor<[2,32,1,1],f32> -> !torch.vtensor<[2,32,10,4096],f32>
          ^
C:\Users\gpetters\Desktop\SHARK\apps\shark_studio\web\shark_tmp\unet.torch.tempfile:693:10: note: called from
    %4 = call @forward(%0, %1, %2, %3) : (!torch.vtensor<[1,4,64,64],f16>, !torch.vtensor<[1],f16>, !torch.vtensor<[2,77,1024],f16>, !torch.vtensor<[1],f16>) -> !torch.vtensor<[1,4,64,64],f16>
         ^

Invoked with:
 iree-compile.exe C:\Users\gpetters\Desktop\SHARK\shark.venv\Lib\site-packages\iree\compiler\tools\..\_mlir_libs\iree-compile.exe C:\Users\gpetters\Desktop\SHARK\apps\shark_studio\web\shark_tmp\unet.torch.tempfile --iree-input-type=torch --iree-vm-bytecode-module-output-format=flatbuffer-binary --iree-hal-target-backends=vulkan-spirv --mlir-print-debuginfo --mlir-print-op-on-diagnostic=false --iree-llvmcpu-target-cpu-features=host --iree-stream-resource-max-allocation-size=3221225472 --iree-flow-inline-constants-max-byte-length=0 --iree-vulkan-target-triple=rdna2-unknown-windows --iree-vulkan-target-env=<#spirv.vce<v1.3, r(120), [SPV_KHR_16bit_storage, SPV_KHR_8bit_storage, SPV_KHR_shader_float16_int8, SPV_KHR_spirv_1_4, SPV_KHR_storage_buffer_storage_class, SPV_KHR_variable_pointers, VK_EXT_subgroup_size_control]>, AMD:DiscreteGPU, #spirv.resource_limits< max_compute_shared_memory_size = 65536, max_compute_workgroup_invocations = 1024, max_compute_workgroup_size = dense<[1024, 1024, 1024]>: vector<3xi32>, subgroup_size = 64, subgroup_features = 255: i32, min_subgroup_size = 32, max_subgroup_size = 64, shaderFloat16 = unit, shaderFloat64 = unit, shaderInt8 = unit, shaderInt16 = unit, shaderInt64 = unit, storageBuffer16BitAccess = unit, storagePushConstant16 = unit, uniformAndStorageBuffer16BitAccess = unit, storageBuffer8BitAccess = unit, storagePushConstant8 = unit, uniformAndStorageBuffer8BitAccess = unit, variablePointers = unit, variablePointersStorageBuffer = unit, shaderIntegerDotProduct = unit >> --iree-util-zero-fill-elided-attrs --mlir-elide-elementsattrs-if-larger=10 --iree-opt-strip-assertions=true --verify=false --iree-opt-const-eval=False

Need more information? Set IREE_SAVE_TEMPS=/some/dir in your environment to save all artifacts and reproducers.
monorimet commented 2 months ago

@gpetters-amd I'm almost positive those vulkan target env flags are a little off. Can you try using the default target env for the triple? (i.e. try without the target env flag)

gpetters-amd commented 2 months ago

@monorimet Yep, that was it. I'm surprised the target env stuff is still giving us issues, I thought we fixed it back in 2.0.