Closed gpetters94 closed 2 years ago
The output shape looks like [12, 7, 64] in your snippet and not [-1, -1, 64]. Can you show the actual IR snippet you are dealing with?
In the actually processing of aten.view
, it checks if each input dim is a constant. If not it assigns kUnknownDim
to it, and in this case the first two inputs are not constants. The code is here.
Can you show the IR before the pass?
(for future reference, it's usually important to show a reduced, fully valid IR example with any bug reports like this)
Here's the IR after failure: https://gist.github.com/gpetters94/af96b032acb0e6c6274af9aff62ec5e3
The relevant part is:
%136 = torch.aten.mul.Tensor %123, %71 : !torch.vtensor<[],si64>, !torch.vtensor<[],si64> -> !torch.vtensor<[],si64>
%137 = torch.aten.Int.Tensor %136 : !torch.vtensor<[],si64> -> !torch.int
%138 = torch.aten.Int.Tensor %136 : !torch.vtensor<[],si64> -> !torch.int
%139 = torch.aten.Int.Tensor %136 : !torch.vtensor<[],si64> -> !torch.int
%140 = torch.prim.ListConstruct %int1, %int7, %int12, %int64 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
%141 = torch.aten.view %126, %140 : !torch.vtensor<[1,7,768],f32>, !torch.list<int> -> !torch.vtensor<[1,7,12,64],f32>
%142 = torch.aten.transpose.int %141, %int1, %int2 : !torch.vtensor<[1,7,12,64],f32>, !torch.int, !torch.int -> !torch.vtensor<[1,12,7,64],f32>
%143 = torch.aten.contiguous %142, %int0 : !torch.vtensor<[1,12,7,64],f32>, !torch.int -> !torch.vtensor<[1,12,7,64],f32>
%144 = torch.aten.numel %143 : !torch.vtensor<[1,12,7,64],f32> -> !torch.int
%145 = torch.prim.NumToTensor.Scalar %144 : !torch.int -> !torch.vtensor<[],si64>
%146 = torch.aten.div.Tensor_mode %145, %136, %str : !torch.vtensor<[],si64>, !torch.vtensor<[],si64>, !torch.str -> !torch.vtensor<[],si64>
%147 = torch.aten.div.Tensor_mode %146, %70, %str : !torch.vtensor<[],si64>, !torch.vtensor<[],si64>, !torch.str -> !torch.vtensor<[],si64>
%148 = torch.aten.Int.Tensor %147 : !torch.vtensor<[],si64> -> !torch.int
%149 = torch.prim.ListConstruct %139, %148, %int64 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
%150 = torch.aten.view %143, %149 : !torch.vtensor<[1,12,7,64],f32>, !torch.list<int> -> !torch.vtensor<[12,7,64],f32>
Here's the distilled version:
func.func @forward(%arg0: !torch.vtensor<[1,12,7,64],f32>) -> !torch.vtensor<[12,7,64],f32> {
%str = torch.constant.str "floor"
%int7 = torch.constant.int 7
%int12 = torch.constant.int 12
%int64 = torch.constant.int 64
%144 = torch.aten.numel %arg0 : !torch.vtensor<[1,12,7,64],f32> -> !torch.int
%145 = torch.prim.NumToTensor.Scalar %144 : !torch.int -> !torch.vtensor<[],si64>
%tensor7 = torch.prim.NumToTensor.Scalar %int7 : !torch.int -> !torch.vtensor<[],si64>
%tensor64 = torch.prim.NumToTensor.Scalar %int64 : !torch.int -> !torch.vtensor<[],si64>
%146 = torch.aten.div.Tensor_mode %145, %tensor7, %str : !torch.vtensor<[],si64>, !torch.vtensor<[],si64>, !torch.str -> !torch.vtensor<[],si64>
%147 = torch.aten.div.Tensor_mode %146, %tensor64, %str : !torch.vtensor<[],si64>, !torch.vtensor<[],si64>, !torch.str -> !torch.vtensor<[],si64>
%148 = torch.aten.Int.Tensor %147 : !torch.vtensor<[],si64> -> !torch.int
%149 = torch.prim.ListConstruct %int12, %148, %int64 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
%150 = torch.aten.view %arg0, %149 : !torch.vtensor<[1,12,7,64],f32>, !torch.list<int> -> !torch.vtensor<[12,7,64],f32>
return %150 : !torch.vtensor<[12,7,64],f32>
}
It looks we have already done all the shape math statically, because the result shape is inferred as !torch.vtensor<[12,7,64],f32>
. So I don't want to do any special local logic here for that.
You should be able to extend https://github.com/llvm/torch-mlir/pull/935 for torch.aten.div.Tensor_mode to do more folding here if that is useful as well.
So should I just rewrite aten.view to use the statically-inferred output shape when the current logic fails?
So should I just rewrite aten.view to use the statically-inferred output shape when the current logic fails?
That would make sense to me. Actually, I would add a canonicalization that replaces the view sizes operand with a constant list if the result shape is static (and the operand is not already a constant list).
Sure, I can do that. Where are canonicalizations added?
TorchOps.cpp -- you need to add let hasCanonicalizer = 1
the ODS definition.
See here for more info: https://mlir.llvm.org/docs/Canonicalization/
Implemented this in #1337
I'm working on lowering OPT, and I'm running into the following:
Inspecting the lowering of aten.view, it looks like the output shape is
-1, -1, 64
because the first two input dims aren't constants. The solution I'd like to write is to recursively follow the dims up the tree, verifying that all the ops are either constants, no-ops (i.e. NumToTensor), or math ops (i.e. multiplication, addition) and then performing the math statically to determine the output shape. Does this sound like how we want to implement this?