Open AmosLewis opened 1 month ago
python -m torch_mlir.tools.import_onnx --opset-version=21 model.onnx -o ScatterElements.default.torch-onnx.mlir
ScatterElements.default.torch-onnx.mlir
module {
func.func @scatter_graph(%arg0: !torch.vtensor<[4],f32>, %arg1: !torch.vtensor<[6],f32>) -> !torch.vtensor<[4],f32> attributes {torch.onnx_meta.ir_version = 10 : si64, torch.onnx_meta.opset_version = 21 : si64, torch.onnx_meta.producer_name = "", torch.onnx_meta.producer_version = ""} {
%none = torch.constant.none
%0 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<[0, 1, 0, 1, 2, 1]> : tensor<6xsi64>} : () -> !torch.vtensor<[6],si64>
%1 = torch.operator "onnx.ScatterElements"(%arg0, %0, %arg1) {torch.onnx.axis = 0 : si64, torch.onnx.reduction = "add"} : (!torch.vtensor<[4],f32>, !torch.vtensor<[6],si64>, !torch.vtensor<[6],f32>) -> !torch.vtensor<[4],f32>
return %1 : !torch.vtensor<[4],f32>
}
}
torch-mlir-opt -pass-pipeline='builtin.module(func.func(convert-torch-onnx-to-torch),torch-lower-to-backend-contract,func.func(cse,canonicalize))' ScatterElements.default.torch-onnx.mlir > ScatterElements.default.onnx.torch.mlir
ScatterElements.default.onnx.torch.mlir
module {
func.func @scatter_graph(%arg0: !torch.vtensor<[4],f32>, %arg1: !torch.vtensor<[6],f32>) -> !torch.vtensor<[4],f32> attributes {torch.onnx_meta.ir_version = 10 : si64, torch.onnx_meta.opset_version = 21 : si64, torch.onnx_meta.producer_name = "", torch.onnx_meta.producer_version = ""} {
%true = torch.constant.bool true
%str = torch.constant.str "sum"
%0 = torch.vtensor.literal(dense<[0, 1, 0, 1, 2, 1]> : tensor<6xsi64>) : !torch.vtensor<[6],si64>
%int0 = torch.constant.int 0
%1 = torch.aten.scatter_reduce.two %arg0, %int0, %0, %arg1, %str, %true : !torch.vtensor<[4],f32>, !torch.int, !torch.vtensor<[6],si64>, !torch.vtensor<[6],f32>, !torch.str, !torch.bool -> !torch.vtensor<[4],f32>
return %1 : !torch.vtensor<[4],f32>
}
}
torch-mlir-opt --convert-torch-onnx-to-torch --torch-decompose-complex-ops --cse --canonicalize --convert-torch-to-linalg ScatterElements.default.onnx.torch.mlir > linalg.mlir
linalg.mlir
#map = affine_map<(d0) -> (d0, 0)>
#map1 = affine_map<(d0) -> (d0)>
module {
func.func @scatter_graph(%arg0: !torch.vtensor<[4],f32>, %arg1: !torch.vtensor<[6],f32>) -> !torch.vtensor<[4],f32> attributes {torch.onnx_meta.ir_version = 10 : si64, torch.onnx_meta.opset_version = 21 : si64, torch.onnx_meta.producer_name = "", torch.onnx_meta.producer_version = ""} {
%0 = torch_c.to_builtin_tensor %arg1 : !torch.vtensor<[6],f32> -> tensor<6xf32>
%1 = torch_c.to_builtin_tensor %arg0 : !torch.vtensor<[4],f32> -> tensor<4xf32>
%true = torch.constant.bool true
%str = torch.constant.str "sum"
%2 = torch.vtensor.literal(dense<[0, 1, 0, 1, 2, 1]> : tensor<6xsi64>) : !torch.vtensor<[6],si64>
%3 = torch_c.to_builtin_tensor %2 : !torch.vtensor<[6],si64> -> tensor<6xi64>
%int0 = torch.constant.int 0
%c0 = arith.constant 0 : index
%c6 = arith.constant 6 : index
%c1 = arith.constant 1 : index
%4 = arith.muli %c1, %c6 : index
%5 = arith.index_cast %4 : index to i64
%6 = arith.index_cast %5 : i64 to index
%c0_0 = arith.constant 0 : index
%c6_1 = arith.constant 6 : index
%c1_2 = arith.constant 1 : index
%7 = tensor.empty(%6) : tensor<?x1xi32>
%c0_i32 = arith.constant 0 : i32
%8 = linalg.fill ins(%c0_i32 : i32) outs(%7 : tensor<?x1xi32>) -> tensor<?x1xi32>
%9 = tensor.empty(%6) : tensor<?xf32>
%cst = arith.constant 0.000000e+00 : f32
%10 = linalg.fill ins(%cst : f32) outs(%9 : tensor<?xf32>) -> tensor<?xf32>
%11:2 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel"]} outs(%8, %10 : tensor<?x1xi32>, tensor<?xf32>) {
^bb0(%out: i32, %out_13: f32):
%16 = linalg.index 0 : index
%17 = arith.remsi %16, %c6_1 : index
%18 = arith.divsi %16, %c6_1 : index
%extracted = tensor.extract %3[%17] : tensor<6xi64>
%extracted_14 = tensor.extract %0[%17] : tensor<6xf32>
%19 = arith.index_cast %17 : index to i64
%20 = arith.trunci %19 : i64 to i32
%21 = arith.trunci %extracted : i64 to i32
linalg.yield %21, %extracted_14 : i32, f32
} -> (tensor<?x1xi32>, tensor<?xf32>)
%c0_3 = arith.constant 0 : index
%c0_4 = arith.constant 0 : index
%c1_5 = arith.constant 1 : index
%c1_6 = arith.constant 1 : index
%c1_7 = arith.constant 1 : index
%12 = tensor.empty(%6) : tensor<?x1xi32>
%c0_i32_8 = arith.constant 0 : i32
%13 = linalg.fill ins(%c0_i32_8 : i32) outs(%12 : tensor<?x1xi32>) -> tensor<?x1xi32>
%c0_9 = arith.constant 0 : index
%dim = tensor.dim %11#0, %c0_9 : tensor<?x1xi32>
%c1_10 = arith.constant 1 : index
%c1_11 = arith.constant 1 : index
%inserted_slice = tensor.insert_slice %11#0 into %13[0, 0] [%dim, 1] [1, 1] : tensor<?x1xi32> into tensor<?x1xi32>
%c1_12 = arith.constant 1 : index
%14 = tm_tensor.scatter {dimension_map = array<i64: 0>} unique_indices(false) ins(%11#1, %inserted_slice : tensor<?xf32>, tensor<?x1xi32>) outs(%1 : tensor<4xf32>) {
^bb0(%arg2: f32, %arg3: f32):
%16 = arith.addf %arg2, %arg3 : f32
tm_tensor.yield %16 : f32
} -> tensor<4xf32>
%cast = tensor.cast %14 : tensor<4xf32> to tensor<4xf32>
%15 = torch_c.from_builtin_tensor %cast : tensor<4xf32> -> !torch.vtensor<[4],f32>
return %15 : !torch.vtensor<[4],f32>
}
}
Pass by most recent patch
Status report for run: test-run using mode:onnx todtype:default backend:llvm-cpu
| tests | model-run | onnx-import | torch-mlir | iree-compile | inference |
|:-------------------------------|:------------|:--------------|:-------------|:---------------|:------------|
| onnx/operators/ScatterElements | passed | passed | passed | passed | passed |
test torch.scatter.reduce linalg lowering https://github.com/llvm/torch-mlir/pull/3754
torch.scatter.reduce step by step example: