llvm / torch-mlir

The Torch-MLIR project aims to provide first class support from the PyTorch ecosystem to the MLIR ecosystem.
Other
1.25k stars 451 forks source link

Failed to lower "torch.aten.convolution" to linalg #3429

Open josel-amd opened 1 month ago

josel-amd commented 1 month ago

Hi all,

Currently observing a failure when trying to lower the torch.aten.convolution operator. The failure message that I see is: unimplemented: only 2D grouped convolution supported.

Command to reproduce:

torch-mlir-opt --torch-backend-to-linalg-on-tensors-backend-pipeline /scratch/vaimlpl/vitis_flexml/conv3d.mlir -debug;

Snippet to reproduce the error is in [0]

Observed result

** Insert  : 'arith.subi'(0x55f80d6c9a20)
    ** Insert  : 'arith.subi'(0x55f80d6c9b00)
    ** Insert  : 'arith.floordivsi'(0x55f80d6c9be0)
    ** Insert  : 'arith.addi'(0x55f80d6c9cc0)
    ** Insert  : 'arith.index_cast'(0x55f80d6c9da0)
    ** Insert  : 'tensor.empty'(0x55f80d6c9e60)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::linalg::detail::GenericOpGenericAdaptorBase::Properties)
    ** Insert Block into detached Region (nullptr parent op)'    ** Insert  : 'linalg.yield'(0x55f80d6cf2d0)
    ** Insert  : 'linalg.generic'(0x55f80d5e5280)
    ** Insert  : 'arith.floordivsi'(0x55f80d6d0510)
    ** Insert  : 'arith.floordivsi'(0x55f80d6d05f0)
    ** Insert  : 'arith.constant'(0x55f80d6d06d0)
    ** Insert  : 'arith.constant'(0x55f80d6d0770)
    ** Failure : unimplemented: only 2D grouped convolution supported
"(anonymous namespace)::ConvertAtenConvolutionOp" result 0
  } -> FAILURE : pattern failed to match

  * Pattern : 'torch.aten.convolution -> ()' {
Trying to match "(anonymous namespace)::ConvertAtenScalarToTensorLike"
    ** Failure : not a supported Scalar to Tensor like op
"(anonymous namespace)::ConvertAtenScalarToTensorLike" result 0
  } -> FAILURE : pattern failed to match

  * Pattern : 'torch.aten.convolution -> ()' {
Trying to match "(anonymous namespace)::ConvertElementwiseOp"
    ** Failure : not a supported elementwise op
"(anonymous namespace)::ConvertElementwiseOp" result 0
  } -> FAILURE : pattern failed to match

  * Pattern : 'torch.aten.convolution -> ()' {
Trying to match "(anonymous namespace)::ConvertReductionOp"
    ** Failure : not a supported reduce op
"(anonymous namespace)::ConvertReductionOp" result 0
  } -> FAILURE : pattern failed to match
} -> FAILURE : no matched legalization pattern
//===-------------------------------------------===//
<unknown>:0: error: failed to legalize operation 'torch.aten.convolution' that was explicitly marked illegal
<unknown>:0: note: see current operation: %12 = "torch.aten.convolution"(%6, %7, %8, %10, %9, %10, %1, %11, %0) : (!torch.vtensor<[2,4,4,5,4],f32>, !torch.vtensor<[6,2,3,3,3],f32>, !torch.vtensor<[6],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.list<int>, !torch.int) -> !torch.vtensor<[2,6,2,3,2],f32>

[0]

#loc = loc(unknown)
module attributes {
  llvm.data_layout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-i128:128-f80:128-n8:16:32:64-S128",
  llvm.target_triple = "x86_64-unknown-linux-gnu",
  "onnx-mlir.symbol-postfix" = "onnxmodel.onnx.mlir"} {
  func.func @forward(%arg0: tensor<2x4x4x5x4xf32> {onnx.name = "0"} loc(unknown)) -> (tensor<2x6x2x3x2xf32> {onnx.name = "3"}) attributes {
    torch.onnx_meta.opset_version = 19 : si64} {
    %int2 = torch.constant.int 2 loc(#loc)
    %false = torch.constant.bool false loc(#loc)
    %int1 = torch.constant.int 1 loc(#loc)
    %int0 = torch.constant.int 0 loc(#loc)
    %0 = "tosa.const"() <{value = dense<[0.0272595733, 0.0564398244, -6.167640e-02, -0.00770454854, -0.0622025988, -0.0172998682]> : tensor<6xf32>}> : () -> tensor<6xf32> loc(#loc)
    %1 = "tosa.const"() <{value = dense_resource<__elided__> : tensor<6x2x3x3x3xf32>}> : () -> tensor<6x2x3x3x3xf32> loc(#loc)
    %2 = torch_c.from_builtin_tensor %arg0 : tensor<2x4x4x5x4xf32> -> !torch.vtensor<[2,4,4,5,4],f32> loc(#loc)
    %3 = torch_c.from_builtin_tensor %1 : tensor<6x2x3x3x3xf32> -> !torch.vtensor<[6,2,3,3,3],f32> loc(#loc)
    %4 = torch_c.from_builtin_tensor %0 : tensor<6xf32> -> !torch.vtensor<[6],f32> loc(#loc)
    %5 = torch.prim.ListConstruct %int0, %int0, %int0 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int> loc(#loc)
    %6 = torch.prim.ListConstruct %int1, %int1, %int1 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int> loc(#loc)
    %7 = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list<int> loc(#loc)
    %8 = torch.aten.convolution %2, %3, %4, %6, %5, %6, %false, %7, %int2 : !torch.vtensor<[2,4,4,5,4],f32>, !torch.vtensor<[6,2,3,3,3],f32>, !torch.vtensor<[6],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.list<int>, !torch.int -> !torch.vtensor<[2,6,2,3,2],f32> loc(#loc)       %9 = torch_c.to_builtin_tensor %8 : !torch.vtensor<[2,6,2,3,2],f32> -> tensor<2x6x2x3x2xf32> loc(#loc)
    return %9 : tensor<2x6x2x3x2xf32> loc(#loc)
  } loc(#loc)
} loc(#loc)
srcarroll commented 1 month ago

i dont know if anyone here would be interested, but a few months ago I started working on unifying grouped and depthwise convs by creating an interface for them. At the time I couldn't find many people that cared enough about it so I lost steam with it. I recently resurrected it https://github.com/llvm/llvm-project/pull/94796. This is still a work in progress. I actually first started doing this with depthwise only https://github.com/llvm/llvm-project/pull/75017. This is more complete, but I then decided to switch to grouped conv and make depthwise a specialization of it.

zjgarvey commented 1 month ago

@josel-amd How high priority is 3d grouped convolution for your team?

If we need the support immediately, I could temporarily lower this signature to a linalg generic until a better upstream fix is landed. The longer term convolution changes mentioned by @srcarroll sound attractive in comparison to upstreaming (yet another) convolution named op to the linalg dialect.

josel-amd commented 1 month ago

@zjgarvey I can't discuss priorities but this is part of a wider effort to get comprehensive support for onnx (you can see that by looking at the other issues I've raised). This is important for us but not urgent. I wonder if some kind of action should still be taken. However, I don't have enough information to dare suggesting a course of action. We can follow the proper way here.