Open AmosLewis opened 1 year ago
There is a fill op in the manually created python file test_maked_fill.py, which is same as the e2e test model definition.
%10 = torch.aten.fill.Tensor %9, %5 : !torch.vtensor<*,si64>, !torch.vtensor<[],si64> -> !torch.vtensor<[1],si64>
It looks like the fill op is decomposed.
module attributes {torch.debug_module_name = "_lambda"} {
func.func @forward(%arg0: !torch.vtensor<[1,15],si64>, %arg1: !torch.vtensor<[1,4],si64>) -> !torch.vtensor<*,si64> {
%int1 = torch.constant.int 1
%int0 = torch.constant.int 0
%false = torch.constant.bool false
%int4 = torch.constant.int 4
%none = torch.constant.none
%0 = torch.vtensor.literal(dense<0> : tensor<si64>) : !torch.vtensor<[],si64>
%cpu = torch.constant.device "cpu"
%1 = torch.prim.ListConstruct %int1, %int4 : (!torch.int, !torch.int) -> !torch.list<int>
%2 = torch.aten.zeros %1, %int4, %int0, %cpu, %false : !torch.list<int>, !torch.int, !torch.int, !torch.Device, !torch.bool -> !torch.vtensor<[1,4],si64>
%3 = torch.tensor_static_info_cast %2 : !torch.vtensor<[1,4],si64> to !torch.vtensor<*,si64>
%4 = torch.copy.to_tensor %3 : !torch.tensor<*,si64>
%5 = torch.aten.clone %0, %none : !torch.vtensor<[],si64>, !torch.none -> !torch.vtensor<[],si64>
%6 = torch.aten.slice.Tensor %4, %int1, %int0, %int1, %int1 : !torch.tensor<*,si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.tensor<[1,1],si64>
%7 = torch.aten.squeeze.dim %6, %int1 : !torch.tensor<[1,1],si64>, !torch.int -> !torch.tensor<[1],si64>
%8 = torch.tensor_static_info_cast %7 : !torch.tensor<[1],si64> to !torch.tensor<*,si64>
%9 = torch.copy.to_vtensor %8 : !torch.vtensor<*,si64>
%10 = torch.aten.fill.Tensor %9, %5 : !torch.vtensor<*,si64>, !torch.vtensor<[],si64> -> !torch.vtensor<[1],si64>
%11 = torch.tensor_static_info_cast %10 : !torch.vtensor<[1],si64> to !torch.vtensor<*,si64>
torch.overwrite.tensor.contents %11 overwrites %8 : !torch.vtensor<*,si64>, !torch.tensor<*,si64>
%12 = torch.copy.to_vtensor %4 : !torch.vtensor<*,si64>
return %12 : !torch.vtensor<*,si64>
}
}
Here the pattern seems to be select.int + copy_
:
module attributes {torch.debug_module_name = "MyModule"} {
func.func private @__torch__.MyModule.forward(%arg0: !torch.nn.Module<"__torch__.MyModule">, %arg1:
!torch.tensor {torch.type_bound = !torch.vtensor<[1,4],f32>}) -> !torch.tensor {
%false = torch.constant.bool false
%int-1 = torch.constant.int -1
%int1 = torch.constant.int 1
%int0 = torch.constant.int 0
%1 = torch.aten.select.int %arg1, %int-1, %int0 : !torch.tensor, !torch.int, !torch.int ->
!torch.tensor
%2 = torch.prim.dtype %1 : !torch.tensor -> !torch.int
%3 = torch.prim.device %1 : !torch.tensor -> !torch.Device
%4 = torch.aten.tensor.int %int1, %2, %3, %false : !torch.int, !torch.int, !torch.Device, !torch.bool
-> !torch.tensor
%5 = torch.aten.copy_ %1, %4, %false : !torch.tensor, !torch.tensor, !torch.bool -> !torch.tensor
return %arg1 : !torch.tensor
}
torch.class_type @__torch__.MyModule {
torch.attr private "training" : !torch.bool
torch.attr private "_is_full_backward_hook" : !torch.optional<bool>
torch.method "forward", @__torch__.MyModule.forward
}
%true = torch.constant.bool true
%none = torch.constant.none
%0 = torch.nn_module {
torch.slot "training", %true : !torch.bool
torch.slot "_is_full_backward_hook", %none : !torch.none
} : !torch.nn.Module<"__torch__.MyModule">
}
This should be handled in a similar way. In the index_put
op, every indices tensor will be None
except for the tensor at dim selectInt.getDim()
, which will be a single element tensor with value selectInt.getIndex()
Here the pattern seems to be
select.int + copy_
:module attributes {torch.debug_module_name = "MyModule"} { func.func private @__torch__.MyModule.forward(%arg0: !torch.nn.Module<"__torch__.MyModule">, %arg1: !torch.tensor {torch.type_bound = !torch.vtensor<[1,4],f32>}) -> !torch.tensor { %false = torch.constant.bool false %int-1 = torch.constant.int -1 %int1 = torch.constant.int 1 %int0 = torch.constant.int 0 %1 = torch.aten.select.int %arg1, %int-1, %int0 : !torch.tensor, !torch.int, !torch.int -> !torch.tensor %2 = torch.prim.dtype %1 : !torch.tensor -> !torch.int %3 = torch.prim.device %1 : !torch.tensor -> !torch.Device %4 = torch.aten.tensor.int %int1, %2, %3, %false : !torch.int, !torch.int, !torch.Device, !torch.bool -> !torch.tensor %5 = torch.aten.copy_ %1, %4, %false : !torch.tensor, !torch.tensor, !torch.bool -> !torch.tensor return %arg1 : !torch.tensor } torch.class_type @__torch__.MyModule { torch.attr private "training" : !torch.bool torch.attr private "_is_full_backward_hook" : !torch.optional<bool> torch.method "forward", @__torch__.MyModule.forward } %true = torch.constant.bool true %none = torch.constant.none %0 = torch.nn_module { torch.slot "training", %true : !torch.bool torch.slot "_is_full_backward_hook", %none : !torch.none } : !torch.nn.Module<"__torch__.MyModule"> }
This should be handled in a similar way. In the
index_put
op, every indices tensor will beNone
except for the tensor at dimselectInt.getDim()
, which will be a single element tensor with valueselectInt.getIndex()
I just push a patch to fold the select and copy but fail to get the input rank https://github.com/llvm/torch-mlir/pull/2000. BTW, I didn't find any aten.copy Op in my IR dump. Only aten.fill found. Why do you think it is a select+copy_ issue?
I tried to print IR before and after createRecomposeComplexOpsPass, but it looks like the pattern is not selectop+copy but selectop+fill.
// -----// IR Dump After EraseModuleInitializer (torch-erase-module-initializer) //----- //
module attributes {torch.debug_module_name = "_lambda"} {
func.func @forward(%arg0: !torch.vtensor<[1,15],si64>, %arg1: !torch.vtensor<[1,4],si64>) -> !torch.tensor {
%false = torch.constant.bool false
%int1 = torch.constant.int 1
%int4 = torch.constant.int 4
%int0 = torch.constant.int 0
%0 = torch.tensor_static_info_cast %arg1 : !torch.vtensor<[1,4],si64> to !torch.vtensor
%1 = torch.copy.to_tensor %0 : !torch.tensor
%cpu = torch.constant.device "cpu"
%2 = torch.prim.ListConstruct %int1, %int4 : (!torch.int, !torch.int) -> !torch.list<int>
%3 = torch.aten.new_zeros %1, %2, %int4, %int0, %cpu, %false : !torch.tensor, !torch.list<int>, !torch.int, !torch.int, !torch.Device, !torch.bool -> !torch.tensor
%4 = torch.tensor.literal(dense<0> : tensor<si64>) : !torch.tensor<[],si64>
%5 = torch.tensor_static_info_cast %4 : !torch.tensor<[],si64> to !torch.tensor
%6 = torch.aten.lift_fresh_copy %5 : !torch.tensor -> !torch.tensor
%7 = torch.aten.select.int %3, %int1, %int0 : !torch.tensor, !torch.int, !torch.int -> !torch.tensor
%8 = torch.aten.fill_.Tensor %7, %6 : !torch.tensor, !torch.tensor -> !torch.tensor
return %3 : !torch.tensor
}
}
// -----// IR Dump After Canonicalizer (canonicalize) //----- //
func.func @forward(%arg0: !torch.vtensor<[1,15],si64>, %arg1: !torch.vtensor<[1,4],si64>) -> !torch.tensor {
%false = torch.constant.bool false
%int1 = torch.constant.int 1
%int4 = torch.constant.int 4
%int0 = torch.constant.int 0
%0 = torch.tensor_static_info_cast %arg1 : !torch.vtensor<[1,4],si64> to !torch.vtensor
%1 = torch.copy.to_tensor %0 : !torch.tensor
%cpu = torch.constant.device "cpu"
%2 = torch.prim.ListConstruct %int1, %int4 : (!torch.int, !torch.int) -> !torch.list<int>
%3 = torch.aten.new_zeros %1, %2, %int4, %int0, %cpu, %false : !torch.tensor, !torch.list<int>, !torch.int, !torch.int, !torch.Device, !torch.bool -> !torch.tensor
%4 = torch.tensor.literal(dense<0> : tensor<si64>) : !torch.tensor<[],si64>
%5 = torch.aten.lift_fresh_copy %4 : !torch.tensor<[],si64> -> !torch.tensor
%6 = torch.aten.select.int %3, %int1, %int0 : !torch.tensor, !torch.int, !torch.int -> !torch.tensor
%7 = torch.aten.fill_.Tensor %6, %5 : !torch.tensor, !torch.tensor -> !torch.tensor
return %3 : !torch.tensor
}
// -----// IR Dump After RecomposeComplexOps (torch-recompose-complex-ops) //----- //
func.func @forward(%arg0: !torch.vtensor<[1,15],si64>, %arg1: !torch.vtensor<[1,4],si64>) -> !torch.tensor {
%false = torch.constant.bool false
%int1 = torch.constant.int 1
%int4 = torch.constant.int 4
%int0 = torch.constant.int 0
%0 = torch.tensor_static_info_cast %arg1 : !torch.vtensor<[1,4],si64> to !torch.vtensor
%1 = torch.copy.to_tensor %0 : !torch.tensor
%cpu = torch.constant.device "cpu"
%2 = torch.prim.ListConstruct %int1, %int4 : (!torch.int, !torch.int) -> !torch.list<int>
%3 = torch.aten.new_zeros %1, %2, %int4, %int0, %cpu, %false : !torch.tensor, !torch.list<int>, !torch.int, !torch.int, !torch.Device, !torch.bool -> !torch.tensor
%4 = torch.tensor.literal(dense<0> : tensor<si64>) : !torch.tensor<[],si64>
%5 = torch.aten.lift_fresh_copy %4 : !torch.tensor<[],si64> -> !torch.tensor
%6 = torch.aten.select.int %3, %int1, %int0 : !torch.tensor, !torch.int, !torch.int -> !torch.tensor
%7 = torch.aten.fill_.Tensor %6, %5 : !torch.tensor, !torch.tensor -> !torch.tensor
return %3 : !torch.tensor
}
And the shape happened before the SelectOp is decomposed into SliceOp. The shows after RefineTypes:
// -----// IR Dump After DropAbstractInterpCalculations (torch-drop-abstract-interp-calculations) //----- //
func.func @forward(%arg0: !torch.vtensor<[1,15],si64>, %arg1: !torch.vtensor<[1,4],si64>) -> !torch.vtensor {
%0 = torch.vtensor.literal(dense<0> : tensor<si64>) : !torch.vtensor<[],si64>
%false = torch.constant.bool false
%int1 = torch.constant.int 1
%int4 = torch.constant.int 4
%int0 = torch.constant.int 0
%cpu = torch.constant.device "cpu"
%1 = torch.prim.ListConstruct %int1, %int4 : (!torch.int, !torch.int) -> !torch.list<int>
%2 = torch.aten.new_zeros %arg1, %1, %int4, %int0, %cpu, %false : !torch.vtensor<[1,4],si64>, !torch.list<int>, !torch.int, !torch.int, !torch.Device, !torch.bool -> !torch.vtensor<[1,4],unk>
%3 = torch.tensor_static_info_cast %2 : !torch.vtensor<[1,4],unk> to !torch.vtensor
%4 = torch.copy.to_tensor %3 : !torch.tensor
%5 = torch.aten.lift_fresh_copy %0 : !torch.vtensor<[],si64> -> !torch.vtensor<[],unk>
%6 = torch.aten.select.int %4, %int1, %int0 : !torch.tensor, !torch.int, !torch.int -> !torch.tensor
%7 = torch.copy.to_vtensor %6 : !torch.vtensor
%8 = torch.aten.fill.Tensor %7, %5 : !torch.vtensor, !torch.vtensor<[],unk> -> !torch.vtensor
torch.overwrite.tensor.contents %8 overwrites %6 : !torch.vtensor, !torch.tensor
%9 = torch.copy.to_vtensor %4 : !torch.vtensor
return %9 : !torch.vtensor
}
// -----// IR Dump After RefineTypes (torch-refine-types) //----- //
func.func @forward(%arg0: !torch.vtensor<[1,15],si64>, %arg1: !torch.vtensor<[1,4],si64>) -> !torch.vtensor {
%0 = torch.vtensor.literal(dense<0> : tensor<si64>) : !torch.vtensor<[],si64>
%false = torch.constant.bool false
%int1 = torch.constant.int 1
%int4 = torch.constant.int 4
%int0 = torch.constant.int 0
%cpu = torch.constant.device "cpu"
%1 = torch.prim.ListConstruct %int1, %int4 : (!torch.int, !torch.int) -> !torch.list<int>
%2 = torch.aten.new_zeros %arg1, %1, %int4, %int0, %cpu, %false : !torch.vtensor<[1,4],si64>, !torch.list<int>, !torch.int, !torch.int, !torch.Device, !torch.bool -> !torch.vtensor<[1,4],si64>
%3 = torch.tensor_static_info_cast %2 : !torch.vtensor<[1,4],si64> to !torch.vtensor<*,si64>
%4 = torch.copy.to_tensor %3 : !torch.tensor<*,si64>
%5 = torch.aten.lift_fresh_copy %0 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64>
%6 = torch.aten.select.int %4, %int1, %int0 : !torch.tensor<*,si64>, !torch.int, !torch.int -> !torch.tensor<*,si64>
%7 = torch.copy.to_vtensor %6 : !torch.vtensor<*,si64>
%8 = torch.aten.fill.Tensor %7, %5 : !torch.vtensor<*,si64>, !torch.vtensor<[],si64> -> !torch.vtensor<*,si64>
%9 = torch.tensor_static_info_cast %8 : !torch.vtensor<*,si64> to !torch.vtensor<*,si64>
torch.overwrite.tensor.contents %9 overwrites %6 : !torch.vtensor<*,si64>, !torch.tensor<*,si64>
%10 = torch.copy.to_vtensor %4 : !torch.vtensor<*,si64>
%11 = torch.tensor_static_info_cast %10 : !torch.vtensor<*,si64> to !torch.vtensor
return %11 : !torch.vtensor
}
// -----// IR Dump After RefinePublicReturn (torch-refine-public-return) //----- //
module attributes {torch.debug_module_name = "_lambda"} {
func.func @forward(%arg0: !torch.vtensor<[1,15],si64>, %arg1: !torch.vtensor<[1,4],si64>) -> !torch.vtensor<*,si64> {
%0 = torch.vtensor.literal(dense<0> : tensor<si64>) : !torch.vtensor<[],si64>
%false = torch.constant.bool false
%int1 = torch.constant.int 1
%int4 = torch.constant.int 4
%int0 = torch.constant.int 0
%cpu = torch.constant.device "cpu"
%1 = torch.prim.ListConstruct %int1, %int4 : (!torch.int, !torch.int) -> !torch.list<int>
%2 = torch.aten.new_zeros %arg1, %1, %int4, %int0, %cpu, %false : !torch.vtensor<[1,4],si64>, !torch.list<int>, !torch.int, !torch.int, !torch.Device, !torch.bool -> !torch.vtensor<[1,4],si64>
%3 = torch.tensor_static_info_cast %2 : !torch.vtensor<[1,4],si64> to !torch.vtensor<*,si64>
%4 = torch.copy.to_tensor %3 : !torch.tensor<*,si64>
%5 = torch.aten.lift_fresh_copy %0 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64>
%6 = torch.aten.select.int %4, %int1, %int0 : !torch.tensor<*,si64>, !torch.int, !torch.int -> !torch.tensor<*,si64>
%7 = torch.copy.to_vtensor %6 : !torch.vtensor<*,si64>
%8 = torch.aten.fill.Tensor %7, %5 : !torch.vtensor<*,si64>, !torch.vtensor<[],si64> -> !torch.vtensor<*,si64>
%9 = torch.tensor_static_info_cast %8 : !torch.vtensor<*,si64> to !torch.vtensor<*,si64>
torch.overwrite.tensor.contents %9 overwrites %6 : !torch.vtensor<*,si64>, !torch.tensor<*,si64>
%10 = torch.copy.to_vtensor %4 : !torch.vtensor<*,si64>
%11 = torch.tensor_static_info_cast %10 : !torch.vtensor<*,si64> to !torch.vtensor
return %10 : !torch.vtensor<*,si64>
}
}
// -----// IR Dump After Canonicalizer (canonicalize) //----- //
func.func @forward(%arg0: !torch.vtensor<[1,15],si64>, %arg1: !torch.vtensor<[1,4],si64>) -> !torch.vtensor<*,si64> {
%0 = torch.vtensor.literal(dense<0> : tensor<si64>) : !torch.vtensor<[],si64>
%false = torch.constant.bool false
%int1 = torch.constant.int 1
%int4 = torch.constant.int 4
%int0 = torch.constant.int 0
%cpu = torch.constant.device "cpu"
%1 = torch.prim.ListConstruct %int1, %int4 : (!torch.int, !torch.int) -> !torch.list<int>
%2 = torch.aten.new_zeros %arg1, %1, %int4, %int0, %cpu, %false : !torch.vtensor<[1,4],si64>, !torch.list<int>, !torch.int, !torch.int, !torch.Device, !torch.bool -> !torch.vtensor<[1,4],si64>
%3 = torch.tensor_static_info_cast %2 : !torch.vtensor<[1,4],si64> to !torch.vtensor<*,si64>
%4 = torch.copy.to_tensor %3 : !torch.tensor<*,si64>
%5 = torch.aten.lift_fresh_copy %0 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64>
%6 = torch.aten.select.int %4, %int1, %int0 : !torch.tensor<*,si64>, !torch.int, !torch.int -> !torch.tensor<*,si64>
%7 = torch.copy.to_vtensor %6 : !torch.vtensor<*,si64>
%8 = torch.aten.fill.Tensor %7, %5 : !torch.vtensor<*,si64>, !torch.vtensor<[],si64> -> !torch.vtensor<*,si64>
torch.overwrite.tensor.contents %8 overwrites %6 : !torch.vtensor<*,si64>, !torch.tensor<*,si64>
%9 = torch.copy.to_vtensor %4 : !torch.vtensor<*,si64>
return %9 : !torch.vtensor<*,si64>
}
// -----// IR Dump After DecomposeComplexOps (torch-decompose-complex-ops) //----- //
func.func @forward(%arg0: !torch.vtensor<[1,15],si64>, %arg1: !torch.vtensor<[1,4],si64>) -> !torch.vtensor<*,si64> {
%int1 = torch.constant.int 1
%int0 = torch.constant.int 0
%false = torch.constant.bool false
%int4 = torch.constant.int 4
%none = torch.constant.none
%0 = torch.vtensor.literal(dense<0> : tensor<si64>) : !torch.vtensor<[],si64>
%cpu = torch.constant.device "cpu"
%1 = torch.prim.ListConstruct %int1, %int4 : (!torch.int, !torch.int) -> !torch.list<int>
%2 = torch.aten.zeros %1, %int4, %int0, %cpu, %false : !torch.list<int>, !torch.int, !torch.int, !torch.Device, !torch.bool -> !torch.vtensor<[1,4],si64>
%3 = torch.tensor_static_info_cast %2 : !torch.vtensor<[1,4],si64> to !torch.vtensor<*,si64>
%4 = torch.copy.to_tensor %3 : !torch.tensor<*,si64>
%5 = torch.aten.clone %0, %none : !torch.vtensor<[],si64>, !torch.none -> !torch.vtensor<[],si64>
%6 = torch.aten.slice.Tensor %4, %int1, %int0, %int1, %int1 : !torch.tensor<*,si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.tensor<*,si64>
%7 = torch.aten.squeeze.dim %6, %int1 : !torch.tensor<*,si64>, !torch.int -> !torch.tensor<*,si64>
%8 = torch.copy.to_vtensor %7 : !torch.vtensor<*,si64>
%9 = torch.aten.fill.Tensor %8, %5 : !torch.vtensor<*,si64>, !torch.vtensor<[],si64> -> !torch.vtensor<*,si64>
torch.overwrite.tensor.contents %9 overwrites %7 : !torch.vtensor<*,si64>, !torch.tensor<*,si64>
%10 = torch.copy.to_vtensor %4 : !torch.vtensor<*,si64>
return %10 : !torch.vtensor<*,si64>
}
Find this bug when fixing slice and copy shape * issue https://github.com/llvm/torch-mlir/issues/1953. PR: https://github.com/llvm/torch-mlir/pull/1970
Success: test_slicecopy.py FAIL: test_maked_fill.py KEY PYTHORN PART:
x_new[..., 0] = 1
In t5_model:%144 = torch.aten.masked_fill_.Scalar %134, %143, %int0 : !torch.tensor, !torch.tensor, !torch.int -> !torch.tensor
e2e tests:
Run e2e test:
python -m e2e_testing.main -c tosa -f "SliceCopyMaskedFillModule" -v
torchscript to torchbackend :
torch-mlir-opt -pass-pipeline='builtin.module(torchscript-module-to-torch-backend-pipeline{backend-legal-ops=torch.aten.flatten.using_ints,torch.aten.native_layer_norm,torch.aten.linear})' /tmp/SliceCopyMaskedFillModule.mlir -mlir-print-ir-after-failure -mlir-disable-threading
@ramiro050 any idea to fix this?