NVIDIA / TensorRT-Incubator

Experimental projects related to TensorRT
69 stars 11 forks source link

Concatenating a shape tensor using the `*` op does not work with fill #207

Open pranavm-nvidia opened 3 weeks ago

pranavm-nvidia commented 3 weeks ago

The * op is supposed to concatenate a shape tensor to itself. However, if we use the resulting shape tensor with a fill ops, we get a failure:

>>> a = tp.Tensor([1, 2, 3])
>>> tp.ones(a.shape[-1:] * 2)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
...
  File "/tripy/tripy/frontend/trace/ops/base.py", line 52, in build_internal
    op.infer_rank()
  File "/tripy/tripy/frontend/trace/ops/fill.py", line 53, in infer_rank
    input_shape[0] >= 0
AssertionError: Incorrect shape of shape tensor, expected shape to be positive, got -9223372036854775808
pranavm-nvidia commented 3 weeks ago

Possibly related to https://github.com/NVIDIA/TensorRT-Incubator/issues/206

slyubomirsky commented 3 weeks ago

Noting that this works if the product is forced to be evaluated first. This suggests it has to do with the lowering into MLIR.

a = tp.Tensor([1, 2, 3])
s = a.shape[-1:]*2
print(s) # removing the print leads to the error above
tp.ones(s) 

The location of the exception suggests that the incorrect value comes from op_utils.get_trace_shape(self.inputs[0]), which in turn calls into the shape context:

if input.shape is None:
    from tripy.backend.mlir.utils import ShapeContext

    # memoize while we're at it
    input.shape = ShapeContext().get_shape_of_dynamic_trace_tensor(input)
return input.shape
slyubomirsky commented 2 weeks ago

Note that no MLIR is dumped by running this example, so I would expect that the error has to do with how the shape context works. I will look further into that.

slyubomirsky commented 2 weeks ago

Here is some of the MLIR created inside ShapeContext (there is much more). It is massive and very repetitive, so I wonder if there are some ways for us to reduce it. Some of the code corresponds to the lowering for the slice, while other parts correspond to the ones operator.

module @ins_t94_outs_t414_3 {
  func.func @main(%arg0: tensor<i32>) -> tensor<?xi32> {
    %c = stablehlo.constant dense<[1, 2, 3]> : tensor<3xi32>
    %c_0 = stablehlo.constant dense<3> : tensor<i32>
    %c_1 = stablehlo.constant dense<1> : tensor<1xi32>
    %c_2 = stablehlo.constant dense<3> : tensor<1xi32>
    %c_3 = stablehlo.constant dense<1> : tensor<i32>
    %c_4 = stablehlo.constant dense<1> : tensor<1xi32>
    %c_5 = stablehlo.constant dense<0> : tensor<i32>
    %c_6 = stablehlo.constant dense<1> : tensor<i32>
    %c_7 = stablehlo.constant dense<0> : tensor<1xi32>
    %c_8 = stablehlo.constant dense<1> : tensor<1xi32>
    %0 = stablehlo.compare  LE, %c_7, %c_8 : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1>
    %1 = stablehlo.select %0, %c_7, %c_8 : tensor<1xi1>, tensor<1xi32>
    %c_9 = stablehlo.constant dense<1> : tensor<1xi32>
    %2 = stablehlo.real_dynamic_slice %c_4, %1, %c_8, %c_9 : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<?xi32>
    %c_10 = stablehlo.constant dense<> : tensor<0xi32>
    %3 = stablehlo.dynamic_reshape %2, %c_10 : (tensor<?xi32>, tensor<0xi32>) -> tensor<i32>
    %c_11 = stablehlo.constant dense<-1> : tensor<i32>
    %c_12 = stablehlo.constant dense<> : tensor<0xi32>
    %4 = stablehlo.compare  EQ, %c_12, %c_10 : (tensor<0xi32>, tensor<0xi32>) -> tensor<0xi1>
    %5 = stablehlo.select %4, %c_12, %c_12 : tensor<0xi1>, tensor<0xi32>
    %6 = stablehlo.dynamic_broadcast_in_dim %3, %5, dims = [] : (tensor<i32>, tensor<0xi32>) -> tensor<i32>
    %7 = stablehlo.dynamic_broadcast_in_dim %c_11, %5, dims = [] : (tensor<i32>, tensor<0xi32>) -> tensor<i32>
    %8 = stablehlo.add %6, %7 : tensor<i32>
    %9 = stablehlo.compare  EQ, %c_12, %c_10 : (tensor<0xi32>, tensor<0xi32>) -> tensor<0xi1>
    %10 = stablehlo.select %9, %c_12, %c_12 : tensor<0xi1>, tensor<0xi32>
    %11 = stablehlo.dynamic_broadcast_in_dim %8, %10, dims = [] : (tensor<i32>, tensor<0xi32>) -> tensor<i32>
    %12 = stablehlo.dynamic_broadcast_in_dim %c_5, %10, dims = [] : (tensor<i32>, tensor<0xi32>) -> tensor<i32>
    %13 = stablehlo.compare  LT, %11, %12 : (tensor<i32>, tensor<i32>) -> tensor<i1>
    %c_13 = stablehlo.constant dense<0> : tensor<1xi32>
    %c_14 = stablehlo.constant dense<0> : tensor<1xi32>
    %c_15 = stablehlo.constant dense<1> : tensor<1xi32>
    %14 = stablehlo.compare  LE, %c_14, %c_15 : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1>
    %15 = stablehlo.select %14, %c_14, %c_15 : tensor<1xi1>, tensor<1xi32>
    %c_16 = stablehlo.constant dense<1> : tensor<1xi32>
    %16 = stablehlo.real_dynamic_slice %c_4, %15, %c_15, %c_16 : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<?xi32>
    %17 = stablehlo.dynamic_reshape %16, %c_10 : (tensor<?xi32>, tensor<0xi32>) -> tensor<i32>
    %18 = stablehlo.compare  EQ, %c_12, %c_10 : (tensor<0xi32>, tensor<0xi32>) -> tensor<0xi1>
    %19 = stablehlo.select %18, %c_12, %c_12 : tensor<0xi1>, tensor<0xi32>
    %20 = stablehlo.dynamic_broadcast_in_dim %17, %19, dims = [] : (tensor<i32>, tensor<0xi32>) -> tensor<i32>
    %21 = stablehlo.dynamic_broadcast_in_dim %8, %19, dims = [] : (tensor<i32>, tensor<0xi32>) -> tensor<i32>
    %22 = stablehlo.compare  LT, %20, %21 : (tensor<i32>, tensor<i32>) -> tensor<i1>
    %c_17 = stablehlo.constant dense<0> : tensor<1xi32>
    %c_18 = stablehlo.constant dense<1> : tensor<1xi32>
    %23 = stablehlo.compare  LE, %c_17, %c_18 : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1>
    %24 = stablehlo.select %23, %c_17, %c_18 : tensor<1xi1>, tensor<1xi32>
    %c_19 = stablehlo.constant dense<1> : tensor<1xi32>
    %25 = stablehlo.real_dynamic_slice %c_4, %24, %c_18, %c_19 : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<?xi32>
    [ About 600 lines elided ]
    %635 = stablehlo.compare  EQ, %634, %c_1 : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1>
    %c_146 = stablehlo.constant dense<1> : tensor<i32>
    %c_147 = stablehlo.constant dense<1> : tensor<1xi32>
    %636 = stablehlo.compare  EQ, %c_147, %c_1 : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1>
    %637 = stablehlo.get_dimension_size %630, dim = 0 : (tensor<?xi32>) -> tensor<i32>
    %638 = stablehlo.reshape %637 : (tensor<i32>) -> tensor<1xi32>
    %639 = stablehlo.select %636, %638, %c_147 : tensor<1xi1>, tensor<1xi32>
    %640 = stablehlo.select %635, %639, %634 : tensor<1xi1>, tensor<1xi32>
    %641 = stablehlo.dynamic_broadcast_in_dim %591, %640, dims = [0] : (tensor<?xi1>, tensor<1xi32>) -> tensor<?xi1>
    %642 = stablehlo.dynamic_broadcast_in_dim %c_13, %640, dims = [0] : (tensor<1xi32>, tensor<1xi32>) -> tensor<?xi32>
    %643 = stablehlo.dynamic_broadcast_in_dim %630, %640, dims = [0] : (tensor<?xi32>, tensor<1xi32>) -> tensor<?xi32>
    %644 = stablehlo.select %641, %642, %643 : tensor<?xi1>, tensor<?xi32>
    %645 = stablehlo.reshape %535 : (tensor<?xi32>) -> tensor<1xi32>
    %646 = stablehlo.reshape %644 : (tensor<?xi32>) -> tensor<1xi32>
    %647 = stablehlo.compare  LE, %645, %646 : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1>
    %648 = stablehlo.select %647, %645, %646 : tensor<1xi1>, tensor<1xi32>
    %c_148 = stablehlo.constant dense<1> : tensor<1xi32>
    %649 = stablehlo.real_dynamic_slice %502, %648, %646, %c_148 : (tensor<2xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<0xi32>
    %650 = stablehlo.concatenate %497, %649, dim = 0 : (tensor<1xi32>, tensor<0xi32>) -> tensor<1xi32>
    %651 = stablehlo.dynamic_reshape %391, %650 : (tensor<?x?xi32>, tensor<1xi32>) -> tensor<?xi32>
    return %651 : tensor<?xi32>