llvm / torch-mlir

The Torch-MLIR project aims to provide first class support from the PyTorch ecosystem to the MLIR ecosystem.
Other
1.32k stars 493 forks source link

[MLIR] select+fill_ op shape * support #1979

Open AmosLewis opened 1 year ago

AmosLewis commented 1 year ago

Find this bug when fixing slice and copy shape * issue https://github.com/llvm/torch-mlir/issues/1953. PR: https://github.com/llvm/torch-mlir/pull/1970

Success: test_slicecopy.py FAIL: test_maked_fill.py KEY PYTHORN PART: x_new[..., 0] = 1 In t5_model: %144 = torch.aten.masked_fill_.Scalar %134, %143, %int0 : !torch.tensor, !torch.tensor, !torch.int -> !torch.tensor

e2e tests:

class SliceCopyMaskedFillModule(torch.nn.Module):
    def __init__(self):
        super().__init__()

    @export
    @annotate_args([
        None,
        ([1, 4], torch.float32, True),
    ])
    def forward(self, x):
        x_new = x.new_zeros(x.shape)  # tensor([[0, 0, 0, 0]])
        x_new[..., 0] = 1   # tensor([[1, 0,  0,   0]])
        return x_new

@register_test_case(module_factory=lambda: SliceCopyMaskedFillModule())
def SliceCopyMaskedFillModule_basic(module, tu: TestUtils):
    module.forward(tu.rand(1, 4))

Run e2e test: python -m e2e_testing.main -c tosa -f "SliceCopyMaskedFillModule" -v

2023-03-27 15:25:17.594159: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2023-03-27 15:25:17.682064: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory
2023-03-27 15:25:17.682082: I tensorflow/compiler/xla/stream_executor/cuda/cudart_stub.cc:29] Ignore above cudart dlerror if you do not have a GPU set up on your machine.
2023-03-27 15:25:18.338940: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory
2023-03-27 15:25:18.338991: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory
2023-03-27 15:25:18.339018: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.
Compiling SliceCopyMaskedFillModule_basic...
XFAIL - "SliceCopyMaskedFillModule_basic"

Summary:
    Expectedly Failed: 1

torchscript to torchbackend : torch-mlir-opt -pass-pipeline='builtin.module(torchscript-module-to-torch-backend-pipeline{backend-legal-ops=torch.aten.flatten.using_ints,torch.aten.native_layer_norm,torch.aten.linear})' /tmp/SliceCopyMaskedFillModule.mlir -mlir-print-ir-after-failure -mlir-disable-threading

/home/chi/src/ubuntu20/shark/torch-mlir/build/tools/torch-mlir/python_packages/torch_mlir/torch_mlir_e2e_test/test_suite/slice_like.py:574:16: error: unsupported by backend contract: tensor with unknown rank
        x_new = x.new_zeros(x.shape)  # tensor([[0, 0, 0, 0]])
               ^
/home/chi/src/ubuntu20/shark/torch-mlir/build/tools/torch-mlir/python_packages/torch_mlir/torch_mlir_e2e_test/test_suite/slice_like.py:574:16: note: see current operation: %9 = "torch.tensor_static_info_cast"(%8) : (!torch.vtensor<[1,4],f32>) -> !torch.vtensor<*,f32>
/home/chi/src/ubuntu20/shark/torch-mlir/build/tools/torch-mlir/python_packages/torch_mlir/torch_mlir_e2e_test/test_suite/slice_like.py:574:16: note: this is likely due to a missing transfer function in abstract_interp_lib_gen.py
// -----// IR Dump After LowerToBackendContract Failed (torch-lower-to-backend-contract) //----- //
module attributes {torch.debug_module_name = "SliceCopyMaskedFillModule"} {
  func.func @forward(%arg0: !torch.vtensor<[1,4],f32>) -> !torch.vtensor<*,f32> {
    %int1 = torch.constant.int 1
    %int0 = torch.constant.int 0
    %false = torch.constant.bool false
    %int4 = torch.constant.int 4
    %int6 = torch.constant.int 6
    %none = torch.constant.none
    %int-1 = torch.constant.int -1
    %0 = torch.prim.ListConstruct %int1, %int4 : (!torch.int, !torch.int) -> !torch.list<int>
    %1 = torch.aten.zeros %0, %int6, %none, %none, %none : !torch.list<int>, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[1,4],f32>
    %2 = torch.tensor_static_info_cast %1 : !torch.vtensor<[1,4],f32> to !torch.vtensor<*,f32>
    %3 = torch.copy.to_tensor %2 : !torch.tensor<*,f32>
    %4 = torch.aten.slice.Tensor %3, %int-1, %int0, %int1, %int1 : !torch.tensor<*,f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.tensor<[1,1],f32>
    %5 = torch.aten.squeeze.dim %4, %int-1 : !torch.tensor<[1,1],f32>, !torch.int -> !torch.tensor<[1],f32>
    %6 = torch.tensor_static_info_cast %5 : !torch.tensor<[1],f32> to !torch.tensor<*,f32>
    %7 = torch.copy.to_vtensor %6 : !torch.vtensor<*,f32>
    %8 = torch.prim.device %7 : !torch.vtensor<*,f32> -> !torch.Device
    %9 = torch.aten.tensor.int %int1, %int6, %8, %false : !torch.int, !torch.int, !torch.Device, !torch.bool -> !torch.vtensor<[],f32>
    %10 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list<int>
    %11 = torch.aten.broadcast_to %9, %10 : !torch.vtensor<[],f32>, !torch.list<int> -> !torch.vtensor<[1],f32>
    %12 = torch.tensor_static_info_cast %11 : !torch.vtensor<[1],f32> to !torch.vtensor<*,f32>
    torch.overwrite.tensor.contents %12 overwrites %6 : !torch.vtensor<*,f32>, !torch.tensor<*,f32>
    %13 = torch.copy.to_vtensor %3 : !torch.vtensor<*,f32>
    return %13 : !torch.vtensor<*,f32>
  }
}

@ramiro050 any idea to fix this?

AmosLewis commented 1 year ago

There is a fill op in the manually created python file test_maked_fill.py, which is same as the e2e test model definition. %10 = torch.aten.fill.Tensor %9, %5 : !torch.vtensor<*,si64>, !torch.vtensor<[],si64> -> !torch.vtensor<[1],si64> It looks like the fill op is decomposed.

module attributes {torch.debug_module_name = "_lambda"} {
  func.func @forward(%arg0: !torch.vtensor<[1,15],si64>, %arg1: !torch.vtensor<[1,4],si64>) -> !torch.vtensor<*,si64> {
    %int1 = torch.constant.int 1
    %int0 = torch.constant.int 0
    %false = torch.constant.bool false
    %int4 = torch.constant.int 4
    %none = torch.constant.none
    %0 = torch.vtensor.literal(dense<0> : tensor<si64>) : !torch.vtensor<[],si64>
    %cpu = torch.constant.device "cpu"
    %1 = torch.prim.ListConstruct %int1, %int4 : (!torch.int, !torch.int) -> !torch.list<int>
    %2 = torch.aten.zeros %1, %int4, %int0, %cpu, %false : !torch.list<int>, !torch.int, !torch.int, !torch.Device, !torch.bool -> !torch.vtensor<[1,4],si64>
    %3 = torch.tensor_static_info_cast %2 : !torch.vtensor<[1,4],si64> to !torch.vtensor<*,si64>
    %4 = torch.copy.to_tensor %3 : !torch.tensor<*,si64>
    %5 = torch.aten.clone %0, %none : !torch.vtensor<[],si64>, !torch.none -> !torch.vtensor<[],si64>
    %6 = torch.aten.slice.Tensor %4, %int1, %int0, %int1, %int1 : !torch.tensor<*,si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.tensor<[1,1],si64>
    %7 = torch.aten.squeeze.dim %6, %int1 : !torch.tensor<[1,1],si64>, !torch.int -> !torch.tensor<[1],si64>
    %8 = torch.tensor_static_info_cast %7 : !torch.tensor<[1],si64> to !torch.tensor<*,si64>
    %9 = torch.copy.to_vtensor %8 : !torch.vtensor<*,si64>
    %10 = torch.aten.fill.Tensor %9, %5 : !torch.vtensor<*,si64>, !torch.vtensor<[],si64> -> !torch.vtensor<[1],si64>
    %11 = torch.tensor_static_info_cast %10 : !torch.vtensor<[1],si64> to !torch.vtensor<*,si64>
    torch.overwrite.tensor.contents %11 overwrites %8 : !torch.vtensor<*,si64>, !torch.tensor<*,si64>
    %12 = torch.copy.to_vtensor %4 : !torch.vtensor<*,si64>
    return %12 : !torch.vtensor<*,si64>
  }
}
ramiro050 commented 1 year ago

Here the pattern seems to be select.int + copy_:

module attributes {torch.debug_module_name = "MyModule"} {
  func.func private @__torch__.MyModule.forward(%arg0: !torch.nn.Module<"__torch__.MyModule">, %arg1:
!torch.tensor {torch.type_bound = !torch.vtensor<[1,4],f32>}) -> !torch.tensor {
    %false = torch.constant.bool false
    %int-1 = torch.constant.int -1
    %int1 = torch.constant.int 1
    %int0 = torch.constant.int 0
    %1 = torch.aten.select.int %arg1, %int-1, %int0 : !torch.tensor, !torch.int, !torch.int ->
!torch.tensor
    %2 = torch.prim.dtype %1 : !torch.tensor -> !torch.int
    %3 = torch.prim.device %1 : !torch.tensor -> !torch.Device
    %4 = torch.aten.tensor.int %int1, %2, %3, %false : !torch.int, !torch.int, !torch.Device, !torch.bool
-> !torch.tensor
    %5 = torch.aten.copy_ %1, %4, %false : !torch.tensor, !torch.tensor, !torch.bool -> !torch.tensor
    return %arg1 : !torch.tensor
  }
  torch.class_type @__torch__.MyModule {
    torch.attr private "training" : !torch.bool
    torch.attr private "_is_full_backward_hook" : !torch.optional<bool>
    torch.method "forward", @__torch__.MyModule.forward
  }
  %true = torch.constant.bool true
  %none = torch.constant.none
  %0 = torch.nn_module {
    torch.slot "training", %true : !torch.bool
    torch.slot "_is_full_backward_hook", %none : !torch.none
  } : !torch.nn.Module<"__torch__.MyModule">
}

This should be handled in a similar way. In the index_put op, every indices tensor will be None except for the tensor at dim selectInt.getDim(), which will be a single element tensor with value selectInt.getIndex()

AmosLewis commented 1 year ago

Here the pattern seems to be select.int + copy_:

module attributes {torch.debug_module_name = "MyModule"} {
  func.func private @__torch__.MyModule.forward(%arg0: !torch.nn.Module<"__torch__.MyModule">, %arg1:
!torch.tensor {torch.type_bound = !torch.vtensor<[1,4],f32>}) -> !torch.tensor {
    %false = torch.constant.bool false
    %int-1 = torch.constant.int -1
    %int1 = torch.constant.int 1
    %int0 = torch.constant.int 0
    %1 = torch.aten.select.int %arg1, %int-1, %int0 : !torch.tensor, !torch.int, !torch.int ->
!torch.tensor
    %2 = torch.prim.dtype %1 : !torch.tensor -> !torch.int
    %3 = torch.prim.device %1 : !torch.tensor -> !torch.Device
    %4 = torch.aten.tensor.int %int1, %2, %3, %false : !torch.int, !torch.int, !torch.Device, !torch.bool
-> !torch.tensor
    %5 = torch.aten.copy_ %1, %4, %false : !torch.tensor, !torch.tensor, !torch.bool -> !torch.tensor
    return %arg1 : !torch.tensor
  }
  torch.class_type @__torch__.MyModule {
    torch.attr private "training" : !torch.bool
    torch.attr private "_is_full_backward_hook" : !torch.optional<bool>
    torch.method "forward", @__torch__.MyModule.forward
  }
  %true = torch.constant.bool true
  %none = torch.constant.none
  %0 = torch.nn_module {
    torch.slot "training", %true : !torch.bool
    torch.slot "_is_full_backward_hook", %none : !torch.none
  } : !torch.nn.Module<"__torch__.MyModule">
}

This should be handled in a similar way. In the index_put op, every indices tensor will be None except for the tensor at dim selectInt.getDim(), which will be a single element tensor with value selectInt.getIndex()

I just push a patch to fold the select and copy but fail to get the input rank https://github.com/llvm/torch-mlir/pull/2000. BTW, I didn't find any aten.copy Op in my IR dump. Only aten.fill found. Why do you think it is a select+copy_ issue?

AmosLewis commented 1 year ago

I tried to print IR before and after createRecomposeComplexOpsPass, but it looks like the pattern is not selectop+copy but selectop+fill.

// -----// IR Dump After EraseModuleInitializer (torch-erase-module-initializer) //----- //
module attributes {torch.debug_module_name = "_lambda"} {
  func.func @forward(%arg0: !torch.vtensor<[1,15],si64>, %arg1: !torch.vtensor<[1,4],si64>) -> !torch.tensor {
    %false = torch.constant.bool false
    %int1 = torch.constant.int 1
    %int4 = torch.constant.int 4
    %int0 = torch.constant.int 0
    %0 = torch.tensor_static_info_cast %arg1 : !torch.vtensor<[1,4],si64> to !torch.vtensor
    %1 = torch.copy.to_tensor %0 : !torch.tensor
    %cpu = torch.constant.device "cpu"
    %2 = torch.prim.ListConstruct %int1, %int4 : (!torch.int, !torch.int) -> !torch.list<int>
    %3 = torch.aten.new_zeros %1, %2, %int4, %int0, %cpu, %false : !torch.tensor, !torch.list<int>, !torch.int, !torch.int, !torch.Device, !torch.bool -> !torch.tensor
    %4 = torch.tensor.literal(dense<0> : tensor<si64>) : !torch.tensor<[],si64>
    %5 = torch.tensor_static_info_cast %4 : !torch.tensor<[],si64> to !torch.tensor
    %6 = torch.aten.lift_fresh_copy %5 : !torch.tensor -> !torch.tensor
    %7 = torch.aten.select.int %3, %int1, %int0 : !torch.tensor, !torch.int, !torch.int -> !torch.tensor
    %8 = torch.aten.fill_.Tensor %7, %6 : !torch.tensor, !torch.tensor -> !torch.tensor
    return %3 : !torch.tensor
  }
}

// -----// IR Dump After Canonicalizer (canonicalize) //----- //
func.func @forward(%arg0: !torch.vtensor<[1,15],si64>, %arg1: !torch.vtensor<[1,4],si64>) -> !torch.tensor {
  %false = torch.constant.bool false
  %int1 = torch.constant.int 1
  %int4 = torch.constant.int 4
  %int0 = torch.constant.int 0
  %0 = torch.tensor_static_info_cast %arg1 : !torch.vtensor<[1,4],si64> to !torch.vtensor
  %1 = torch.copy.to_tensor %0 : !torch.tensor
  %cpu = torch.constant.device "cpu"
  %2 = torch.prim.ListConstruct %int1, %int4 : (!torch.int, !torch.int) -> !torch.list<int>
  %3 = torch.aten.new_zeros %1, %2, %int4, %int0, %cpu, %false : !torch.tensor, !torch.list<int>, !torch.int, !torch.int, !torch.Device, !torch.bool -> !torch.tensor
  %4 = torch.tensor.literal(dense<0> : tensor<si64>) : !torch.tensor<[],si64>
  %5 = torch.aten.lift_fresh_copy %4 : !torch.tensor<[],si64> -> !torch.tensor
  %6 = torch.aten.select.int %3, %int1, %int0 : !torch.tensor, !torch.int, !torch.int -> !torch.tensor
  %7 = torch.aten.fill_.Tensor %6, %5 : !torch.tensor, !torch.tensor -> !torch.tensor
  return %3 : !torch.tensor
}

// -----// IR Dump After RecomposeComplexOps (torch-recompose-complex-ops) //----- //
func.func @forward(%arg0: !torch.vtensor<[1,15],si64>, %arg1: !torch.vtensor<[1,4],si64>) -> !torch.tensor {
  %false = torch.constant.bool false
  %int1 = torch.constant.int 1
  %int4 = torch.constant.int 4
  %int0 = torch.constant.int 0
  %0 = torch.tensor_static_info_cast %arg1 : !torch.vtensor<[1,4],si64> to !torch.vtensor
  %1 = torch.copy.to_tensor %0 : !torch.tensor
  %cpu = torch.constant.device "cpu"
  %2 = torch.prim.ListConstruct %int1, %int4 : (!torch.int, !torch.int) -> !torch.list<int>
  %3 = torch.aten.new_zeros %1, %2, %int4, %int0, %cpu, %false : !torch.tensor, !torch.list<int>, !torch.int, !torch.int, !torch.Device, !torch.bool -> !torch.tensor
  %4 = torch.tensor.literal(dense<0> : tensor<si64>) : !torch.tensor<[],si64>
  %5 = torch.aten.lift_fresh_copy %4 : !torch.tensor<[],si64> -> !torch.tensor
  %6 = torch.aten.select.int %3, %int1, %int0 : !torch.tensor, !torch.int, !torch.int -> !torch.tensor
  %7 = torch.aten.fill_.Tensor %6, %5 : !torch.tensor, !torch.tensor -> !torch.tensor
  return %3 : !torch.tensor
}
AmosLewis commented 1 year ago

And the shape happened before the SelectOp is decomposed into SliceOp. The shows after RefineTypes:

// -----// IR Dump After DropAbstractInterpCalculations (torch-drop-abstract-interp-calculations) //----- //
func.func @forward(%arg0: !torch.vtensor<[1,15],si64>, %arg1: !torch.vtensor<[1,4],si64>) -> !torch.vtensor {
  %0 = torch.vtensor.literal(dense<0> : tensor<si64>) : !torch.vtensor<[],si64>
  %false = torch.constant.bool false
  %int1 = torch.constant.int 1
  %int4 = torch.constant.int 4
  %int0 = torch.constant.int 0
  %cpu = torch.constant.device "cpu"
  %1 = torch.prim.ListConstruct %int1, %int4 : (!torch.int, !torch.int) -> !torch.list<int>
  %2 = torch.aten.new_zeros %arg1, %1, %int4, %int0, %cpu, %false : !torch.vtensor<[1,4],si64>, !torch.list<int>, !torch.int, !torch.int, !torch.Device, !torch.bool -> !torch.vtensor<[1,4],unk>
  %3 = torch.tensor_static_info_cast %2 : !torch.vtensor<[1,4],unk> to !torch.vtensor
  %4 = torch.copy.to_tensor %3 : !torch.tensor
  %5 = torch.aten.lift_fresh_copy %0 : !torch.vtensor<[],si64> -> !torch.vtensor<[],unk>
  %6 = torch.aten.select.int %4, %int1, %int0 : !torch.tensor, !torch.int, !torch.int -> !torch.tensor
  %7 = torch.copy.to_vtensor %6 : !torch.vtensor
  %8 = torch.aten.fill.Tensor %7, %5 : !torch.vtensor, !torch.vtensor<[],unk> -> !torch.vtensor
  torch.overwrite.tensor.contents %8 overwrites %6 : !torch.vtensor, !torch.tensor
  %9 = torch.copy.to_vtensor %4 : !torch.vtensor
  return %9 : !torch.vtensor
}

// -----// IR Dump After RefineTypes (torch-refine-types) //----- //
func.func @forward(%arg0: !torch.vtensor<[1,15],si64>, %arg1: !torch.vtensor<[1,4],si64>) -> !torch.vtensor {
  %0 = torch.vtensor.literal(dense<0> : tensor<si64>) : !torch.vtensor<[],si64>
  %false = torch.constant.bool false
  %int1 = torch.constant.int 1
  %int4 = torch.constant.int 4
  %int0 = torch.constant.int 0
  %cpu = torch.constant.device "cpu"
  %1 = torch.prim.ListConstruct %int1, %int4 : (!torch.int, !torch.int) -> !torch.list<int>
  %2 = torch.aten.new_zeros %arg1, %1, %int4, %int0, %cpu, %false : !torch.vtensor<[1,4],si64>, !torch.list<int>, !torch.int, !torch.int, !torch.Device, !torch.bool -> !torch.vtensor<[1,4],si64>
  %3 = torch.tensor_static_info_cast %2 : !torch.vtensor<[1,4],si64> to !torch.vtensor<*,si64>
  %4 = torch.copy.to_tensor %3 : !torch.tensor<*,si64>
  %5 = torch.aten.lift_fresh_copy %0 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64>
  %6 = torch.aten.select.int %4, %int1, %int0 : !torch.tensor<*,si64>, !torch.int, !torch.int -> !torch.tensor<*,si64>
  %7 = torch.copy.to_vtensor %6 : !torch.vtensor<*,si64>
  %8 = torch.aten.fill.Tensor %7, %5 : !torch.vtensor<*,si64>, !torch.vtensor<[],si64> -> !torch.vtensor<*,si64>
  %9 = torch.tensor_static_info_cast %8 : !torch.vtensor<*,si64> to !torch.vtensor<*,si64>
  torch.overwrite.tensor.contents %9 overwrites %6 : !torch.vtensor<*,si64>, !torch.tensor<*,si64>
  %10 = torch.copy.to_vtensor %4 : !torch.vtensor<*,si64>
  %11 = torch.tensor_static_info_cast %10 : !torch.vtensor<*,si64> to !torch.vtensor
  return %11 : !torch.vtensor
}

// -----// IR Dump After RefinePublicReturn (torch-refine-public-return) //----- //
module attributes {torch.debug_module_name = "_lambda"} {
  func.func @forward(%arg0: !torch.vtensor<[1,15],si64>, %arg1: !torch.vtensor<[1,4],si64>) -> !torch.vtensor<*,si64> {
    %0 = torch.vtensor.literal(dense<0> : tensor<si64>) : !torch.vtensor<[],si64>
    %false = torch.constant.bool false
    %int1 = torch.constant.int 1
    %int4 = torch.constant.int 4
    %int0 = torch.constant.int 0
    %cpu = torch.constant.device "cpu"
    %1 = torch.prim.ListConstruct %int1, %int4 : (!torch.int, !torch.int) -> !torch.list<int>
    %2 = torch.aten.new_zeros %arg1, %1, %int4, %int0, %cpu, %false : !torch.vtensor<[1,4],si64>, !torch.list<int>, !torch.int, !torch.int, !torch.Device, !torch.bool -> !torch.vtensor<[1,4],si64>
    %3 = torch.tensor_static_info_cast %2 : !torch.vtensor<[1,4],si64> to !torch.vtensor<*,si64>
    %4 = torch.copy.to_tensor %3 : !torch.tensor<*,si64>
    %5 = torch.aten.lift_fresh_copy %0 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64>
    %6 = torch.aten.select.int %4, %int1, %int0 : !torch.tensor<*,si64>, !torch.int, !torch.int -> !torch.tensor<*,si64>
    %7 = torch.copy.to_vtensor %6 : !torch.vtensor<*,si64>
    %8 = torch.aten.fill.Tensor %7, %5 : !torch.vtensor<*,si64>, !torch.vtensor<[],si64> -> !torch.vtensor<*,si64>
    %9 = torch.tensor_static_info_cast %8 : !torch.vtensor<*,si64> to !torch.vtensor<*,si64>
    torch.overwrite.tensor.contents %9 overwrites %6 : !torch.vtensor<*,si64>, !torch.tensor<*,si64>
    %10 = torch.copy.to_vtensor %4 : !torch.vtensor<*,si64>
    %11 = torch.tensor_static_info_cast %10 : !torch.vtensor<*,si64> to !torch.vtensor
    return %10 : !torch.vtensor<*,si64>
  }
}

// -----// IR Dump After Canonicalizer (canonicalize) //----- //
func.func @forward(%arg0: !torch.vtensor<[1,15],si64>, %arg1: !torch.vtensor<[1,4],si64>) -> !torch.vtensor<*,si64> {
  %0 = torch.vtensor.literal(dense<0> : tensor<si64>) : !torch.vtensor<[],si64>
  %false = torch.constant.bool false
  %int1 = torch.constant.int 1
  %int4 = torch.constant.int 4
  %int0 = torch.constant.int 0
  %cpu = torch.constant.device "cpu"
  %1 = torch.prim.ListConstruct %int1, %int4 : (!torch.int, !torch.int) -> !torch.list<int>
  %2 = torch.aten.new_zeros %arg1, %1, %int4, %int0, %cpu, %false : !torch.vtensor<[1,4],si64>, !torch.list<int>, !torch.int, !torch.int, !torch.Device, !torch.bool -> !torch.vtensor<[1,4],si64>
  %3 = torch.tensor_static_info_cast %2 : !torch.vtensor<[1,4],si64> to !torch.vtensor<*,si64>
  %4 = torch.copy.to_tensor %3 : !torch.tensor<*,si64>
  %5 = torch.aten.lift_fresh_copy %0 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64>
  %6 = torch.aten.select.int %4, %int1, %int0 : !torch.tensor<*,si64>, !torch.int, !torch.int -> !torch.tensor<*,si64>
  %7 = torch.copy.to_vtensor %6 : !torch.vtensor<*,si64>
  %8 = torch.aten.fill.Tensor %7, %5 : !torch.vtensor<*,si64>, !torch.vtensor<[],si64> -> !torch.vtensor<*,si64>
  torch.overwrite.tensor.contents %8 overwrites %6 : !torch.vtensor<*,si64>, !torch.tensor<*,si64>
  %9 = torch.copy.to_vtensor %4 : !torch.vtensor<*,si64>
  return %9 : !torch.vtensor<*,si64>
}

// -----// IR Dump After DecomposeComplexOps (torch-decompose-complex-ops) //----- //
func.func @forward(%arg0: !torch.vtensor<[1,15],si64>, %arg1: !torch.vtensor<[1,4],si64>) -> !torch.vtensor<*,si64> {
  %int1 = torch.constant.int 1
  %int0 = torch.constant.int 0
  %false = torch.constant.bool false
  %int4 = torch.constant.int 4
  %none = torch.constant.none
  %0 = torch.vtensor.literal(dense<0> : tensor<si64>) : !torch.vtensor<[],si64>
  %cpu = torch.constant.device "cpu"
  %1 = torch.prim.ListConstruct %int1, %int4 : (!torch.int, !torch.int) -> !torch.list<int>
  %2 = torch.aten.zeros %1, %int4, %int0, %cpu, %false : !torch.list<int>, !torch.int, !torch.int, !torch.Device, !torch.bool -> !torch.vtensor<[1,4],si64>
  %3 = torch.tensor_static_info_cast %2 : !torch.vtensor<[1,4],si64> to !torch.vtensor<*,si64>
  %4 = torch.copy.to_tensor %3 : !torch.tensor<*,si64>
  %5 = torch.aten.clone %0, %none : !torch.vtensor<[],si64>, !torch.none -> !torch.vtensor<[],si64>
  %6 = torch.aten.slice.Tensor %4, %int1, %int0, %int1, %int1 : !torch.tensor<*,si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.tensor<*,si64>
  %7 = torch.aten.squeeze.dim %6, %int1 : !torch.tensor<*,si64>, !torch.int -> !torch.tensor<*,si64>
  %8 = torch.copy.to_vtensor %7 : !torch.vtensor<*,si64>
  %9 = torch.aten.fill.Tensor %8, %5 : !torch.vtensor<*,si64>, !torch.vtensor<[],si64> -> !torch.vtensor<*,si64>
  torch.overwrite.tensor.contents %9 overwrites %7 : !torch.vtensor<*,si64>, !torch.tensor<*,si64>
  %10 = torch.copy.to_vtensor %4 : !torch.vtensor<*,si64>
  return %10 : !torch.vtensor<*,si64>
}