apache / tvm

Open deep learning compiler stack for cpu, gpu and specialized accelerators
https://tvm.apache.org/
Apache License 2.0
11.84k stars 3.48k forks source link

[Bug] Possible error in strided_slice_checker #14142

Closed AreopagX closed 1 year ago

AreopagX commented 1 year ago

Hello, I'm trying to use TensorRT for tuning my model. For that matter, I use partition_for_tensorrt before building the model. However, this function fails when processing strided_slice ops. This might be related to strided_slice ops, where fewer axes than tensor dimensions are sliced, e.g., %435 = strided_slice(%376, begin=[0i64], end=[1i64], strides=[1i64], axes=[0i64]) with input tensors of 4 dims. I suspect that these ops will fail since the for loop https://github.com/apache/tvm/blob/7e3dc45fed385e12f8cfd9c9284b5551ada82f95/python/tvm/relay/op/contrib/tensorrt.py#L712-L713 iterates over the number of input dimensions while the strided_slice op does not necessarily cover all dimensions, e.g., in the above call only the 0th axis.

Expected behavior

I expected the partition_for_tensorrt function to return without error.

Actual behavior

IndexError in strided_slice_checker: https://github.com/apache/tvm/blob/7e3dc45fed385e12f8cfd9c9284b5551ada82f95/python/tvm/relay/op/contrib/tensorrt.py#L713, error thrown at: https://github.com/apache/tvm/blob/7e3dc45fed385e12f8cfd9c9284b5551ada82f95/python/tvm/runtime/container.py#L57

Environment

OS: Ubuntu 20.04, Kernel 5.15 TVM: built from master branch, commit 0e046daf9e51724b3910aa7ba199069b09e2707e

Triage

masahi commented 1 year ago

You're right, the axes argument in strided_slices was added later. So when the TRT backend was developed, it was assumed that the size of begin / end is always equal to the rank of the input.

Can you send a PR to fix this?