Open JBloodless opened 1 year ago
Hi @JBloodless, do you have an example of calling how you're calling pad_packed_squence
and compiling it with torch-mlir? I can take a look. You're right that torch.zeros
should work.
Hi @JBloodless, do you have an example of calling how you're calling
pad_packed_squence
and compiling it with torch-mlir? I can take a look. You're right thattorch.zeros
should work.
Here's full repro:
import torch
import torch.nn as nn
import torch_mlir
def pack_padded_sequence(input, lengths, batch_first=False):
if batch_first:
input = input.transpose(0, 1)
sorted_lengths, indices = torch.sort(lengths, descending=True)
sorted_input = input.index_select(1, indices.to(input.device))
# Create a mask for valid entries in the sorted_input tensor
mask = torch.arange(sorted_input.size(0)).unsqueeze(1) < sorted_lengths.unsqueeze(0)
mask = mask.to(input.device)
# Flatten the tensor and mask out invalid entries
packed_output = sorted_input[mask]
return packed_output
def pad_packed_sequence(packed_input, lengths, batch_first=False):
max_length = torch.max(lengths).item()
batch_size = lengths.shape[0]
# Create an output tensor filled with zeros
shapes = (max_length, batch_size, *packed_input.shape[1:])
padded_output = torch.zeros(shapes, dtype=packed_input.dtype,
device=packed_input.device)
# Create a mask based on lengths
mask = torch.arange(max_length).unsqueeze(1).to(packed_input.device) < lengths.unsqueeze(0).to(packed_input.device)
# Use mask to get the correct placement of packed values
mask_cumsum = mask.cumsum(dim=0) - 1
mask_cumsum[~mask] = packed_input.shape[0] # Use an index that's out of bounds for the invalid places
flattened_output = padded_output.view(-1, *packed_input.shape[1:])
flattened_output[mask_cumsum[mask]] = packed_input
if batch_first:
padded_output = padded_output.transpose(0, 1)
return padded_output
class model(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x, n_wins):
x = pack_padded_sequence(
x,
n_wins.cpu(),
batch_first=True
)
x = pad_packed_sequence(
x, lengths=n_wins.cpu(),
batch_first=True)
return x
model = model()
model.eval()
x = torch.zeros((63, 1, 48, 15))
n_wins = 63
example_input = (x.unsqueeze(0).float(), torch.as_tensor(n_wins).unsqueeze(0))
mlir = torch_mlir.compile(
model,
example_input,
output_type="linalg-on-tensors",
use_tracing=True)
And full error:
Traceback (most recent call last):
File "/Users/i.beskrovnyy/tts/NISQA-s/repro_mlir.py", line 71, in <module>
mlir = torch_mlir.compile(
^^^^^^^^^^^^^^^^^^^
File "/Users/i.beskrovnyy/anaconda3/envs/mlir/lib/python3.11/site-packages/torch_mlir/__init__.py", line 451, in compile
run_pipeline_with_repro_report(
File "/Users/i.beskrovnyy/anaconda3/envs/mlir/lib/python3.11/site-packages/torch_mlir/compiler_utils.py", line 69, in run_pipeline_with_repro_report
raise TorchMlirCompilerError(trimmed_message) from None
torch_mlir.compiler_utils.TorchMlirCompilerError: Lowering TorchScript IR -> Torch Backend IR failed with the following diagnostics:
python exception: Failure while executing pass pipeline:
error: "aten::zeros"("/Users/i.beskrovnyy/tts/NISQA-s/repro_mlir.py":29:0): unsupported by backend contract: non-value tensor type
note: "aten::zeros"("/Users/i.beskrovnyy/tts/NISQA-s/repro_mlir.py":29:0): see current operation: %28 = "torch.copy.to_tensor"(%27) : (!torch.vtensor<[63,1,1,48,15],f32>) -> !torch.tensor<[63,1,1,48,15],f32>
note: "aten::zeros"("/Users/i.beskrovnyy/tts/NISQA-s/repro_mlir.py":29:0): this is likely due to a missing case in the MaximizeValueSemantics pass
For Torch-MLIR developers, the error can be reproduced with:
$ torch-mlir-opt -pass-pipeline='builtin.module(torchscript-module-to-torch-backend-pipeline{backend-legal-ops=aten.flatten.using_ints extra-library=})' /var/folders/7h/lx81hv9s2gd25ytr1bqg25700000gp/T/model.mlir
Add '-mlir-print-ir-after-all -mlir-disable-threading' to get the IR dump for debugging purpose.
I believe that the MaximizeValueSemantics
pass is not working as expected here. %8
is not being converted to value tensors because of torch.overwrite.tensor.contents
operation which indirectly uses %8
. Overwrite statements are not supported by the pass in the way the viewLikeOps
and returnOps
are.
// -----// IR Dump After MaximizeValueSemantics (torch-maximize-value-semantics) //----- //
func.func @forward(%arg0: !torch.vtensor<[1,63,1,48,15],f32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[1,63,1,48,15],f32> {
%false = torch.constant.bool false
%int0 = torch.constant.int 0
%int1 = torch.constant.int 1
%int63 = torch.constant.int 63
%0 = torch.vtensor.literal(dense<1> : tensor<si64>) : !torch.vtensor<[],si64>
%none = torch.constant.none
%int-1 = torch.constant.int -1
%true = torch.constant.bool true
%int6 = torch.constant.int 6
%int15 = torch.constant.int 15
%int48 = torch.constant.int 48
%cpu = torch.constant.device "cpu"
%1 = torch.aten.transpose.int %arg0, %int0, %int1 : !torch.vtensor<[1,63,1,48,15],f32>, !torch.int, !torch.int -> !torch.vtensor<[63,1,1,48,15],f32>
%values, %indices = torch.aten.sort %arg1, %int-1, %true : !torch.vtensor<[1],si64>, !torch.int, !torch.bool -> !torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>
%2 = torch.aten.index_select %1, %int1, %indices : !torch.vtensor<[63,1,1,48,15],f32>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[63,1,1,48,15],f32>
%3 = torch.prim.NumToTensor.Scalar %int63 : !torch.int -> !torch.vtensor<[],si64>
%4 = torch.aten.arange.start_step %int0, %int63, %int1, %none, %none, %cpu, %false : !torch.int, !torch.int, !torch.int, !torch.none, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[63],si64>
%5 = torch.aten.unsqueeze %4, %int1 : !torch.vtensor<[63],si64>, !torch.int -> !torch.vtensor<[63,1],si64>
%6 = torch.aten.unsqueeze %values, %int0 : !torch.vtensor<[1],si64>, !torch.int -> !torch.vtensor<[1,1],si64>
%7 = torch.aten.lt.Tensor %5, %6 : !torch.vtensor<[63,1],si64>, !torch.vtensor<[1,1],si64> -> !torch.vtensor<[63,1],i1>
%8 = torch.prim.ListConstruct %7 : (!torch.vtensor<[63,1],i1>) -> !torch.list<vtensor>
%9 = torch.aten.index.Tensor_hacked_twin %2, %8 : !torch.vtensor<[63,1,1,48,15],f32>, !torch.list<vtensor> -> !torch.vtensor<[63,1,48,15],f32>
%10 = torch.prim.NumToTensor.Scalar %int1 : !torch.int -> !torch.vtensor<[],si64>
%11 = torch.prim.NumToTensor.Scalar %int1 : !torch.int -> !torch.vtensor<[],si64>
%12 = torch.prim.NumToTensor.Scalar %int48 : !torch.int -> !torch.vtensor<[],si64>
%13 = torch.prim.NumToTensor.Scalar %int15 : !torch.int -> !torch.vtensor<[],si64>
%14 = torch.prim.ListConstruct %int63, %int1, %int1, %int48, %int15 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
%15 = torch.aten.zeros %14, %int6, %none, %cpu, %false : !torch.list<int>, !torch.int, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[63,1,1,48,15],f32>
%16 = torch.copy.to_tensor %15 : !torch.tensor<[63,1,1,48,15],f32>
%17 = torch.aten.arange.start_step %int0, %int63, %int1, %none, %none, %cpu, %false : !torch.int, !torch.int, !torch.int, !torch.none, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[63],si64>
%18 = torch.aten.unsqueeze %17, %int1 : !torch.vtensor<[63],si64>, !torch.int -> !torch.vtensor<[63,1],si64>
%19 = torch.aten.unsqueeze %arg1, %int0 : !torch.vtensor<[1],si64>, !torch.int -> !torch.vtensor<[1,1],si64>
%20 = torch.aten.lt.Tensor %18, %19 : !torch.vtensor<[63,1],si64>, !torch.vtensor<[1,1],si64> -> !torch.vtensor<[63,1],i1>
%21 = torch.aten.cumsum %20, %int0, %none : !torch.vtensor<[63,1],i1>, !torch.int, !torch.none -> !torch.vtensor<[63,1],si64>
%22 = torch.aten.sub.Tensor %21, %0, %int1 : !torch.vtensor<[63,1],si64>, !torch.vtensor<[],si64>, !torch.int -> !torch.vtensor<[63,1],si64>
%23 = torch.prim.NumToTensor.Scalar %int63 : !torch.int -> !torch.vtensor<[],si64>
%24 = torch.aten.bitwise_not %20 : !torch.vtensor<[63,1],i1> -> !torch.vtensor<[63,1],i1>
%25 = torch.prim.ListConstruct %24 : (!torch.vtensor<[63,1],i1>) -> !torch.list<optional<vtensor>>
%26 = torch.aten._index_put_impl %22, %25, %23, %false, %false : !torch.vtensor<[63,1],si64>, !torch.list<optional<vtensor>>, !torch.vtensor<[],si64>, !torch.bool, !torch.bool -> !torch.vtensor<[63,1],si64>
%27 = torch.prim.NumToTensor.Scalar %int1 : !torch.int -> !torch.vtensor<[],si64>
%28 = torch.prim.NumToTensor.Scalar %int48 : !torch.int -> !torch.vtensor<[],si64>
%29 = torch.prim.NumToTensor.Scalar %int15 : !torch.int -> !torch.vtensor<[],si64>
%30 = torch.prim.ListConstruct %int-1, %int1, %int48, %int15 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
%31 = torch.aten.view %16, %30 : !torch.tensor<[63,1,1,48,15],f32>, !torch.list<int> -> !torch.tensor<[63,1,48,15],f32>
%32 = torch.prim.ListConstruct %20 : (!torch.vtensor<[63,1],i1>) -> !torch.list<vtensor>
%33 = torch.aten.index.Tensor_hacked_twin %26, %32 : !torch.vtensor<[63,1],si64>, !torch.list<vtensor> -> !torch.vtensor<[63],si64>
%34 = torch.copy.to_vtensor %31 : !torch.vtensor<[63,1,48,15],f32>
%35 = torch.prim.ListConstruct %33 : (!torch.vtensor<[63],si64>) -> !torch.list<optional<vtensor>>
%36 = torch.aten._index_put_impl %34, %35, %9, %false, %false : !torch.vtensor<[63,1,48,15],f32>, !torch.list<optional<vtensor>>, !torch.vtensor<[63,1,48,15],f32>, !torch.bool, !torch.bool -> !torch.vtensor<[63,1,48,15],f32>
torch.overwrite.tensor.contents %36 overwrites %31 : !torch.vtensor<[63,1,48,15],f32>, !torch.tensor<[63,1,48,15],f32>
%37 = torch.aten.transpose.int %16, %int0, %int1 : !torch.tensor<[63,1,1,48,15],f32>, !torch.int, !torch.int -> !torch.tensor<[1,63,1,48,15],f32>
%38 = torch.copy.to_vtensor %37 : !torch.vtensor<[1,63,1,48,15],f32>
return %38 : !torch.vtensor<[1,63,1,48,15],f32>
}
Yeah, the mutation case here is not currently supported by maximize value semantics. Adding support for it might require a bit of work. A way to circumvent this is to help torch-mlir out a little by getting rid of the aliasing in the Python code. One possible way is to do:
# Use mask to get the correct placement of packed values
mask_cumsum = mask.cumsum(dim=0) - 1
mask_cumsum[~mask] = packed_input.shape[0] # Use an index that's out of bounds for the invalid places
flattened_output = padded_output.view(-1, *packed_input.shape[1:])
flattened_output[mask_cumsum[mask]] = packed_input
+ padded_output = flattened_output.view(shapes)
if batch_first:
padded_output = padded_output.transpose(0, 1)
return padded_output
Hi, is there any update on the issue? I was also running into similar issue while lowering GPT2 from torch to linalgIR. @fhossein-quic, @trahman-quic
Hello. I’m trying to convert my model to MLIR and I’m stuck with this error:
I’m using many custom functions (because torch_mlir doesn’t support many parts of this model, like LSTMs and adaptive_max_pool2d), and this is the part of such custom function:
It mimics
pad_packed_sequence
fromtorch.nn.utils.rnn
. Any idea why this happens and how to avoid it? I’m not sure that I can get rid of torch.zeros here.As far as I know, torch.zeros and torch.empty are supported with #2440 and #604, so I have no idea why this shouldn't work.