nod-ai / SHARK-ModelDev

Unified compiler/runtime for interfacing with PyTorch Dynamo.
Apache License 2.0
95 stars 48 forks source link

onnx.ScatterElements #823

Closed vinayakdsci closed 1 month ago

vinayakdsci commented 2 months ago

Torch constant int operation fails with failed to legalize operation 'torch.constant.int' during lowering, as reported in https://github.com/nod-ai/SHARK-Turbine/issues/812.

IR to reproduce the issue

module {
  func.func @main_graph(%arg0: !torch.vtensor<[2708,1433],f32>, %arg1: !torch.vtensor<[2,2708],si64>, %arg3:!torch.vtensor<[2,?],si64>, %arg4: !torch.vtensor<[],si64>, %arg5:!torch.vtensor<[2708],f32>) -> !torch.vtensor<[2708],f32>    attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 17 : si64, torch.onnx_meta.producer_name = "pytorch", torch.onnx_meta.producer_version = "2.1.0"} {
    %6 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<__1> : tensor<si64>} : () -> !torch.vtensor<[],si64> 
    %24 = torch.operator "onnx.Gather"(%arg3, %6) {torch.onnx.axis = 0 : si64} : (!torch.vtensor<[2,?],si64>, !torch.vtensor<[],si64>) -> !torch.vtensor<[?],si64> 
    %51 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<__15> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64> 
    %52 = torch.operator "onnx.Unsqueeze"(%arg4, %51) : (!torch.vtensor<[],si64>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[1],si64> 
    %53 = torch.operator "onnx.Concat"(%52) {torch.onnx.axis = 0 : si64} : (!torch.vtensor<[1],si64>) -> !torch.vtensor<[1],si64> 
    %54 = torch.operator "onnx.ConstantOfShape"(%53) {torch.onnx.value = dense_resource<__16> : tensor<1xf32>} : (!torch.vtensor<[1],si64>) -> !torch.vtensor<[?],f32> 
    %56 = torch.operator "onnx.Shape"(%24) : (!torch.vtensor<[?],si64>) -> !torch.vtensor<[1],si64> 
    %57 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<__18> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64> 
    %58 = torch.operator "onnx.Slice"(%54, %57, %56) : (!torch.vtensor<[?],f32>, !torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[?],f32> 
    %59 = torch.operator "onnx.ScatterElements"(%arg5, %24, %58) {torch.onnx.axis = 0 : si64, torch.onnx.reduction = "add"} : (!torch.vtensor<[2708],f32>, !torch.vtensor<[?],si64>, !torch.vtensor<[?],f32>) -> !torch.vtensor<[2708],f32> 
    return %59 : !torch.vtensor<[2708],f32>
  }
}

{-#
  dialect_resources: {
    builtin: {
      __1: "0x080000000100000000000000",
      __15: "0x080000000000000000000000",
      __16: "0x080000000000803F",
      __18: "0x080000000000000000000000"
    }
  }
#-}

commands:

iree-compile --iree-hal-target-backends=llvm-cpu model.mlir
AmosLewis commented 2 months ago

The error is from the onnx.ScatterElements, torch2linalg is missing. https://github.com/nod-ai/SHARK-Turbine/issues/812

torch-mlir-opt --convert-torch-onnx-to-torch  --torch-decompose-complex-ops --cse --canonicalize --convert-torch-to-linalg model.mlir --debug
%58 = torch.aten.scatter.reduce %arg4, %int0, %57, %48, %str : !torch.vtensor<[2708],f32>, !torch.int, !torch.vtensor<[?],si64>, !torch.vtensor<[?],f32>, !torch.str -> !torch.vtensor<[2708],f32>

pytorch c++ implementation https://github.com/pytorch/pytorch/blob/4a6dfbe4806b361c43210dfd56db64c4097c66bb/aten/src/ATen/native/cpu/ScatterGatherKernel.cpp#L858

zjgarvey commented 1 month ago

@AmosLewis One of the turbine camp folk (JeongHyun Leon Kim) wanted to pick this up. Are you currently working on this?

AmosLewis commented 1 month ago

https://github.com/llvm/torch-mlir/pull/3754