nod-ai / SHARK-TestSuite

Temporary home of a test suite we are evaluating
Apache License 2.0
5 stars 35 forks source link

Add e2e test for onnx.ScatterElements / torch.scatter.reduce #363

Open AmosLewis opened 1 month ago

AmosLewis commented 1 month ago

test torch.scatter.reduce linalg lowering https://github.com/llvm/torch-mlir/pull/3754

torch.scatter.reduce step by step example:

src = [1, 2, 3, 4, 5, 6]
index = [0, 1, 0, 1, 2, 1]
self = [1, 2, 3, 4]
Step 0:
self[index[0]] += src[0]
self[0] += 1  = 1+1 = 2
1+1 = 2
self = [2, 2, 3, 4])

Step 1:
self[index[1]] += src[1]
self[1] += 2  = 2+2 = 4
self = [2, 4, 3, 4])

Step 2:
self[index[2]] += src[2]
self[0] += 3  = 2+3 = 5
self = [5, 4, 3, 4])

Step 3:
self[index[3]] += src[3]
self[1] += 4  = 4+4 = 8
self = [5, 8, 3, 4])

Step 4:
self[index[4]] += src[4]
self[2] += 5  = 3+5 = 8
self = [5, 8, 8, 4])

Step 5:
self[index[5]] += src[5]
self[1] += 6  = 8+6 = 14
self = [5, 14, 8, 4])
AmosLewis commented 1 month ago

python -m torch_mlir.tools.import_onnx --opset-version=21 model.onnx -o ScatterElements.default.torch-onnx.mlir ScatterElements.default.torch-onnx.mlir

module {
  func.func @scatter_graph(%arg0: !torch.vtensor<[4],f32>, %arg1: !torch.vtensor<[6],f32>) -> !torch.vtensor<[4],f32> attributes {torch.onnx_meta.ir_version = 10 : si64, torch.onnx_meta.opset_version = 21 : si64, torch.onnx_meta.producer_name = "", torch.onnx_meta.producer_version = ""} {
    %none = torch.constant.none
    %0 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<[0, 1, 0, 1, 2, 1]> : tensor<6xsi64>} : () -> !torch.vtensor<[6],si64> 
    %1 = torch.operator "onnx.ScatterElements"(%arg0, %0, %arg1) {torch.onnx.axis = 0 : si64, torch.onnx.reduction = "add"} : (!torch.vtensor<[4],f32>, !torch.vtensor<[6],si64>, !torch.vtensor<[6],f32>) -> !torch.vtensor<[4],f32> 
    return %1 : !torch.vtensor<[4],f32>
  }
}

torch-mlir-opt -pass-pipeline='builtin.module(func.func(convert-torch-onnx-to-torch),torch-lower-to-backend-contract,func.func(cse,canonicalize))' ScatterElements.default.torch-onnx.mlir > ScatterElements.default.onnx.torch.mlir ScatterElements.default.onnx.torch.mlir

module {
  func.func @scatter_graph(%arg0: !torch.vtensor<[4],f32>, %arg1: !torch.vtensor<[6],f32>) -> !torch.vtensor<[4],f32> attributes {torch.onnx_meta.ir_version = 10 : si64, torch.onnx_meta.opset_version = 21 : si64, torch.onnx_meta.producer_name = "", torch.onnx_meta.producer_version = ""} {
    %true = torch.constant.bool true
    %str = torch.constant.str "sum"
    %0 = torch.vtensor.literal(dense<[0, 1, 0, 1, 2, 1]> : tensor<6xsi64>) : !torch.vtensor<[6],si64>
    %int0 = torch.constant.int 0
    %1 = torch.aten.scatter_reduce.two %arg0, %int0, %0, %arg1, %str, %true : !torch.vtensor<[4],f32>, !torch.int, !torch.vtensor<[6],si64>, !torch.vtensor<[6],f32>, !torch.str, !torch.bool -> !torch.vtensor<[4],f32>
    return %1 : !torch.vtensor<[4],f32>
  }
}

torch-mlir-opt --convert-torch-onnx-to-torch --torch-decompose-complex-ops --cse --canonicalize --convert-torch-to-linalg ScatterElements.default.onnx.torch.mlir > linalg.mlir linalg.mlir

#map = affine_map<(d0) -> (d0, 0)>
#map1 = affine_map<(d0) -> (d0)>
module {
  func.func @scatter_graph(%arg0: !torch.vtensor<[4],f32>, %arg1: !torch.vtensor<[6],f32>) -> !torch.vtensor<[4],f32> attributes {torch.onnx_meta.ir_version = 10 : si64, torch.onnx_meta.opset_version = 21 : si64, torch.onnx_meta.producer_name = "", torch.onnx_meta.producer_version = ""} {
    %0 = torch_c.to_builtin_tensor %arg1 : !torch.vtensor<[6],f32> -> tensor<6xf32>
    %1 = torch_c.to_builtin_tensor %arg0 : !torch.vtensor<[4],f32> -> tensor<4xf32>
    %true = torch.constant.bool true
    %str = torch.constant.str "sum"
    %2 = torch.vtensor.literal(dense<[0, 1, 0, 1, 2, 1]> : tensor<6xsi64>) : !torch.vtensor<[6],si64>
    %3 = torch_c.to_builtin_tensor %2 : !torch.vtensor<[6],si64> -> tensor<6xi64>
    %int0 = torch.constant.int 0
    %c0 = arith.constant 0 : index
    %c6 = arith.constant 6 : index
    %c1 = arith.constant 1 : index
    %4 = arith.muli %c1, %c6 : index
    %5 = arith.index_cast %4 : index to i64
    %6 = arith.index_cast %5 : i64 to index
    %c0_0 = arith.constant 0 : index
    %c6_1 = arith.constant 6 : index
    %c1_2 = arith.constant 1 : index
    %7 = tensor.empty(%6) : tensor<?x1xi32>
    %c0_i32 = arith.constant 0 : i32
    %8 = linalg.fill ins(%c0_i32 : i32) outs(%7 : tensor<?x1xi32>) -> tensor<?x1xi32>
    %9 = tensor.empty(%6) : tensor<?xf32>
    %cst = arith.constant 0.000000e+00 : f32
    %10 = linalg.fill ins(%cst : f32) outs(%9 : tensor<?xf32>) -> tensor<?xf32>
    %11:2 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel"]} outs(%8, %10 : tensor<?x1xi32>, tensor<?xf32>) {
    ^bb0(%out: i32, %out_13: f32):
      %16 = linalg.index 0 : index
      %17 = arith.remsi %16, %c6_1 : index
      %18 = arith.divsi %16, %c6_1 : index
      %extracted = tensor.extract %3[%17] : tensor<6xi64>
      %extracted_14 = tensor.extract %0[%17] : tensor<6xf32>
      %19 = arith.index_cast %17 : index to i64
      %20 = arith.trunci %19 : i64 to i32
      %21 = arith.trunci %extracted : i64 to i32
      linalg.yield %21, %extracted_14 : i32, f32
    } -> (tensor<?x1xi32>, tensor<?xf32>)
    %c0_3 = arith.constant 0 : index
    %c0_4 = arith.constant 0 : index
    %c1_5 = arith.constant 1 : index
    %c1_6 = arith.constant 1 : index
    %c1_7 = arith.constant 1 : index
    %12 = tensor.empty(%6) : tensor<?x1xi32>
    %c0_i32_8 = arith.constant 0 : i32
    %13 = linalg.fill ins(%c0_i32_8 : i32) outs(%12 : tensor<?x1xi32>) -> tensor<?x1xi32>
    %c0_9 = arith.constant 0 : index
    %dim = tensor.dim %11#0, %c0_9 : tensor<?x1xi32>
    %c1_10 = arith.constant 1 : index
    %c1_11 = arith.constant 1 : index
    %inserted_slice = tensor.insert_slice %11#0 into %13[0, 0] [%dim, 1] [1, 1] : tensor<?x1xi32> into tensor<?x1xi32>
    %c1_12 = arith.constant 1 : index
    %14 = tm_tensor.scatter {dimension_map = array<i64: 0>} unique_indices(false) ins(%11#1, %inserted_slice : tensor<?xf32>, tensor<?x1xi32>) outs(%1 : tensor<4xf32>) {
    ^bb0(%arg2: f32, %arg3: f32):
      %16 = arith.addf %arg2, %arg3 : f32
      tm_tensor.yield %16 : f32
    } -> tensor<4xf32>
    %cast = tensor.cast %14 : tensor<4xf32> to tensor<4xf32>
    %15 = torch_c.from_builtin_tensor %cast : tensor<4xf32> -> !torch.vtensor<[4],f32>
    return %15 : !torch.vtensor<[4],f32>
  }
}
AmosLewis commented 1 month ago

Pass by most recent patch

Status report for run: test-run using mode:onnx todtype:default backend:llvm-cpu

| tests                          | model-run   | onnx-import   | torch-mlir   | iree-compile   | inference   |
|:-------------------------------|:------------|:--------------|:-------------|:---------------|:------------|
| onnx/operators/ScatterElements | passed      | passed        | passed       | passed         | passed      |