llvm / torch-mlir

The Torch-MLIR project aims to provide first class support from the PyTorch ecosystem to the MLIR ecosystem.
Other
1.34k stars 499 forks source link

[RFC] Selective Op decomposition: Propagate down backend-legal Ops as Custom Ops #1519

Open Svoch opened 1 year ago

Svoch commented 1 year ago

Hi folks, I'd like to get your feedback on my proposed design of selective decomposition of operations in Torch-MLIR.

I am building mainly on the notion of "backend-legal Ops" (merged in this PR) and the ideas proposed in the design for the custom Ops support (mentioned in this RFC).

When an operation is marked as legal in the backend, then it will not be decomposed in the "DecomposeComplexOpsPass", and is propagated down to the backend. Next there would be two possibilities:

  1. A lowering pattern for the undecomposed Op exists in the backend: In this case the compilation will succeed.
  2. There is no lowering pattern for the undecomposed Op in the backend: In this case the compilation fails. For example, marking an Op like AtenLayerNormOp and compiling will yield an IR that cannot be lowered into the backend dialects.

To handle the second case, we can mark such operations during (or before) the decomposition pass and have the backend conversion passes lower them into custom Ops. For example, consider Torch module below:

class SimpleModel(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.layernorm = nn.LayerNorm(normalized_shape=4)

    def forward(self, input_tensor):
        model_out = self.layernorm(input_tensor)
        return model_out

When compiling this to a backend, say TOSA, using the commands below

torch_mlir.BACKEND_LEGAL_OPS[torch_mlir.OutputType.TOSA].append("torch.aten.layernorm")
module = torch_mlir.compile(SimpleModel(), torch.randn(4, 4), output_type=torch_mlir.OutputType.TOSA)

The compilation fails due to not being able to legalize unhandled Ops like the aten.layer_norm.int and its attributes present as constant.int in the Torch IR:

====================
Torch Backend IR
module attributes {torch.debug_module_name = "SimpleModel"} {
  func.func @forward(%arg0: !torch.vtensor<[4,4],f32>) -> !torch.vtensor<[4,4],f32> {
    %0 = torch.vtensor.literal(dense<0.000000e+00> : tensor<4xf32>) : !torch.vtensor<[4],f32>
    %1 = torch.vtensor.literal(dense<1.000000e+00> : tensor<4xf32>) : !torch.vtensor<[4],f32>
    %int4 = torch.constant.int 4
    %float1.000000e-05 = torch.constant.float 1.000000e-05
    %2 = torch.prim.ListConstruct %int4 : (!torch.int) -> !torch.list<int>
    %3 = torch.aten.layer_norm %arg0, %2, %1, %0, %float1.000000e-05 {backend_legal} : !torch.vtensor<[4,4],f32>, !torch.list<int>, !torch.vtensor<[4],f32>, !torch.vtensor<[4],f32>, !torch.float, !torch.bool -> !torch.vtensor<[4,4],f32>
    return %3 : !torch.vtensor<[4,4],f32>
  }
}

But we can mark the AtenLayerNormOp as "backend-legal" during or before the decomposition pass, and introduce a conversion pass to lower the marked Ops to custom Ops in the backend dialect. For example, this could be the a valid TOSA IR for a full selective decomposition compilation:

====================
TOSA Backend IR
module attributes {torch.debug_module_name = "3D_LayerNorm"} {
  func.func @forward(%arg0: tensor<4x4xf32>) -> tensor<4x4xf32> {
    %normalized_shape = "tosa.const"() {value = dense<4> : tensor<1xi64>} : () -> tensor<1xi64>
    %bias = "tosa.const"() {value = dense<0.000000e+00> : tensor<4xf32>} : () -> tensor<4xf32>
    %weight = "tosa.const"() {value = dense<1.000000e+00> : tensor<4xf32>} : () -> tensor<4xf32>
    %eps = "tosa.const"() {value = dense<9.99999996E-13> : tensor<f32>} : () -> tensor<f32>
    %layernorm = "tosa.custom"(%arg0, %normalized_shape, %weight, %bias, %eps) {identifier = "torch.aten.layer_norm"} : (tensor<4x4xf32>, tensor<1xi64>, tensor<4xf32>, tensor<4xf32>, tensor<f32>) -> tensor<4x4xf32>
    return %layernorm : tensor<4x4xf32>
  }
}

Please note that torch.aten.layer_norm was marked with a UnaryAttribute named backend_legal and converted into a tosa.custom operation, and the rest of the operations used by it converted to tosa.const Ops via the existing lowering patterns in Torch to TOSA conversion pass. Since there is no way for the conversion to distinguish between the attributes and constant input arguments, all of the Ops used by the custom Op will be treated as inputs. It will be up to the downstream dialects ingesting this IR to deduce the arguments and attributes based on their custom Op lowering methods.

On a final note, the concept of "backend-legal" can be extended to distinguish between the two cases mentioned above: Ops that should be handled by the lowering targets in Torch-MLIR, and Ops that should be tunneled down to the downstream dialects as custom Ops which can be marked as "backend-custom" Ops maybe.

cc @silvasean @powderluv

eric-k256 commented 1 year ago

This sounds like a reasonably useful tool to me. This seems generic enough to me to work across multiple backends, assuming that they have some way to wrap the op similar to what tosa.custom allows. So for TOSA I'd say yes, although I'd like to get indications from other backends that this would work for them, so that we can have a consistent experience.

As a side note, I'm working on a change for tosa.custom to accept two additional strings, one for a 'config', which is like a namespace for the identifier, and another one to pass attributes of the wrapped operation if needed. I should be posting that in the TOSA dialect soon.

ramiro050 commented 1 year ago

@Svoch, thanks for presenting on this during the community meeting! I'll mention here the feedback given during the meeting.

The approach mentioned here seems good to me. We definitely want to have a way to specify backend-custom ops that can then be fed to a TorchTo{Tosa/Mhlo/Linalg}Custom pass that systematically converts each op specified to a custom op.

One thing we need to figure out is how to deal with the type conversions for certain torch dialect types that are expected to vanish once we finish the TorchToTosa pass. For example, many ops take optional arguments as input, which can potentially be a !torch.none. If an argument is None, we need a way to preserve this knowledge at the TOSA level. Perhaps this will require encoding these arguments in the attributes field @eric-k256 is working on. Another type that will be tricky to deal with is !torch.list<...>. In your example, the opaten.layer.norm takes a list of ints as input, which you converted to a tensor. However, sometimes ops have as input a list of tensors, or list of optional tensors, where a simple transformation to a tensor will not work. We need to have a way to encode this into the op.

Off the top of my head, the types that we need to worry about are: !torch.list<...>, !torch.device, !torch.string, !torch.none, !torch.generator.

silvasean commented 1 year ago

I wasn't present at the discussion, but my general thoughts on this are that there isn't much to do here. TorchLoweringPipelineOptions already has a way to set backend legal ops. So as Ramiro points out, we just need a pass convert-torch-to-tosa-custom={ops-to-convert=torch.aten.layer_norm.int,...} that systematically converts any given Torch ops to tosa.custom ops in a well-specified way (with a well-documented contract). Whether other backends want to do something similar is orthogonal and does not interact with the TOSA feature in any way.

We don't want to have a backend_legal attribute monkey-patched on ops. Whether an op is legal for the backend is a property of the backend, not of the IR. Integrators of Torch-MLIR (which torch_mlir.compile Python API counts as one) need to be aware of what each backend supports and how to configure the lowering to the backend contract (in particular, setting the backend legal ops) so that it is compatible with the desired backend.

ramiro050 commented 1 year ago

@silvasean, I noticed you commented here https://reviews.llvm.org/D137133 regarding the design of tosa::CustomOp. Is the idea to incorporate everything that is not a tensor into the attributes field? Would this work for the case of ints/floats that don't have a constant value at compile time? At the moment the approach taken in https://github.com/llvm/torch-mlir/pull/1514 is to convert scalars to tensors and pass them as normal arguments.

silvasean commented 1 year ago

Sometimes ints/floats are optional so for that we won't be able to represent them as operands due to the "None" case. And I think that makes sense since for the other backends we often pattern-match them as constants anyway during lowering.

The important thing here semantically is to consider what is truly a dynamically computed value and what is not, as far as TOSA is concerned. Since TOSA does not lower e.g. AtenAddIntOp/AtenAddFloatOp, it is fair to say that we do not expect !torch.int or !torch.float to ever be dynamically computed for the TOSA backend. At the very least, I think it would be a major extension to what TOSA can do to really support this in any generality, since often the computed values end up in lists/etc. and not just raw operands.

For cross-project, stable IR contracts, it is preferable to start with a constrained contract that is loosened over time to meet real demands, and putting all non-Tensor's in attributes (which for now probably means JSONifying) is the most consistent with that philosophy at the moment. TBD how to handle optional<vtensor> but for now we can prohibit that (maybe something horrible like an integer list of "omitted-becuase-they-are-None operands").

I would recommend that as part of the development of this feature, that we run at least once the e2e test suite in a dummy mode that turns all ops into tosa custom ops and see what ops don't work. We should have a pretty clear explanation for each kind of op that we don't systematically convert in the v1 of this feature.

Svoch commented 1 year ago

Thanks for your comments folks, I have created this PR to create a PoC to systematically lower backend-legal Torch operations to tosa.Custom Ops. I've added the design details and the backend contract in the PR. Here is a copy of the contract:

The pass convert-torch-backend-legal-to-tosa-custom extends the selective decomposition patterns to TOSA backend, such that it will convert operations marked as backend-legal by the user for the TOSA backend to CustomOps in TOSA dialect. The backend-legal operations will not be decomposed or lowered through the existing Torch to TOSA conversion patterns, even if such patterns exist in the subsequent Torch to TOSA conversion pass; instead a TOSA CustomOp will be created with the specifications below:

  1. The 'identifier' attribute in the CustomOp will be the operation name of the corresponding ATen operation.
  2. All inputs to the ATen operation will be converted into legal TOSA operations. There will be no distiction between inputs and constant attributes; all will be treated as inputs to the CustomOp. It is up to the consuming backend to deduce them based on the ATen operation semantics.
  3. Since TOSA conversion pattern of Torch operations are responsible to consume and convet the constant attributes (such as 'axis' for a reduction operation), the TOSA CustomOp conversion will also match and rewrite these attributes as TOSA ConstOps as well. Specifically:
    • torch.constant.bool -> tosa.ConstOp (of tensor type i1)
    • torch.constant.int -> tosa.ConstOp (of tensor type i64)
    • torch.constant.float -> tosa.ConstOp (of tensor type f32)
    • torch.prim.ListConstruct -> tosa.ConstOp (of tensor type i64)
    • torch.constant.none -> tosa.CustomOp (of tensor type i1) The 'identifier' attribute of this CustomOp is 'torch.constant.none' All other Torch ATen operations will be lowered to TOSA by the Torch to TOSA conversion pass after this one.
  4. The order of the input operands of the backend-legal Torch operation preserved and the TOSA CustomOp will have the same order.

@eric-k256 regarding your comment on considering the new config attribute for the lowering, I have marked it as a TODO but I'll be happy to update this PR to include this enhancement. We need to wait for an LLVM uplift in Torch-MLIR however. @ramiro050 @silvasean Your point on handling the cases like the !torch.none is on-spot. I think I have handled a few of these cases in the PoC PR that I have linked. Specifically, for handling "None", I am creating yet another tosa.Custom Op. I've also refrained from attaching attributes to the IR; instead I am propagating the list of backend-legal operations to the TOSA backend pipeline.

Please let me know how does this approach sound, and thanks for the reviews again!

Svoch commented 1 year ago

BTW, here is an example of performing selective decomposition the module defined below with marking torch.aten.layer_norm operation as a backend-legal operation:

class SimpleModel(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.layernorm = nn.LayerNorm(normalized_shape=4)

    def forward(self, input_tensor):
        return self.layernorm(input_tensor)
====================
Torch Backend IR
module attributes {torch.debug_module_name = "SimpleModel"} {
  func.func @forward(%arg0: !torch.vtensor<[4,4],f32>) -> !torch.vtensor<[4,4],f32> {
    %0 = torch.vtensor.literal(dense<0.000000e+00> : tensor<4xf32>) : !torch.vtensor<[4],f32>
    %1 = torch.vtensor.literal(dense<1.000000e+00> : tensor<4xf32>) : !torch.vtensor<[4],f32>
    %int4 = torch.constant.int 4
    %float1.000000e-05 = torch.constant.float 1.000000e-05
    %true = torch.constant.bool true
    %2 = torch.prim.ListConstruct %int4 : (!torch.int) -> !torch.list<int>
    %3 = torch.aten.layer_norm %arg0, %2, %1, %0, %float1.000000e-05, %true : !torch.vtensor<[4,4],f32>, !torch.list<int>, !torch.vtensor<[4],f32>, !torch.vtensor<[4],f32>, !torch.float, !torch.bool -> !torch.vtensor<[4,4],f32>
    return %3 : !torch.vtensor<[4,4],f32>
  }
}

====================
TOSA Backend IR
module attributes {torch.debug_module_name = "SimpleModel"} {
  func.func @forward(%arg0: tensor<4x4xf32>) -> tensor<4x4xf32> {
    %0 = "tosa.const"() {value = dense<0.000000e+00> : tensor<4xf32>} : () -> tensor<4xf32>
    %1 = "tosa.const"() {value = dense<1.000000e+00> : tensor<4xf32>} : () -> tensor<4xf32>
    %2 = "tosa.const"() {value = dense<4> : tensor<1xi64>} : () -> tensor<1xi64>
    %3 = "tosa.const"() {value = dense<9.99999974E-6> : tensor<f32>} : () -> tensor<f32>
    %4 = "tosa.const"() {value = dense<1> : tensor<i64>} : () -> tensor<i64>
    %5 = "tosa.custom"(%arg0, %2, %1, %0, %3, %4) {identifier = "torch.aten.layer_norm"} : (tensor<4x4xf32>, tensor<1xi64>, tensor<4xf32>, tensor<4xf32>, tensor<f32>, tensor<i64>) -> tensor<4x4xf32>
    return %5 : tensor<4x4xf32>
  }
}