llvm / torch-mlir

The Torch-MLIR project aims to provide first class support from the PyTorch ecosystem to the MLIR ecosystem.
Other
1.35k stars 507 forks source link

Some failed ops when legalizing from torch to stablehlo #2020

Open MUR-83 opened 1 year ago

MUR-83 commented 1 year ago

Hi, I'm experimenting with conversions of different pre-trained models to all 3 output dialects, and met 6 failures when converting to stablehlo. I'm showing the (simplified) examples below. Are this not implemented yet? Note: some of these examples also fail when converting to tosa/linalg Note2: examples with torch.constant depend on operations around the failing one: e.g. torch.constant.int 0 sometimes passes, sometimes fails.

  1. error: failed to legalize operation 'torch.constant.int'
    module attributes {torch.debug_module_name = "_lambda"} {
    func.func @forward(%arg0: !torch.int) -> !torch.int {
    %int0 = torch.constant.int 0
    return %int0 : !torch.int
    }
    }

2.error: failed to legalize operation 'torch.constant.float'

module attributes {torch.debug_module_name = "_lambda"} {
  func.func @forward(%arg0: !torch.float) -> !torch.float {
    %float-Inf = torch.constant.float 0xFFF0000000000000
    return %float-Inf : !torch.float
  }
}

3.error: failed to legalize operation 'torch.aten.any'. (torch.constant.int 0 here seems ok!?)

module attributes {torch.debug_module_name = "_lambda"} {
  func.func @forward(%arg0: !torch.vtensor<[2],si64>) -> !torch.vtensor<[],i1> {
    %int0 = torch.constant.int 0
    %0 = torch.aten.eq.Scalar %arg0, %int0 : !torch.vtensor<[2],si64>, !torch.int -> !torch.vtensor<[2],i1>
    %1 = torch.aten.any %0 : !torch.vtensor<[2],i1> -> !torch.vtensor<[],i1>
    return %1 : !torch.vtensor<[],i1>
  }
}

4.error: failed to legalize operation 'torch.prim.ListConstruct'

module attributes {torch.debug_module_name = "_lambda"} {
  func.func @forward(%arg0: !torch.vtensor<[1],si64>) -> !torch.list<optional<vtensor>> {
    %1 = torch.prim.ListConstruct %arg0 : (!torch.vtensor<[1],si64>) -> !torch.list<optional<vtensor>>
    return %1 : !torch.list<optional<vtensor>>
  }
}

5.error: failed to legalize operation 'torch.aten.detach'

module attributes {torch.debug_module_name = "_lambda"} {
  func.func @forward(%arg0: !torch.vtensor<[1],si64>) -> !torch.vtensor<[1],si64> {
    %0 = torch.aten.detach %arg0 : !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64>
    return %0 : !torch.vtensor<[1],si64>
  }
}

6.error: failed to legalize operation 'torch.aten.pow.Tensor_Tensor'

module attributes {torch.debug_module_name = "_lambda"} {
  func.func @forward(%arg0: !torch.vtensor<[1,801,33],f32>, %arg1: !torch.vtensor<[],f32>) -> !torch.vtensor<[1,801,33],f32> {
    %1 = torch.aten.pow.Tensor_Tensor %arg0, %arg1 : !torch.vtensor<[1,801,33],f32>, !torch.vtensor<[],f32> -> !torch.vtensor<[1,801,33],f32>
    return %1 : !torch.vtensor<[1,801,33],f32>
  }
}

I'm using standard command torch-mlir-opt -pass-pipeline='builtin.module(torch-backend-to-stablehlo-backend-pipeline)' All examples are reproduced on today's fresh torch-mlir.

ramiro050 commented 1 year ago

Hi @MUR-83,

  1. This is probably because the stablehlo path does not have patterns for torch.constants. In the case of the linalg path, the TorchToArith pass is the one that handles these.
  2. Same as 1
  3. Stablehlo path currently does not handle torch.aten.any (in linalg this is also handled in TorchToArith). Here the torch.constant op is not an issue because stablehlo does have a pattern that handles the torch.aten.eq.Scalar op, which absorbs the constant.
  4. Returning lists is not supported.
  5. Someone recently made a PR for torch.aten.detach: https://github.com/llvm/torch-mlir/pull/2021. That PR should fix it
  6. Currently only the linalg path has support for torch.aten.pow.Tensor_Tensor
MUR-83 commented 1 year ago
  1. This is probably because the stablehlo path does not have patterns for torch.constants. In the case of the linalg path, the TorchToArith pass is the one that handles these.
  2. Same as 1
  3. Stablehlo path currently does not handle torch.aten.any (in linalg this is also handled in TorchToArith). Here the torch.constant op is not an issue because stablehlo does have a pattern that handles the torch.aten.eq.Scalar op, which absorbs the constant.
  4. Returning lists is not supported.
  5. Someone recently made a PR for torch.aten.detach: [Torch Dialect] fold aten.detach #2021. That PR should fix it
  6. Currently only the linalg path has support for torch.aten.pow.Tensor_Tensor

Thank you for the detailed answer. Are there any expectations when these will be implemented? Or, are they considered not priority ones currently?

ramiro050 commented 1 year ago

@tanyokwok, do you know of anyone working on these ops? @MUR-83, anyone can contribute ops, so if you wanted to add support for these, we can help guide you along.

FengJungle commented 2 weeks ago
func.func @forward(%arg0: !torch.vtensor<[1,4,8,8],f16>) -> !torch.vtensor<[1,4,16,16],f16> {
  %0 = torch_c.to_builtin_tensor %arg0 : !torch.vtensor<[1,4,8,8],f16> -> tensor<1x4x8x8xf16>
  %none = torch.constant.none
  %cst = arith.constant 2.000000e+00 : f64
  %1 = torch_c.from_f64 %cst
  %cst_0 = stablehlo.constant dense<0.000000e+00> : tensor<1x4x8x8xf16>
  %2 = stablehlo.maximum %0, %cst_0 : tensor<1x4x8x8xf16>
  %3 = torch_c.from_builtin_tensor %2 : tensor<1x4x8x8xf16> -> !torch.vtensor<[1,4,8,8],f16>
  %4 = torch.prim.ListConstruct %1, %1 : (!torch.float, !torch.float) -> !torch.list<float>
  %5 = torch.aten.upsample_nearest2d.vec %3, %none, %4 : !torch.vtensor<[1,4,8,8],f16>, !torch.none, !torch.list<float> -> !torch.vtensor<[1,4,16,16],f16>
  return %5 : !torch.vtensor<[1,4,16,16],f16>
}

Hi @ramiro050 , torch.constant.float has been converted to arith.constant, but still failed in the next process:

Traceback (most recent call last):
  File "/workspace/torch-mlir/projects/pt1/examples/torchscript_stablehlo_backend_test_upsample.py", line 31, in <module>
    module = torchscript.compile(
             ^^^^^^^^^^^^^^^^^^^^
  File "/workspace/torch-mlir/build/tools/torch-mlir/python_packages/torch_mlir/torch_mlir/torchscript.py", line 410, in compile
    return lower_mlir_module(verbose, output_type, mb.module, disable_composite_ops)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/torch-mlir/build/tools/torch-mlir/python_packages/torch_mlir/torch_mlir/compiler_utils.py", line 181, in lower_mlir_module
    run_pipeline_with_repro_report(
  File "/workspace/torch-mlir/build/tools/torch-mlir/python_packages/torch_mlir/torch_mlir/compiler_utils.py", line 81, in run_pipeline_with_repro_report
    raise TorchMlirCompilerError(trimmed_message) from None
torch_mlir.compiler_utils.TorchMlirCompilerError: Lowering Torch Backend IR -> StableHLO Backend IR failed with the following diagnostics:
error: failed to legalize operation 'torch.constant.float'
note: see current operation: %2 = "torch.constant.float"() <{value = 2.000000e+00 : f64}> : () -> !torch.float
FengJungle commented 2 weeks ago

https://github.com/openxla/stablehlo/issues/2604

Here is an explaination.