llvm / torch-mlir

The Torch-MLIR project aims to provide first class support from the PyTorch ecosystem to the MLIR ecosystem.
Other
1.35k stars 505 forks source link

'unsupported by backend contract: non-value tensor type' with torch.zeros #2523

Open JBloodless opened 1 year ago

JBloodless commented 1 year ago

Hello. I’m trying to convert my model to MLIR and I’m stuck with this error:

python exception: Failure while executing pass pipeline:
error: "__module.cnn/aten::zeros"("/Users/i.beskrovnyy/tts/NISQA-s/utils/nn_utils.py":28:0): unsupported by backend contract: non-value tensor type
note: "__module.cnn/aten::zeros"("/Users/i.beskrovnyy/tts/NISQA-s/utils/nn_utils.py":28:0): see current operation: %223 = "torch.copy.to_tensor"(%222) : (!torch.vtensor<[63,1,384],f32>) -> !torch.tensor<[63,1,384],f32>
note: "__module.cnn/aten::zeros"("/Users/i.beskrovnyy/tts/NISQA-s/utils/nn_utils.py":28:0): this is likely due to a missing case in the MaximizeValueSemantics pass

I’m using many custom functions (because torch_mlir doesn’t support many parts of this model, like LSTMs and adaptive_max_pool2d), and this is the part of such custom function:

def pad_packed_sequence(packed_input, lengths, batch_first=False):
    max_length = max(lengths)
    batch_size = len(lengths)

    # here’s where torch_mlir fails
    padded_output = torch.zeros((max_length, batch_size, *packed_input.shape[1:]), dtype=packed_input.dtype,
                                device=packed_input.device)

    mask = torch.arange(max_length).unsqueeze(1).to(packed_input.device) < lengths.unsqueeze(0)

    mask_cumsum = mask.cumsum(dim=0) - 1
    mask_cumsum[~mask] = packed_input.shape[0]
    flattened_output = padded_output.view(-1, *packed_input.shape[1:])
    flattened_output[mask_cumsum[mask]] = packed_input

    if batch_first:
        padded_output = padded_output.transpose(0, 1)

    return padded_output

It mimics pad_packed_sequence from torch.nn.utils.rnn. Any idea why this happens and how to avoid it? I’m not sure that I can get rid of torch.zeros here.

As far as I know, torch.zeros and torch.empty are supported with #2440 and #604, so I have no idea why this shouldn't work.

ramiro050 commented 1 year ago

Hi @JBloodless, do you have an example of calling how you're calling pad_packed_squence and compiling it with torch-mlir? I can take a look. You're right that torch.zeros should work.

JBloodless commented 1 year ago

Hi @JBloodless, do you have an example of calling how you're calling pad_packed_squence and compiling it with torch-mlir? I can take a look. You're right that torch.zeros should work.

Here's full repro:

import torch
import torch.nn as nn
import torch_mlir

def pack_padded_sequence(input, lengths, batch_first=False):
    if batch_first:
        input = input.transpose(0, 1)

    sorted_lengths, indices = torch.sort(lengths, descending=True)
    sorted_input = input.index_select(1, indices.to(input.device))

    # Create a mask for valid entries in the sorted_input tensor
    mask = torch.arange(sorted_input.size(0)).unsqueeze(1) < sorted_lengths.unsqueeze(0)
    mask = mask.to(input.device)

    # Flatten the tensor and mask out invalid entries
    packed_output = sorted_input[mask]

    return packed_output

def pad_packed_sequence(packed_input, lengths, batch_first=False):
    max_length = torch.max(lengths).item()
    batch_size = lengths.shape[0]

    # Create an output tensor filled with zeros
    shapes = (max_length, batch_size, *packed_input.shape[1:])
    padded_output = torch.zeros(shapes, dtype=packed_input.dtype,
                                device=packed_input.device)

    # Create a mask based on lengths
    mask = torch.arange(max_length).unsqueeze(1).to(packed_input.device) < lengths.unsqueeze(0).to(packed_input.device)

    # Use mask to get the correct placement of packed values
    mask_cumsum = mask.cumsum(dim=0) - 1
    mask_cumsum[~mask] = packed_input.shape[0]  # Use an index that's out of bounds for the invalid places
    flattened_output = padded_output.view(-1, *packed_input.shape[1:])
    flattened_output[mask_cumsum[mask]] = packed_input

    if batch_first:
        padded_output = padded_output.transpose(0, 1)

    return padded_output

class model(nn.Module):

    def __init__(self):
        super().__init__()

    def forward(self, x, n_wins):
        x = pack_padded_sequence(
            x,
            n_wins.cpu(),
            batch_first=True
        )
        x = pad_packed_sequence(
            x, lengths=n_wins.cpu(),
            batch_first=True)
        return x

model = model()
model.eval()

x = torch.zeros((63, 1, 48, 15))
n_wins = 63
example_input = (x.unsqueeze(0).float(), torch.as_tensor(n_wins).unsqueeze(0))

mlir = torch_mlir.compile(
    model,
    example_input,
    output_type="linalg-on-tensors",
    use_tracing=True)

And full error:

Traceback (most recent call last):
  File "/Users/i.beskrovnyy/tts/NISQA-s/repro_mlir.py", line 71, in <module>
    mlir = torch_mlir.compile(
           ^^^^^^^^^^^^^^^^^^^
  File "/Users/i.beskrovnyy/anaconda3/envs/mlir/lib/python3.11/site-packages/torch_mlir/__init__.py", line 451, in compile
    run_pipeline_with_repro_report(
  File "/Users/i.beskrovnyy/anaconda3/envs/mlir/lib/python3.11/site-packages/torch_mlir/compiler_utils.py", line 69, in run_pipeline_with_repro_report
    raise TorchMlirCompilerError(trimmed_message) from None
torch_mlir.compiler_utils.TorchMlirCompilerError: Lowering TorchScript IR -> Torch Backend IR failed with the following diagnostics:

python exception: Failure while executing pass pipeline:
error: "aten::zeros"("/Users/i.beskrovnyy/tts/NISQA-s/repro_mlir.py":29:0): unsupported by backend contract: non-value tensor type
note: "aten::zeros"("/Users/i.beskrovnyy/tts/NISQA-s/repro_mlir.py":29:0): see current operation: %28 = "torch.copy.to_tensor"(%27) : (!torch.vtensor<[63,1,1,48,15],f32>) -> !torch.tensor<[63,1,1,48,15],f32>
note: "aten::zeros"("/Users/i.beskrovnyy/tts/NISQA-s/repro_mlir.py":29:0): this is likely due to a missing case in the MaximizeValueSemantics pass

For Torch-MLIR developers, the error can be reproduced with:
$ torch-mlir-opt -pass-pipeline='builtin.module(torchscript-module-to-torch-backend-pipeline{backend-legal-ops=aten.flatten.using_ints extra-library=})' /var/folders/7h/lx81hv9s2gd25ytr1bqg25700000gp/T/model.mlir
Add '-mlir-print-ir-after-all -mlir-disable-threading' to get the IR dump for debugging purpose.
gptsarthak commented 1 year ago

I believe that the MaximizeValueSemantics pass is not working as expected here. %8 is not being converted to value tensors because of torch.overwrite.tensor.contents operation which indirectly uses %8. Overwrite statements are not supported by the pass in the way the viewLikeOps and returnOps are.

// -----// IR Dump After MaximizeValueSemantics (torch-maximize-value-semantics) //----- //
func.func @forward(%arg0: !torch.vtensor<[1,63,1,48,15],f32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[1,63,1,48,15],f32> {
  %false = torch.constant.bool false
  %int0 = torch.constant.int 0
  %int1 = torch.constant.int 1
  %int63 = torch.constant.int 63
  %0 = torch.vtensor.literal(dense<1> : tensor<si64>) : !torch.vtensor<[],si64>
  %none = torch.constant.none
  %int-1 = torch.constant.int -1
  %true = torch.constant.bool true
  %int6 = torch.constant.int 6
  %int15 = torch.constant.int 15
  %int48 = torch.constant.int 48
  %cpu = torch.constant.device "cpu"
  %1 = torch.aten.transpose.int %arg0, %int0, %int1 : !torch.vtensor<[1,63,1,48,15],f32>, !torch.int, !torch.int -> !torch.vtensor<[63,1,1,48,15],f32>
  %values, %indices = torch.aten.sort %arg1, %int-1, %true : !torch.vtensor<[1],si64>, !torch.int, !torch.bool -> !torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>
  %2 = torch.aten.index_select %1, %int1, %indices : !torch.vtensor<[63,1,1,48,15],f32>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[63,1,1,48,15],f32>
  %3 = torch.prim.NumToTensor.Scalar %int63 : !torch.int -> !torch.vtensor<[],si64>
  %4 = torch.aten.arange.start_step %int0, %int63, %int1, %none, %none, %cpu, %false : !torch.int, !torch.int, !torch.int, !torch.none, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[63],si64>
  %5 = torch.aten.unsqueeze %4, %int1 : !torch.vtensor<[63],si64>, !torch.int -> !torch.vtensor<[63,1],si64>
  %6 = torch.aten.unsqueeze %values, %int0 : !torch.vtensor<[1],si64>, !torch.int -> !torch.vtensor<[1,1],si64>
  %7 = torch.aten.lt.Tensor %5, %6 : !torch.vtensor<[63,1],si64>, !torch.vtensor<[1,1],si64> -> !torch.vtensor<[63,1],i1>
  %8 = torch.prim.ListConstruct %7 : (!torch.vtensor<[63,1],i1>) -> !torch.list<vtensor>
  %9 = torch.aten.index.Tensor_hacked_twin %2, %8 : !torch.vtensor<[63,1,1,48,15],f32>, !torch.list<vtensor> -> !torch.vtensor<[63,1,48,15],f32>
  %10 = torch.prim.NumToTensor.Scalar %int1 : !torch.int -> !torch.vtensor<[],si64>
  %11 = torch.prim.NumToTensor.Scalar %int1 : !torch.int -> !torch.vtensor<[],si64>
  %12 = torch.prim.NumToTensor.Scalar %int48 : !torch.int -> !torch.vtensor<[],si64>
  %13 = torch.prim.NumToTensor.Scalar %int15 : !torch.int -> !torch.vtensor<[],si64>
  %14 = torch.prim.ListConstruct %int63, %int1, %int1, %int48, %int15 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
  %15 = torch.aten.zeros %14, %int6, %none, %cpu, %false : !torch.list<int>, !torch.int, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[63,1,1,48,15],f32>
  %16 = torch.copy.to_tensor %15 : !torch.tensor<[63,1,1,48,15],f32>
  %17 = torch.aten.arange.start_step %int0, %int63, %int1, %none, %none, %cpu, %false : !torch.int, !torch.int, !torch.int, !torch.none, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[63],si64>
  %18 = torch.aten.unsqueeze %17, %int1 : !torch.vtensor<[63],si64>, !torch.int -> !torch.vtensor<[63,1],si64>
  %19 = torch.aten.unsqueeze %arg1, %int0 : !torch.vtensor<[1],si64>, !torch.int -> !torch.vtensor<[1,1],si64>
  %20 = torch.aten.lt.Tensor %18, %19 : !torch.vtensor<[63,1],si64>, !torch.vtensor<[1,1],si64> -> !torch.vtensor<[63,1],i1>
  %21 = torch.aten.cumsum %20, %int0, %none : !torch.vtensor<[63,1],i1>, !torch.int, !torch.none -> !torch.vtensor<[63,1],si64>
  %22 = torch.aten.sub.Tensor %21, %0, %int1 : !torch.vtensor<[63,1],si64>, !torch.vtensor<[],si64>, !torch.int -> !torch.vtensor<[63,1],si64>
  %23 = torch.prim.NumToTensor.Scalar %int63 : !torch.int -> !torch.vtensor<[],si64>
  %24 = torch.aten.bitwise_not %20 : !torch.vtensor<[63,1],i1> -> !torch.vtensor<[63,1],i1>
  %25 = torch.prim.ListConstruct %24 : (!torch.vtensor<[63,1],i1>) -> !torch.list<optional<vtensor>>
  %26 = torch.aten._index_put_impl %22, %25, %23, %false, %false : !torch.vtensor<[63,1],si64>, !torch.list<optional<vtensor>>, !torch.vtensor<[],si64>, !torch.bool, !torch.bool -> !torch.vtensor<[63,1],si64>
  %27 = torch.prim.NumToTensor.Scalar %int1 : !torch.int -> !torch.vtensor<[],si64>
  %28 = torch.prim.NumToTensor.Scalar %int48 : !torch.int -> !torch.vtensor<[],si64>
  %29 = torch.prim.NumToTensor.Scalar %int15 : !torch.int -> !torch.vtensor<[],si64>
  %30 = torch.prim.ListConstruct %int-1, %int1, %int48, %int15 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
  %31 = torch.aten.view %16, %30 : !torch.tensor<[63,1,1,48,15],f32>, !torch.list<int> -> !torch.tensor<[63,1,48,15],f32>
  %32 = torch.prim.ListConstruct %20 : (!torch.vtensor<[63,1],i1>) -> !torch.list<vtensor>
  %33 = torch.aten.index.Tensor_hacked_twin %26, %32 : !torch.vtensor<[63,1],si64>, !torch.list<vtensor> -> !torch.vtensor<[63],si64>
  %34 = torch.copy.to_vtensor %31 : !torch.vtensor<[63,1,48,15],f32>
  %35 = torch.prim.ListConstruct %33 : (!torch.vtensor<[63],si64>) -> !torch.list<optional<vtensor>>
  %36 = torch.aten._index_put_impl %34, %35, %9, %false, %false : !torch.vtensor<[63,1,48,15],f32>, !torch.list<optional<vtensor>>, !torch.vtensor<[63,1,48,15],f32>, !torch.bool, !torch.bool -> !torch.vtensor<[63,1,48,15],f32>
  torch.overwrite.tensor.contents %36 overwrites %31 : !torch.vtensor<[63,1,48,15],f32>, !torch.tensor<[63,1,48,15],f32>
  %37 = torch.aten.transpose.int %16, %int0, %int1 : !torch.tensor<[63,1,1,48,15],f32>, !torch.int, !torch.int -> !torch.tensor<[1,63,1,48,15],f32>
  %38 = torch.copy.to_vtensor %37 : !torch.vtensor<[1,63,1,48,15],f32>
  return %38 : !torch.vtensor<[1,63,1,48,15],f32>
}
ramiro050 commented 1 year ago

Yeah, the mutation case here is not currently supported by maximize value semantics. Adding support for it might require a bit of work. A way to circumvent this is to help torch-mlir out a little by getting rid of the aliasing in the Python code. One possible way is to do:

      # Use mask to get the correct placement of packed values
      mask_cumsum = mask.cumsum(dim=0) - 1
      mask_cumsum[~mask] = packed_input.shape[0]  # Use an index that's out of bounds for the invalid places
      flattened_output = padded_output.view(-1, *packed_input.shape[1:])
      flattened_output[mask_cumsum[mask]] = packed_input
+     padded_output = flattened_output.view(shapes)

      if batch_first:
          padded_output = padded_output.transpose(0, 1)

      return padded_output
sdalvi-quic commented 2 months ago

Hi, is there any update on the issue? I was also running into similar issue while lowering GPT2 from torch to linalgIR. @fhossein-quic, @trahman-quic