NVIDIA / TensorRT-Incubator

Experimental projects related to TensorRT
81 stars 12 forks source link

`stablehlo-to-tensorrt` conversion pass doesn't support `stablehlo.reduce` with multiple reduction dims #279

Open farazkh80 opened 1 month ago

farazkh80 commented 1 month ago

This issue is related to tp.mean and tp.var failures when implementing BatchNorm using Tripy for Resnet50 model.

class TPBatchNorm(tp.Module):
    def __init__(self, num_features, eps=1e-5):
        super(TPBatchNorm, self).__init__()
        self.num_features = num_features
        self.eps = eps

        # Initialize learnable parameters (gamma and beta) to the correct shape
        self.gamma = tp.ones((1, num_features,1,1), dtype=tp.float32)
        self.beta = tp.zeros((1, num_features,1,1), dtype=tp.float32)

    def __call__(self, x):
        print(f"x.shape before normalization: {x.shape}")

        # Calculate mean and variance across the batch and spatial dimensions
        mean = tp.mean(x, dim=(0, 2, 3), keepdim=True)
        variance = tp.var(x, dim=(0, 2, 3), keepdim=True)

        print(f"mean.shape: {mean.shape}, variance.shape: {variance.shape}")

        # Normalize the input
        x_normalized = (x - mean) / tp.sqrt(variance + self.eps)

        print(f"x_normalized.shape: {x_normalized.shape}, gamma.shape: {self.gamma.shape}, beta.shape: {self.beta.shape}")

        # Apply the learned scaling (gamma) and shifting (beta)
        x_scaled = self.gamma * x_normalized + self.beta

        print(f"output shape after batchnorm: {x_scaled.shape}")
        return x_scaled

module = TPBatchNorm(num_features=128)
input_shape = [1, 128, 56, 56]
x = tp.ones([1, 128, 56, 56], dtype=tp.float32)

# # Compile the module
start_compile = time.perf_counter()
compiled_module = tp.compile(module, args=[tp.InputInfo(input_shape, dtype=tp.float32)])
print(f"Compilation of {module.__class__.__name__} took {time.perf_counter() - start_compile:.4f} seconds.")

# Forward pass
start_forward = time.perf_counter()
output = module(x)
print(f"Forward pass of {module.__class__.__name__} took {time.perf_counter()-start_forward:.4f} seconds.")

stout (helper prints for shapes)

x.shape before normalization: shape(1, 128, 56, 56)
mean.shape: shape(1, 128, 1, 1), variance.shape: shape(1, 128, 1, 1)
x_normalized.shape: shape(1, 128, 56, 56), gamma.shape: shape(1, 128, 1, 1), beta.shape: shape(1, 128, 1, 1)
output shape after batchnorm: shape(1, 128, 56, 56)

stderr

Traceback (most recent call last):
  File "/tripy/examples/resent50/batchnorm.py", line 40, in <module>
    output = module(x)
  File "/tripy/tripy/function_registry.py", line 358, in wrapper
    return self.find_overload(key, args, kwargs)(*args, **kwargs)
  File "/tripy/tripy/function_registry.py", line 262, in __call__
    return self.func(*args, **kwargs)
  File "/tripy/tripy/backend/api/compile.py", line 192, in compile
    executable = compiler.compile(mlir, flat_ir=flat_ir)
  File "/tripy/tripy/utils/utils.py", line 74, in wrapper
    result = func(*args, **kwargs)
  File "/tripy/tripy/backend/mlir/compiler.py", line 109, in compile
    map_error_to_user_code_and_raise(flat_ir, exc, stderr.decode())
  File "/tripy/tripy/backend/mlir/utils.py", line 513, in map_error_to_user_code_and_raise
    raise_error(
  File "/tripy/tripy/common/exception.py", line 195, in raise_error
    raise TripyException(msg) from None
tripy.common.exception.TripyException: 

--> /tripy/examples/resent50/batchnorm.py:40 in <module>()
      |
   40 | output = module(x)
      | 

MTRTException: InternalError: failed to run compilation on module with symbol name: ins_x_outs_t1392_1

Additional context:
Traceback (most recent call last):
  File "/tripy/tripy/backend/mlir/compiler.py", line 102, in compile
    executable = compiler.compiler_stablehlo_to_executable(
mlir_tensorrt.runtime._mlir_libs._api.MTRTException: InternalError: failed to run compilation on module with symbol name: ins_x_outs_t1392_1
.
    Loaded TensorRT version 10.5.0.18 but compiled for TensorRT 10.2.0.19. This can result in crashes or unintended behavior.
    (t8,t465)): error: op: %13 = "stablehlo.add"(%arg3, %arg4) : (tensor<f32>, tensor<f32>) -> tensor<f32> from function main is invalid, post clustering.
    (t8,t465)): error: op: "stablehlo.return"(%13) : (tensor<f32>) -> () from function main is invalid, post clustering.
    (t8,t465)): error: op: 
    %1 = "stablehlo.reduce"(%arg0, %0) <{dimensions = array<i64: 0, 2, 3>}> ({
    ^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
      %13 = "stablehlo.add"(%arg3, %arg4) : (tensor<f32>, tensor<f32>) -> tensor<f32>
      "stablehlo.return"(%13) : (tensor<f32>) -> ()
    }) : (tensor<1x128x56x56xf32>, tensor<f32>) -> tensor<128xf32> from function main is invalid, post clustering.
    (t8,t465)): error: op: %12 = "stablehlo.add"(%arg1, %arg2) : (tensor<f32>, tensor<f32>) -> tensor<f32> from function main is invalid, post clustering.
    (t8,t465)): error: op: "stablehlo.return"(%12) : (tensor<f32>) -> () from function main is invalid, post clustering.
    (t8,t465)): error: op: 
    %7 = "stablehlo.reduce"(%6#1, %0) <{dimensions = array<i64: 0, 2, 3>}> ({
    ^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>):
      %12 = "stablehlo.add"(%arg1, %arg2) : (tensor<f32>, tensor<f32>) -> tensor<f32>
      "stablehlo.return"(%12) : (tensor<f32>) -> ()
    }) : (tensor<1x128x56x56xf32>, tensor<f32>) -> tensor<128xf32> from function main is invalid, post clustering.

    This error occured while trying to compile the following FlatIR expression:
          |
          | t_inter9: [rank=(1), dtype=(float32), loc=(gpu:0)] = ReduceOp(t_inter8, t_inter10, reduce_mode='sum', reduce_dims=[0, 2, 3])
          | 

    This operation was introduced to Cloning tensor t8: [rank=(1), dtype=(float32), loc=(gpu:0)] for function input/output.

    Note: This originated from the following expression:

    --> /tripy/tripy/frontend/trace/ops/reduce.py:174 in sum()
          |
      174 |     return _reduce_impl(input, Reduce.Kind.SUM, dim, keepdim)
          |            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

    --> /tripy/tripy/frontend/trace/ops/reduce.py:318 in mean_impl()
          |
      318 |     sum_val = sum(tensor, dim=dim, keepdim=keepdim)
          |               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ --- required from here

    --> /tripy/tripy/frontend/trace/ops/reduce.py:361 in mean()
          |
      361 |     return mean_impl(input, dim, keepdim)
          |            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ --- required from here

    --> /tripy/examples/resent50/batchnorm.py:18 in __call__()
          |
       18 |         mean = tp.mean(x, dim=(0, 2, 3), keepdim=True)
          |                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ --- required from here

    --> /tripy/tripy/frontend/trace/ops/reduce.py:174 in sum()
          |
      174 |     return _reduce_impl(input, Reduce.Kind.SUM, dim, keepdim)
          |            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

    --> /tripy/tripy/frontend/trace/ops/reduce.py:318 in mean_impl()
          |
      318 |     sum_val = sum(tensor, dim=dim, keepdim=keepdim)
          |               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ --- required from here

    --> /tripy/tripy/frontend/trace/ops/reduce.py:361 in mean()
          |
      361 |     return mean_impl(input, dim, keepdim)
          |            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ --- required from here

    --> /tripy/tripy/frontend/trace/ops/reduce.py:408 in var()
          |
      408 |     mean_val = mean(input, dim=dim, keepdim=dim is not None)
          |                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ --- required from here

    --> /tripy/examples/resent50/batchnorm.py:19 in __call__()
          |
       19 |         variance = tp.var(x, dim=(0, 2, 3), keepdim=True)
          |                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ --- required from here

    Input 0:

    --> /tripy/tripy/frontend/utils.py:455 in wrapper()
          |
      455 |             return func(*new_args, **new_kwargs)
          | 

    --> /tripy/tripy/frontend/trace/ops/fill.py:141 in full()
          |
      141 |     return full_impl(shape, value, dtype, output_rank)
          | 

    --> /tripy/tripy/backend/api/compile.py:143 in process_arg()
          |
      143 |             tensor = full(shape=arg.shape_bounds.opt, value=init_value, dtype=arg.dtype)
          |                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ --- required from here

    --> /tripy/tripy/backend/api/compile.py:155 in compile()
          |
      155 |         new_args.append(process_arg(name, arg))
          |                         ^^^^^^^^^^^^^^^^^^^^^^ --- required from here

    --> /tripy/examples/resent50/batchnorm.py:40 in <module>()
          |
       40 | output = module(x)
          | 

    Input 1:

    --> /tripy/tripy/frontend/utils.py:455 in wrapper()
          |
      455 |             return func(*new_args, **new_kwargs)
          | 

    --> /tripy/tripy/frontend/trace/ops/fill.py:141 in full()
          |
      141 |     return full_impl(shape, value, dtype, output_rank)
          | 

    --> /tripy/tripy/backend/api/compile.py:143 in process_arg()
          |
      143 |             tensor = full(shape=arg.shape_bounds.opt, value=init_value, dtype=arg.dtype)
          |                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ --- required from here

    --> /tripy/tripy/backend/api/compile.py:155 in compile()
          |
      155 |         new_args.append(process_arg(name, arg))
          |                         ^^^^^^^^^^^^^^^^^^^^^^ --- required from here

    --> /tripy/examples/resent50/batchnorm.py:40 in <module>()
          |
       40 | output = module(x)
          | 

mlir dumps

tripy-mlir-batchnorm.zip

seems like tp.mean and tp.variance reduction operation failure at MLIR compile.

farazkh80 commented 1 month ago

new finding:tp.mean only fails if we skip a dimension @parthchadha

Example

x = tp.reshape(tp.arange(12), (2,3,2))

then if we do

>>> tp.mean(x, dim=[0], keepdim=True)
tensor(
    [[[3.0000, 4.0000],
      [5.0000, 6.0000],
      [7.0000, 8.0000]]], 
    dtype=float32, loc=gpu:0, shape=(1, 3, 2))

and then

>>> tp.mean(x, dim=[0,1], keepdim=True)
tensor(
    [[[5.0000, 6.0000]]], 
    dtype=float32, loc=gpu:0, shape=(1, 1, 2))

and even

>>> tp.mean(x, dim=[0,1,2], keepdim=True)
tensor([[[5.5000]]], dtype=float32, loc=gpu:0, shape=(1, 1, 1))

but if you skip a dim

>>> tp.mean(x, dim=[0,2], keepdim=True)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/tripy/tripy/frontend/tensor.py", line 214, in __repr__
    data_list = self.tolist()
  File "/tripy/tripy/frontend/tensor.py", line 195, in tolist
    data_memref = self.eval()
  File "/tripy/tripy/frontend/tensor.py", line 180, in eval
    executable = compiler.compile(mlir, flat_ir=flat_ir)
  File "/tripy/tripy/utils/utils.py", line 74, in wrapper
    result = func(*args, **kwargs)
  File "/tripy/tripy/backend/mlir/compiler.py", line 109, in compile
    map_error_to_user_code_and_raise(flat_ir, exc, stderr.decode())
  File "/tripy/tripy/backend/mlir/utils.py", line 513, in map_error_to_user_code_and_raise
    raise_error(
  File "/tripy/tripy/common/exception.py", line 195, in raise_error
    raise TripyException(msg) from None
tripy.common.exception.TripyException: 

--> <stdin>:1 in <module>()

MTRTException: InternalError: failed to run compilation on module with symbol name: outs_t2231_12

Additional context:
Traceback (most recent call last):
  File "/tripy/tripy/backend/mlir/compiler.py", line 102, in compile
    executable = compiler.compiler_stablehlo_to_executable(
mlir_tensorrt.runtime._mlir_libs._api.MTRTException: InternalError: failed to run compilation on module with symbol name: outs_t2231_12
.
    (t1926)): error: op: %7 = "stablehlo.add"(%arg0, %arg1) : (tensor<f32>, tensor<f32>) -> tensor<f32> from function main is invalid, post clustering.
    (t1926)): error: op: "stablehlo.return"(%7) : (tensor<f32>) -> () from function main is invalid, post clustering.
    (t1926)): error: op: 
    %2 = "stablehlo.reduce"(%1, %0) <{dimensions = array<i64: 0, 2>}> ({
    ^bb0(%arg0: tensor<f32>, %arg1: tensor<f32>):
      %7 = "stablehlo.add"(%arg0, %arg1) : (tensor<f32>, tensor<f32>) -> tensor<f32>
      "stablehlo.return"(%7) : (tensor<f32>) -> ()
    }) : (tensor<2x3x2xf32>, tensor<f32>) -> tensor<3xf32> from function main is invalid, post clustering.

    This error occured while trying to compile the following FlatIR expression:
          |
          | t_inter4: [rank=(1), dtype=(float32), loc=(gpu:0)] = ReduceOp(t_inter3, t_inter5, reduce_mode='sum', reduce_dims=[0, 2])
          | 

    This operation was introduced to Cloning tensor t1926: [rank=(1), dtype=(float32), loc=(gpu:0)] for function input/output.

    Note: This originated from the following expression:

    --> /tripy/tripy/frontend/trace/ops/reduce.py:174 in sum()
          |
      174 |     return _reduce_impl(input, Reduce.Kind.SUM, dim, keepdim)
          |            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

    --> /tripy/tripy/frontend/trace/ops/reduce.py:318 in mean_impl()
          |
      318 |     sum_val = sum(tensor, dim=dim, keepdim=keepdim)
          |               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ --- required from here

    --> /tripy/tripy/frontend/trace/ops/reduce.py:361 in mean()
          |
      361 |     return mean_impl(input, dim, keepdim)
          |            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ --- required from here

    Input 0:

    --> /tripy/tripy/frontend/utils.py:455 in wrapper()
          |
      455 |             return func(*new_args, **new_kwargs)
          | 

    --> /tripy/tripy/frontend/trace/ops/reshape.py:145 in reshape()
          |
      145 |     return reshape_impl(input, shape, len(shape), output_len)
          |
farazkh80 commented 1 month ago

I was able to around this to implement bathnorm as follows

# Transpose the channel dimension (dim 1) with the batch dimension (dim 0)
x_transposed = tp.transpose(x, 0, 1)

# Reshape to combine the batch and spatial dimensions
C, N, H, W = x_transposed.shape
x_reshaped = tp.reshape(x_transposed, (C, N * H * W))

# Calculate mean and variance across the merged dimensions for each channel (C)
mean = tp.mean(x_reshaped, dim=1, keepdim=True)
variance = tp.var(x_reshaped, dim=1, keepdim=True)
mean = tp.reshape(mean, (1, C, 1, 1))
variance = tp.reshape(variance, (1, C, 1, 1))

# Transpose back to the original shape
x_transposed_back = tp.transpose(x_transposed, 0, 1)

# Normalize the input
x_normalized = (x_transposed_back - mean) / tp.sqrt(variance + self.eps)

# Apply the learned scaling (gamma) and shifting (beta)
x_scaled = self.gamma * x_normalized + self.beta

return x_scaled
christopherbate commented 2 weeks ago

TensorRT doesn't actually support doing reduction across multiple dimensions.

In MLIR-TRT don't do anything special to work around this limitations. We would have to add a transformation that decomposes stablehlo.reduce with multiple dimensions into a sequence of single-dim reductions or via reshape+reduce+reshape.

pranavm-nvidia commented 2 weeks ago

@christopherbate TRT does support reduction across multiple dimensions - the axes parameter is a bitset. MLIR-TRT also seems to support this in the case where the reduces axes are contiguous (see #297 for an example of multiple contiguous axes working).

It seems to be only when we skip over certain dimensions that compilation fails. TRT should work fine with skipped dimensions, so is this just a lowering bug?

christopherbate commented 2 weeks ago

@pranavm-nvidia

I'm aware the axes parameter is a bitset, but IIRC if you actually tried to reduce multiple dimensions, tRT will return an error.

This is the reason we have the current restriction, although maybe it has been lifted since TRT 8. We would need to confirm. Right now in StableHLO-to-TRT, we only convert single-axis reductions. The preprocessing pipelien tries to take care of flattening in the case the reduction is over multiple dims IIRC.