llvm / torch-mlir

The Torch-MLIR project aims to provide first class support from the PyTorch ecosystem to the MLIR ecosystem.
Other
1.35k stars 505 forks source link

convert-torch-onnx-to-torch generates invalid IR for onnx.Resize where scaling is in the first two dimensions #3453

Open mgehre-amd opened 4 months ago

mgehre-amd commented 4 months ago

The code in that pass seems to silently assume the the first two dimensions are not scaled, but ONNX has no such restriction.

With input

func.func @test_resize_middle(%arg0: !torch.vtensor<[1,36,42,384],f32>) -> !torch.vtensor<[1,72,84,384],f32> 
  attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
 %none = torch.constant.none
 %12 = torch.vtensor.literal(dense<[1.000000e+00, 2.000000e+00, 2.000000e+00, 1.000000e+00]> : tensor<4xf32>) : !torch.vtensor<[4],f32>
 %19 = torch.operator "onnx.Resize"(%arg0, %none, %12) {
  torch.onnx.coordinate_transformation_mode = "half_pixel",
  torch.onnx.mode = "nearest",
  torch.onnx.nearest_mode = "round_prefer_floor"} : (!torch.vtensor<[1,36,42,384],f32>, !torch.none, !torch.vtensor<[4],f32>) -> !torch.vtensor<[1,72,84,384],f32>
 return %19 : !torch.vtensor<[1,72,84,384],f32>
}

we get

module {
  func.func @test_resize_middle(%arg0: !torch.vtensor<[1,36,42,384],f32>) -> !torch.vtensor<[1,72,84,384],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
    %none = torch.constant.none
    %0 = torch.vtensor.literal(dense<[1.000000e+00, 2.000000e+00, 2.000000e+00, 1.000000e+00]> : tensor<4xf32>) : !torch.vtensor<[4],f32>
    %none_0 = torch.constant.none
    %int0 = torch.constant.int 0
    %false = torch.constant.bool false
    %true = torch.constant.bool true
    %str = torch.constant.str "nearest_half_pixel,round_prefer_floor"
    %int2 = torch.constant.int 2
    %1 = torch.aten.select.int %0, %int0, %int2 : !torch.vtensor<[4],f32>, !torch.int, !torch.int -> !torch.vtensor<[1],f32>
    %2 = torch.aten.item %1 : !torch.vtensor<[1],f32> -> !torch.float
    %int3 = torch.constant.int 3
    %3 = torch.aten.select.int %0, %int0, %int3 : !torch.vtensor<[4],f32>, !torch.int, !torch.int -> !torch.vtensor<[1],f32>
    %4 = torch.aten.item %3 : !torch.vtensor<[1],f32> -> !torch.float
    %5 = torch.prim.ListConstruct %2, %4 : (!torch.float, !torch.float) -> !torch.list<float>
    %6 = torch.aten.__interpolate.size_list_scale_list %arg0, %none_0, %5, %str, %false, %none_0, %false : !torch.vtensor<[1,36,42,384],f32>, !torch.none, !torch.list<float>, !torch.str, !torch.bool, !torch.none, !torch.bool -> !torch.vtensor<[1,72,84,384],f32>
    return %6 : !torch.vtensor<[1,72,84,384],f32>
  }
}

Here, %1 and %3 only read the last two elements out of %0. When lowering this IR to linalg (-convert-torch-to-linalg), we get error: unexpected error: 'tensor.cast' op operand type 'tensor<1x36x?x?xf32>' and result type 'tensor<1x72x84x384xf32>' are cast incompatible because the scales used by torch.aten.__interpolate.size_list_scale_list are not matching the output shape anymore.

arnavmehta1 commented 2 months ago

I can work on this.