llvm / torch-mlir

The Torch-MLIR project aims to provide first class support from the PyTorch ecosystem to the MLIR ecosystem.
Other
1.35k stars 505 forks source link

[RFC] Adding support for non-constant dims for aten.view #1131

Closed gpetters94 closed 2 years ago

gpetters94 commented 2 years ago

I'm working on lowering OPT, and I'm running into the following:

error: failed to legalize operation 'torch.aten.view' that was explicitly marked illegal
note: see current operation: %932 = "torch.aten.view"(%898, %931) : (!torch.vtensor<[1,12,7,64],f32>, !torch.list<int>) -> !torch.vtensor<[12,7,64],f32>

Inspecting the lowering of aten.view, it looks like the output shape is -1, -1, 64 because the first two input dims aren't constants. The solution I'd like to write is to recursively follow the dims up the tree, verifying that all the ops are either constants, no-ops (i.e. NumToTensor), or math ops (i.e. multiplication, addition) and then performing the math statically to determine the output shape. Does this sound like how we want to implement this?

silvasean commented 2 years ago

The output shape looks like [12, 7, 64] in your snippet and not [-1, -1, 64]. Can you show the actual IR snippet you are dealing with?

gpetters94 commented 2 years ago

In the actually processing of aten.view, it checks if each input dim is a constant. If not it assigns kUnknownDim to it, and in this case the first two inputs are not constants. The code is here.

silvasean commented 2 years ago

Can you show the IR before the pass?

silvasean commented 2 years ago

(for future reference, it's usually important to show a reduced, fully valid IR example with any bug reports like this)

gpetters94 commented 2 years ago

Here's the IR after failure: https://gist.github.com/gpetters94/af96b032acb0e6c6274af9aff62ec5e3

The relevant part is:

  %136 = torch.aten.mul.Tensor %123, %71 : !torch.vtensor<[],si64>, !torch.vtensor<[],si64> -> !torch.vtensor<[],si64>
  %137 = torch.aten.Int.Tensor %136 : !torch.vtensor<[],si64> -> !torch.int
  %138 = torch.aten.Int.Tensor %136 : !torch.vtensor<[],si64> -> !torch.int
  %139 = torch.aten.Int.Tensor %136 : !torch.vtensor<[],si64> -> !torch.int
  %140 = torch.prim.ListConstruct %int1, %int7, %int12, %int64 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
  %141 = torch.aten.view %126, %140 : !torch.vtensor<[1,7,768],f32>, !torch.list<int> -> !torch.vtensor<[1,7,12,64],f32>
  %142 = torch.aten.transpose.int %141, %int1, %int2 : !torch.vtensor<[1,7,12,64],f32>, !torch.int, !torch.int -> !torch.vtensor<[1,12,7,64],f32>
  %143 = torch.aten.contiguous %142, %int0 : !torch.vtensor<[1,12,7,64],f32>, !torch.int -> !torch.vtensor<[1,12,7,64],f32>
  %144 = torch.aten.numel %143 : !torch.vtensor<[1,12,7,64],f32> -> !torch.int
  %145 = torch.prim.NumToTensor.Scalar %144 : !torch.int -> !torch.vtensor<[],si64>
  %146 = torch.aten.div.Tensor_mode %145, %136, %str : !torch.vtensor<[],si64>, !torch.vtensor<[],si64>, !torch.str -> !torch.vtensor<[],si64>
  %147 = torch.aten.div.Tensor_mode %146, %70, %str : !torch.vtensor<[],si64>, !torch.vtensor<[],si64>, !torch.str -> !torch.vtensor<[],si64>
  %148 = torch.aten.Int.Tensor %147 : !torch.vtensor<[],si64> -> !torch.int
  %149 = torch.prim.ListConstruct %139, %148, %int64 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
  %150 = torch.aten.view %143, %149 : !torch.vtensor<[1,12,7,64],f32>, !torch.list<int> -> !torch.vtensor<[12,7,64],f32>
gpetters94 commented 2 years ago

Here's the distilled version:

func.func @forward(%arg0: !torch.vtensor<[1,12,7,64],f32>) -> !torch.vtensor<[12,7,64],f32> {
  %str = torch.constant.str "floor"
  %int7 = torch.constant.int 7
  %int12 = torch.constant.int 12
  %int64 = torch.constant.int 64
  %144 = torch.aten.numel %arg0 : !torch.vtensor<[1,12,7,64],f32> -> !torch.int
  %145 = torch.prim.NumToTensor.Scalar %144 : !torch.int -> !torch.vtensor<[],si64>
  %tensor7 = torch.prim.NumToTensor.Scalar %int7 : !torch.int -> !torch.vtensor<[],si64>
  %tensor64 = torch.prim.NumToTensor.Scalar %int64 : !torch.int -> !torch.vtensor<[],si64>
  %146 = torch.aten.div.Tensor_mode %145, %tensor7, %str : !torch.vtensor<[],si64>, !torch.vtensor<[],si64>, !torch.str -> !torch.vtensor<[],si64>
  %147 = torch.aten.div.Tensor_mode %146, %tensor64, %str : !torch.vtensor<[],si64>, !torch.vtensor<[],si64>, !torch.str -> !torch.vtensor<[],si64>
  %148 = torch.aten.Int.Tensor %147 : !torch.vtensor<[],si64> -> !torch.int
  %149 = torch.prim.ListConstruct %int12, %148, %int64 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
  %150 = torch.aten.view %arg0, %149 : !torch.vtensor<[1,12,7,64],f32>, !torch.list<int> -> !torch.vtensor<[12,7,64],f32>
  return %150 : !torch.vtensor<[12,7,64],f32>
}
silvasean commented 2 years ago

It looks we have already done all the shape math statically, because the result shape is inferred as !torch.vtensor<[12,7,64],f32>. So I don't want to do any special local logic here for that.

You should be able to extend https://github.com/llvm/torch-mlir/pull/935 for torch.aten.div.Tensor_mode to do more folding here if that is useful as well.

gpetters94 commented 2 years ago

So should I just rewrite aten.view to use the statically-inferred output shape when the current logic fails?

silvasean commented 2 years ago

So should I just rewrite aten.view to use the statically-inferred output shape when the current logic fails?

That would make sense to me. Actually, I would add a canonicalization that replaces the view sizes operand with a constant list if the result shape is static (and the operand is not already a constant list).

gpetters94 commented 2 years ago

Sure, I can do that. Where are canonicalizations added?

silvasean commented 2 years ago

TorchOps.cpp -- you need to add let hasCanonicalizer = 1 the ODS definition.

silvasean commented 2 years ago

See here for more info: https://mlir.llvm.org/docs/Canonicalization/

gpetters94 commented 2 years ago

Implemented this in #1337