iree-org / iree

A retargetable MLIR-based machine learning compiler and runtime toolkit.
http://iree.dev/
Apache License 2.0
2.82k stars 608 forks source link

Error in IREE official pytorch JIT Compilation Notebook #17876

Open adeel10x opened 3 months ago

adeel10x commented 3 months ago

What happened?

The last cell in the IREE official JIT Compilation Notebook fails to execute.

Error:

  func.func @main(%arg0: !torch.vtensor<[4,3],f32>, %arg1: !torch.vtensor<[3],f32>, %arg2: !torch.vtensor<[4],f32>) -> (!torch.vtensor<[3],f32>, !torch.vtensor<[1,4],f32>) {
    %int0 = torch.constant.int 0
    %0 = torch.aten.unsqueeze %arg2, %int0 : !torch.vtensor<[4],f32>, !torch.int -> !torch.vtensor<[1,4],f32>
    %1 = torch.aten.mm %0, %arg0 : !torch.vtensor<[1,4],f32>, !torch.vtensor<[4,3],f32> -> !torch.vtensor<[1,3],f32>
    %int0_0 = torch.constant.int 0
    %2 = torch.aten.squeeze.dim %1, %int0_0 : !torch.vtensor<[1,3],f32>, !torch.int -> !torch.vtensor<[3],f32>
    %int1 = torch.constant.int 1
    %3 = torch.aten.add.Tensor %2, %arg1, %int1 : !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>, !torch.int -> !torch.vtensor<[3],f32>
    return %3, %0 : !torch.vtensor<[3],f32>, !torch.vtensor<[1,4],f32>
  }
}

#map = affine_map<(d0) -> (d0)>
module {
  util.func public @main$async(%arg0: !hal.buffer_view, %arg1: !hal.buffer_view, %arg2: !hal.buffer_view, %arg3: !hal.fence, %arg4: !hal.fence) -> (!hal.buffer_view, !hal.buffer_view) attributes {inlining_policy = #util.inline.never, iree.abi.model = "coarse-fences", iree.abi.stub} {
    %cst = arith.constant 0.000000e+00 : f32
    %0 = hal.tensor.import wait(%arg3) => %arg0 : !hal.buffer_view -> tensor<4x3xf32>
    %1 = hal.tensor.import wait(%arg3) => %arg1 : !hal.buffer_view -> tensor<3xf32>
    %2 = hal.tensor.import wait(%arg3) => %arg2 : !hal.buffer_view -> tensor<4xf32>
    %expanded = tensor.expand_shape %2 [[0, 1]] output_shape [1, 4] : tensor<4xf32> into tensor<1x4xf32>
    %3 = tensor.empty() : tensor<1x3xf32>
    %4 = linalg.fill ins(%cst : f32) outs(%3 : tensor<1x3xf32>) -> tensor<1x3xf32>
    %5 = linalg.matmul ins(%expanded, %0 : tensor<1x4xf32>, tensor<4x3xf32>) outs(%4 : tensor<1x3xf32>) -> tensor<1x3xf32>
    %collapsed = tensor.collapse_shape %5 [[0, 1]] : tensor<1x3xf32> into tensor<3xf32>
    %6 = tensor.empty() : tensor<3xf32>
    %7 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%collapsed, %1 : tensor<3xf32>, tensor<3xf32>) outs(%6 : tensor<3xf32>) {
    ^bb0(%in: f32, %in_0: f32, %out: f32):
      %11 = arith.addf %in, %in_0 : f32
      linalg.yield %11 : f32
    } -> tensor<3xf32>
    %8:2 = hal.tensor.barrier join(%7, %expanded : tensor<3xf32>, tensor<1x4xf32>) => %arg4 : !hal.fence
    %9 = hal.tensor.export %8#0 : tensor<3xf32> -> !hal.buffer_view
    %10 = hal.tensor.export %8#1 : tensor<1x4xf32> -> !hal.buffer_view
    util.return %9, %10 : !hal.buffer_view, !hal.buffer_view
  }
  util.func public @main(%arg0: !hal.buffer_view, %arg1: !hal.buffer_view, %arg2: !hal.buffer_view) -> (!hal.buffer_view, !hal.buffer_view) attributes {iree.abi.stub} {
    %c-1_i32 = arith.constant -1 : i32
    %c0 = arith.constant 0 : index
    %device_0 = hal.devices.get %c0 : !hal.device
    %0 = util.null : !hal.fence
    %fence = hal.fence.create device(%device_0 : !hal.device) flags("None") : !hal.fence
    %1:2 = util.call @main$async(%arg0, %arg1, %arg2, %0, %fence) : (!hal.buffer_view, !hal.buffer_view, !hal.buffer_view, !hal.fence, !hal.fence) -> (!hal.buffer_view, !hal.buffer_view)
    %status = hal.fence.await until([%fence]) timeout_millis(%c-1_i32) : i32
    util.return %1#0, %1#1 : !hal.buffer_view, !hal.buffer_view
  }
}

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
[<ipython-input-16-6aa9231f8d81>](https://localhost:8080/#) in <cell line: 2>()
      1 args = torch.randn(4)
----> 2 turbine_output = opt_linear_module(args)
      3 
      4 print("Weight:", linear_module.weight)
      5 print("Bias:", linear_module.bias)

19 frames
[/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _wrapped_call_impl(self, *args, **kwargs)
   1530             return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1531         else:
-> 1532             return self._call_impl(*args, **kwargs)
   1533 
   1534     def _call_impl(self, *args, **kwargs):

[/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _call_impl(self, *args, **kwargs)
   1539                 or _global_backward_pre_hooks or _global_backward_hooks
   1540                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1541             return forward_call(*args, **kwargs)
   1542 
   1543         try:

[/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py](https://localhost:8080/#) in _fn(*args, **kwargs)
    449             prior = set_eval_frame(callback)
    450             try:
--> 451                 return fn(*args, **kwargs)
    452             finally:
    453                 set_eval_frame(prior)

[/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _wrapped_call_impl(self, *args, **kwargs)
   1530             return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1531         else:
-> 1532             return self._call_impl(*args, **kwargs)
   1533 
   1534     def _call_impl(self, *args, **kwargs):

[/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _call_impl(self, *args, **kwargs)
   1539                 or _global_backward_pre_hooks or _global_backward_hooks
   1540                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1541             return forward_call(*args, **kwargs)
   1542 
   1543         try:

[<ipython-input-14-2b94aa6395f4>](https://localhost:8080/#) in forward(self, input)
      7     self.bias = torch.nn.Parameter(torch.randn(out_features))
      8 
----> 9   def forward(self, input):
     10     return (input @ self.weight) + self.bias
     11 

[/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py](https://localhost:8080/#) in _fn(*args, **kwargs)
    449             prior = set_eval_frame(callback)
    450             try:
--> 451                 return fn(*args, **kwargs)
    452             finally:
    453                 set_eval_frame(prior)

[/usr/local/lib/python3.10/dist-packages/torch/_dynamo/external_utils.py](https://localhost:8080/#) in inner(*args, **kwargs)
     34     @functools.wraps(fn)
     35     def inner(*args, **kwargs):
---> 36         return fn(*args, **kwargs)
     37 
     38     return inner

[/usr/local/lib/python3.10/dist-packages/torch/_functorch/aot_autograd.py](https://localhost:8080/#) in forward(*runtime_args)
    915         full_args.extend(params_flat)
    916         full_args.extend(runtime_args)
--> 917         return compiled_fn(full_args)
    918 
    919     # Just for convenience

[/usr/local/lib/python3.10/dist-packages/torch/_functorch/_aot_autograd/utils.py](https://localhost:8080/#) in g(args)
     87 def make_boxed_func(f):
     88     def g(args):
---> 89         return f(*args)
     90 
     91     g._boxed_call = True  # type: ignore[attr-defined]

[/usr/local/lib/python3.10/dist-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py](https://localhost:8080/#) in runtime_wrapper(*args)
     86                     args_[idx] = args_[idx].detach()
     87             with torch.autograd._force_original_view_tracking(True):
---> 88                 all_outs = call_func_at_runtime_with_args(
     89                     compiled_fn,
     90                     args_,

[/usr/local/lib/python3.10/dist-packages/torch/_functorch/_aot_autograd/utils.py](https://localhost:8080/#) in call_func_at_runtime_with_args(f, args, steal_args, disable_amp)
    111     with context():
    112         if hasattr(f, "_boxed_call"):
--> 113             out = normalize_as_list(f(args))
    114         else:
    115             # TODO: Please remove soon

[/usr/local/lib/python3.10/dist-packages/torch/_functorch/_aot_autograd/utils.py](https://localhost:8080/#) in g(args)
     87 def make_boxed_func(f):
     88     def g(args):
---> 89         return f(*args)
     90 
     91     g._boxed_call = True  # type: ignore[attr-defined]

[/usr/local/lib/python3.10/dist-packages/torch/autograd/function.py](https://localhost:8080/#) in apply(cls, *args, **kwargs)
    596             # See NOTE: [functorch vjp and autograd interaction]
    597             args = _functorch.utils.unwrap_dead_wrappers(args)
--> 598             return super().apply(*args, **kwargs)  # type: ignore[misc]
    599 
    600         if not is_setup_ctx_defined:

[/usr/local/lib/python3.10/dist-packages/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py](https://localhost:8080/#) in forward(ctx, *deduped_flat_tensor_args)
    503                 #   of the original view, and not the synthetic base
    504 
--> 505                 fw_outs = call_func_at_runtime_with_args(
    506                     CompiledFunction.compiled_fw,
    507                     args,

[/usr/local/lib/python3.10/dist-packages/torch/_functorch/_aot_autograd/utils.py](https://localhost:8080/#) in call_func_at_runtime_with_args(f, args, steal_args, disable_amp)
    111     with context():
    112         if hasattr(f, "_boxed_call"):
--> 113             out = normalize_as_list(f(args))
    114         else:
    115             # TODO: Please remove soon

[/usr/local/lib/python3.10/dist-packages/torch/_functorch/_aot_autograd/utils.py](https://localhost:8080/#) in g(args)
     87 def make_boxed_func(f):
     88     def g(args):
---> 89         return f(*args)
     90 
     91     g._boxed_call = True  # type: ignore[attr-defined]

[/usr/local/lib/python3.10/dist-packages/shark_turbine/dynamo/executor.py](https://localhost:8080/#) in __call__(self, *inputs)
     91 
     92         # Move inputs to the device and add to arguments.
---> 93         self._inputs_to_device(inputs, arg_list)
     94         # TODO: Append semaphores for async execution.
     95 

[/usr/local/lib/python3.10/dist-packages/shark_turbine/dynamo/executor.py](https://localhost:8080/#) in _inputs_to_device(self, inputs, arg_list)
    106             # Since this is already a fallback case, just use the numpy array interop.
    107             # It isn't great, but meh... fallback case.
--> 108             device_array = asdevicearray(self.device_state.device, input_cpu)
    109             arg_list.push_ref(device_array._buffer_view)
    110 

[/usr/local/lib/python3.10/dist-packages/iree/runtime/array_interop.py](https://localhost:8080/#) in asdevicearray(device, a, dtype, implicit_host_transfer, memory_type, allowed_usage, element_type)
    265         )
    266     # First get an ndarray. Needs to be C-contiguous, enforcing it here.
--> 267     a = np.asarray(a, dtype=dtype, order="C")
    268     element_type = map_dtype_to_element_type(a.dtype)
    269     if element_type is None:

ValueError: object __array__ method not producing an array```

Steps to reproduce your issue

  1. Go to '...'
  2. Click on '....'
  3. Scroll down to '....'
  4. See error

What component(s) does this issue relate to?

Compiler, Runtime

Version information

iree-turbine, Version: 2.3.1 IREE compiler version 20240621.931 @ ac418d1f45d562bf9e9675bf69606c7d718e2432 LLVM version 19.0.0git Optimized build

PyTorch version: 2.3.0+cpu

Additional context

No response

ScottTodd commented 3 months ago

I can reproduce this. I thought that notebook was covered by https://github.com/iree-org/iree/blob/main/samples/colab/test_notebooks.py and https://github.com/iree-org/iree/actions/workflows/samples.yml . It seems to be running: https://github.com/iree-org/iree/actions/runs/9885342089/job/27303199921#step:4:53. Maybe the CI tests are using different package versions or just broke entirely recently.