iree-org / iree

A retargetable MLIR-based machine learning compiler and runtime toolkit.
http://iree.dev/
Apache License 2.0
2.82k stars 609 forks source link

WebGPU/tint SPIR-V lowering relying on detensorizing. #12509

Open benvanik opened 1 year ago

benvanik commented 1 year ago

While testing #12503 I had to disable the linalg detensorize pass. When I did I noticed some unique compilation failures in tint that were exposed. We're not going to disable the pass for real but this indicates that something deep in SPIR-V/tint land is relying on detensorizing and it's something that may come up in other contexts with a more difficult repro.

Here's the command/before/after by commenting out .addPass(mlir::createLinalgDetensorizePass) in compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.cpp. https://gist.github.com/benvanik/5ba24a269e6ac60cf4483bf94a1679b6

Tint reported 1 error(s) for a SPIR-V program, see diagnostics:
error: function parameter of pointer type cannot be in 'storage' address space
tests/e2e/tosa_ops/while.mlir:15:10: error: failed to compile SPIR-V to WGSL. Consider inspecting the shader program using -iree-hal-dump-executable-intermediates.
    %3 = "tosa.greater_equal"(%2, %arg0) : (tensor<i32>, tensor<i32>) -> tensor<i1>
         ^
tests/e2e/tosa_ops/while.mlir:15:10: error: failed to serialize executable for target backend webgpu
    %3 = "tosa.greater_equal"(%2, %arg0) : (tensor<i32>, tensor<i32>) -> tensor<i1>
         ^
// -----// IR Dump After mlir::iree_compiler::IREE::HAL::SerializeTargetExecutablesPass Failed (iree-hal-serialize-target-executables) //----- //
hal.executable private @_while_test_iter0_dispatch_0 {
  hal.executable.variant public @webgpu_wgsl_fb, target = <"webgpu", "webgpu-wgsl-fb", {spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [Shader], [SPV_KHR_storage_buffer_storage_class]>, api=WebGPU, #spirv.resource_limits<>>}> {
    hal.executable.export public @d0 ordinal(0) layout(#hal.pipeline.layout<push_constants = 0, sets = [<0, bindings = [<0, storage_buffer, ReadOnly>, <1, storage_buffer>]>]>) attributes {translation_info = #iree_codegen.translation_info<SPIRVBaseDistribute>, workgroup_size = [1 : index, 1 : index, 1 : index]} {
    ^bb0(%arg0: !hal.device):
      %c1 = arith.constant 1 : index
      hal.return %c1, %c1, %c1 : index, index, index
    }
    builtin.module attributes {spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [Shader], [SPV_KHR_storage_buffer_storage_class]>, api=WebGPU, #spirv.resource_limits<>>} {
      spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], [SPV_KHR_storage_buffer_storage_class]> {
        spirv.GlobalVariable @__resource_var_0_0_ bind(0, 0) : !spirv.ptr<!spirv.struct<(!spirv.array<1 x i32, stride=4> [0])>, StorageBuffer>
        spirv.GlobalVariable @__resource_var_0_1_ bind(0, 1) : !spirv.ptr<!spirv.struct<(!spirv.array<1 x i32, stride=4> [0])>, StorageBuffer>
        spirv.func @d0() "None" {
          %cst255_i32 = spirv.Constant 255 : i32
          %cst8_i32 = spirv.Constant 8 : i32
          %cst4_i32 = spirv.Constant 4 : i32
          %cst1_i32 = spirv.Constant 1 : i32
          %cst0_i32 = spirv.Constant 0 : i32
          %cst3_i32 = spirv.Constant 3 : i32
          %__resource_var_0_0__addr = spirv.mlir.addressof @__resource_var_0_0_ : !spirv.ptr<!spirv.struct<(!spirv.array<1 x i32, stride=4> [0])>, StorageBuffer>
          %__resource_var_0_1__addr = spirv.mlir.addressof @__resource_var_0_1_ : !spirv.ptr<!spirv.struct<(!spirv.array<1 x i32, stride=4> [0])>, StorageBuffer>
          %0 = spirv.AccessChain %__resource_var_0_0__addr[%cst0_i32, %cst0_i32] : !spirv.ptr<!spirv.struct<(!spirv.array<1 x i32, stride=4> [0])>, StorageBuffer>, i32, i32
          %1 = spirv.Load "StorageBuffer" %0 : i32
          %2 = spirv.SLessThanEqual %1, %cst3_i32 : i32
          %3 = spirv.Select %2, %cst1_i32, %cst0_i32 : i1, i32
          %4 = spirv.UMod %cst0_i32, %cst4_i32 : i32
          %5 = spirv.IMul %4, %cst8_i32 : i32
          %6 = spirv.ShiftLeftLogical %cst255_i32, %5 : i32, i32
          %7 = spirv.Not %6 : i32
          %8 = spirv.BitwiseAnd %3, %cst255_i32 : i32
          %9 = spirv.ShiftLeftLogical %8, %5 : i32, i32
          %10 = spirv.SDiv %cst0_i32, %cst4_i32 : i32
          %11 = spirv.AccessChain %__resource_var_0_1__addr[%cst0_i32, %10] : !spirv.ptr<!spirv.struct<(!spirv.array<1 x i32, stride=4> [0])>, StorageBuffer>, i32, i32
          %12 = spirv.AtomicAnd "Device" "AcquireRelease" %11, %7 : !spirv.ptr<i32, StorageBuffer>
          %13 = spirv.AtomicOr "Device" "AcquireRelease" %11, %9 : !spirv.ptr<i32, StorageBuffer>
          spirv.Return
        }
        spirv.EntryPoint "GLCompute" @d0
        spirv.ExecutionMode @d0 "LocalSize", 1, 1, 1
      }
    }
  }
}
ScottTodd commented 1 year ago

Maybe a dup of https://github.com/openxla/iree/issues/10906 (detensoring could be removing the problematic ops)

benvanik commented 1 year ago

Seems likely!

ScottTodd commented 1 year ago

We can update Tint to see if that fixes any of the older issues 🤞

dneto0 commented 1 year ago

We now have someone actively working on making the SPIR-V reader path in Tint production ready. Please file bugs to crbug.com/tint and add SpirvReader label