iree-org / iree

A retargetable MLIR-based machine learning compiler and runtime toolkit.
http://iree.dev/
Apache License 2.0
2.57k stars 574 forks source link

Calling one public function from another public function throws an error in Stream RefineUsagePass #16348

Open MaheshRavishankar opened 7 months ago

MaheshRavishankar commented 7 months ago

I encountered this error while trying to write some matchers. The core issue isnt important, but here is a link to the gist for the dump after all passes (https://gist.github.com/MaheshRavishankar/8def9d4a658c229410eb94956384ab62).

The failure itself is in iree-stream-refine-usage pass. This is the IR before the pass

#executable_target_embedded_elf_x86_64_ = #hal.executable.target<"llvm-cpu", "embedded-elf-x86_64", {data_layout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128", native_vector_size = 32 : index, target_triple = "x86_64-none-elf"}>
#pipeline_layout = #hal.pipeline.layout<push_constants = 3, sets = [<0, bindings = [<0, storage_buffer, ReadOnly>, <1, storage_buffer, ReadOnly>, <2, storage_buffer>]>]>
#device_target_llvm_cpu = #hal.device.target<"llvm-cpu", {executable_targets = [#executable_target_embedded_elf_x86_64_]}>
module @example attributes {hal.device.targets = [#device_target_llvm_cpu]} {
  func.func @call_mlp(%arg0: !stream.resource<*>, %arg1: index, %arg2: index, %arg3: index, %arg4: !stream.resource<*>, %arg5: index, %arg6: index, %arg7: index, %arg8: !stream.resource<*>, %arg9: index, %arg10: index, %arg11: index, %arg12: !stream.resource<*>, %arg13: index, %arg14: index, %arg15: index) -> (!stream.resource<*>, index, index, index) {
    %c4 = arith.constant 4 : index
    %c0 = arith.constant 0 : index
    %0 = arith.index_cast %arg2 : index to i32
    %1 = arith.index_cast %arg7 : index to i32
    %2 = arith.index_cast %arg3 : index to i32
    %3 = arith.muli %arg2, %c4 : index
    %4 = arith.muli %3, %arg7 : index
    %5 = stream.async.dispatch @executable::@x86_64::@mlp(%arg0[%c0 to %arg1 for %arg1], %arg4[%c0 to %arg5 for %arg5], %0, %1, %2) {hal.executable.ref = [@executable], hal.interface.bindings = [#hal.interface.binding<0, 0>, #hal.interface.binding<0, 1>, #hal.interface.binding<0, 2>]} : (!stream.resource<*>{%arg1}, !stream.resource<*>{%arg5}, i32, i32, i32) -> !stream.resource<*>{%4}
    return %5, %4, %arg2, %arg7 : !stream.resource<*>, index, index, index
  }
  hal.executable private @executable {
    hal.executable.variant public @x86_64 target(#executable_target_embedded_elf_x86_64_) objects([#hal.executable.object<{path = "samples/custom_dispatch/cpu/embedded/mlp_x86_64.o"}>]) {
      hal.executable.export public @mlp ordinal(0) layout(#pipeline_layout) {
      ^bb0(%arg0: !hal.device):
        %c1 = arith.constant 1 : index
        hal.return %c1, %c1, %c1 : index, index, index
      }
      builtin.module {
        func.func private @mlp_external(memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>, i32, i32, i32) attributes {hal.import.static}
        func.func @mlp() {
          %c0 = arith.constant 0 : index
          %0 = hal.interface.constant.load[0] : i32
          %1 = hal.interface.constant.load[1] : i32
          %2 = hal.interface.constant.load[2] : i32
          %3 = arith.index_cast %0 : i32 to index
          %4 = arith.index_cast %1 : i32 to index
          %5 = arith.index_cast %2 : i32 to index
          %6 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) : memref<?x?xf32>{%3, %5}
          %7 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) : memref<?x?xf32>{%5, %4}
          %8 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0) : memref<?x?xf32>{%3, %4}
          call @mlp_external(%6, %7, %8, %0, %1, %2) : (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>, i32, i32, i32) -> ()
          return
        }
      }
    }
  }
  func.func @mlp_invocation(%arg0: !hal.buffer_view, %arg1: !hal.buffer_view) -> !hal.buffer_view attributes {iree.abi.stub, iree.reflection = {iree.abi.declaration = "sync func @mlp_invocation(%input0: tensor<?x?xf32>, %input1: tensor<?x?xf32>) -> (%output0: tensor<?x?xf32>)"}} {
    %c4 = arith.constant 4 : index
    %0 = hal.buffer_view.dim<%arg0 : !hal.buffer_view>[0] : index
    %1 = hal.buffer_view.dim<%arg0 : !hal.buffer_view>[1] : index
    %element_type_f32 = hal.element_type<f32> : i32
    %dense_row_major = hal.encoding_type<dense_row_major> : i32
    hal.buffer_view.assert<%arg0 : !hal.buffer_view> message("input0") shape([%0, %1]) type(%element_type_f32) encoding(%dense_row_major)
    %2 = arith.muli %0, %c4 : index
    %3 = arith.muli %2, %1 : index
    %4 = stream.tensor.import %arg0 : !hal.buffer_view -> tensor<?x?xf32>{%0, %1} in !stream.resource<external>{%3}
    %5 = stream.async.transfer %4 : !stream.resource<external>{%3} -> !stream.resource<*>{%3}
    %6 = hal.buffer_view.dim<%arg1 : !hal.buffer_view>[0] : index
    %7 = hal.buffer_view.dim<%arg1 : !hal.buffer_view>[1] : index
    hal.buffer_view.assert<%arg1 : !hal.buffer_view> message("input1") shape([%6, %7]) type(%element_type_f32) encoding(%dense_row_major)
    %8 = arith.muli %6, %c4 : index
    %9 = arith.muli %8, %7 : index
    %10 = stream.tensor.import %arg1 : !hal.buffer_view -> tensor<?x?xf32>{%6, %7} in !stream.resource<external>{%9}
    %11 = stream.async.transfer %10 : !stream.resource<external>{%9} -> !stream.resource<*>{%9}
    %12 = arith.muli %2, %7 : index
    %13 = stream.async.alloca : !stream.resource<*>{%12}
    %14:4 = call @call_mlp(%5, %3, %0, %1, %11, %9, %6, %7, %13, %12, %0, %7, %13, %12, %0, %7) : (!stream.resource<*>, index, index, index, !stream.resource<*>, index, index, index, !stream.resource<*>, index, index, index, !stream.resource<*>, index, index, index) -> (!stream.resource<*>, index, index, index)
    %15 = stream.async.transfer %14#0 : !stream.resource<*>{%14#1} -> !stream.resource<external>{%14#1}
    %16 = stream.tensor.export %15 : tensor<?x?xf32>{%0, %7} in !stream.resource<external>{%14#1} -> !hal.buffer_view
    return %16 : !hal.buffer_view
  }
}

The error is

/home/mahesh/iree/iree/samples/custom_dispatch/cpu/embedded//mlp.mlir:32:13: error: 'func.call' op operand type mismatch: expected operand type '!stream.resource<external>', but provided '!stream.resource<transient>' for operand number 8
    %relu = linalg.generic {
            ^
/home/mahesh/iree/iree/samples/custom_dispatch/cpu/embedded//mlp.mlir:21:3: note: called from
  func.func @mlp_invocation(%lhs: tensor<?x?xf32>,
  ^
/home/mahesh/iree/iree/samples/custom_dispatch/cpu/embedded//mlp.mlir:32:13: note: see current operation: %15:4 = "func.call"(%7, %6, %1, %2, %12, %11, %8, %9, %14, %13, %1, %9, %14, %13, %1, %9) <{callee = @call_mlp}> : (!stream.resource<external>, index, index, index, !stream.resource<external>, index, index, index, !stream.resource<transient>, index, index, index, !stream.resource<transient>, index, index, index) -> (!stream.resource<external>, index, index, index)
    %relu = linalg.generic {
ScottTodd commented 7 months ago

Probably unrelated, but I tried something like this in SHARK-Turbine (https://github.com/nod-ai/SHARK-Turbine/issues/101) but hit an error up in Python. Interesting that there are errors at a lower level too. Code was like this:

from iree.compiler.ir import Context
import shark_turbine.aot as aot
import torch

counter_tensor = torch.tensor(0, dtype=torch.int32)

class CounterModule(aot.CompiledModule):
  counter = aot.export_global(counter_tensor, mutable=True)

  def get_value(self):
    return self.counter

  # RuntimeError: Calls to exported functions not yet supported
  def add_to_value(self, value=aot.abstractify(value_example)):
    self.counter = self.get_value() + value

counter_instance = CounterModule(context=Context())
print(aot.CompiledModule.get_mlir_module(counter_instance))
benvanik commented 7 months ago

we'll probably want WrapEntryPoints (or a pass nearby) to handle internalizing calls to other exports - since exports get changed to buffer views/etc that creates a lot of junk that we want to avoid when doing intra-module calls