iree-org / iree

A retargetable MLIR-based machine learning compiler and runtime toolkit.
http://iree.dev/
Apache License 2.0
2.84k stars 614 forks source link

Running heterogenous multi-device example #18156

Open nirvedhmeshram opened 3 months ago

nirvedhmeshram commented 3 months ago

I tried running an example similar to the checked in test here with the exception that I would like to run across GPU and CPU. Here is the sample mlir test I used

!A_TYPE = tensor<32x32xf32>
func.func public @three_mm(%lhs : !A_TYPE {iree.abi.affinity = #hal.device.promise<@device_a>},
    %rhs : !A_TYPE {iree.abi.affinity = #hal.device.promise<@device_a>}) -> 
    ( !A_TYPE {iree.abi.affinity = #hal.device.promise<@device_a>} ){
  %empty = tensor.empty() : !A_TYPE
  %empty_2 = tensor.empty() : !A_TYPE
  %empty_3 = tensor.empty() : !A_TYPE
  %cst = arith.constant 0.0 : f32
  %fill = linalg.fill ins(%cst : f32) outs(%empty : !A_TYPE) -> !A_TYPE
  %2 = linalg.matmul ins(%lhs, %rhs : !A_TYPE, !A_TYPE)
      outs(%fill : !A_TYPE) -> !A_TYPE
  %transient_b = flow.tensor.transfer %2 : !A_TYPE to #hal.device.promise<@device_b>
  %3 = linalg.matmul ins(%transient_b , %transient_b : !A_TYPE, !A_TYPE)
      outs(%fill : !A_TYPE) -> !A_TYPE
  %transient_c = flow.tensor.transfer %3 : !A_TYPE to #hal.device.promise<@device_a>
  return %transient_c : !A_TYPE
}

Here is the iree-compile command I use to run on a rocm gpu and cpu

./tools/iree-compile ../matmul_multi.mlir \
--iree-execution-model=async-external \
-iree-hal-target-device=device_a=local -iree-hal-target-device=device_b=hip \
--iree-hal-local-target-device-backends=llvm-cpu \
--iree-hal-local-target-device-backends=rocm \
--iree-rocm-target-chip=gfx1100 -o output.vmfb

Here is the run command

./tools/iree-run-module --module=output.vmfb \
 --input=32x32xf32=1 --input=32x32xf32=2 \
--device=local-task --device=hip

And here is the error

iree/runtime/src/iree/hal/drivers/hip/event_semaphore.c:351: ABORTED; while calling import; while invoking native function hal.fence.await; 
[ 1]   native hal.fence.await:0 -
[ 0] bytecode module.three_mm:1140 ../matmul_multi.mlir:19:18
      at ../matmul_multi.mlir:10:1; invoking function 'three_mm'; `async func @three_mm(%input0: tensor<32x32xf32> {iree.abi.affinity = #hal.device.promise<@device_a>}, %input1: tensor<32x32xf32> {iree.abi.affinity = #hal.device.promise<@device_a>}) -> (%output0: tensor<32x32xf32> {iree.abi.affinity = #hal.device.promise<@device_a>})`
sogartar commented 3 months ago

I suspect that some semaphore has failed, causing the hal.fence.await to fail. We need a better error message in this case to see the original message that causes the wait to fail. Maybe inside the wait function we could query all the semaphores and chain/annotate the ABORTED status with the ones that are failing. I don't think we should be too concerned about performance on the error path.