NVIDIA / TensorRT-Incubator

Experimental projects related to TensorRT
72 stars 12 forks source link

BatchNorm MLIR compile failure #279

Open farazkh80 opened 4 days ago

farazkh80 commented 4 days ago

This issue is related to tp.mean and tp.var failures when implementing BatchNorm using Tripy for Resnet50 model.

class TPBatchNorm(tp.Module):
    def __init__(self, num_features, eps=1e-5):
        super(TPBatchNorm, self).__init__()
        self.num_features = num_features
        self.eps = eps

        # Initialize learnable parameters (gamma and beta) to the correct shape
        self.gamma = tp.ones((1, num_features,1,1), dtype=tp.float32)
        self.beta = tp.zeros((1, num_features,1,1), dtype=tp.float32)

    def __call__(self, x):
        print(f"x.shape before normalization: {x.shape}")

        # Calculate mean and variance across the batch and spatial dimensions
        mean = tp.mean(x, dim=(0, 2, 3), keepdim=True)
        variance = tp.var(x, dim=(0, 2, 3), keepdim=True)

        print(f"mean.shape: {mean.shape}, variance.shape: {variance.shape}")

        # Normalize the input
        x_normalized = (x - mean) / tp.sqrt(variance + self.eps)

        print(f"x_normalized.shape: {x_normalized.shape}, gamma.shape: {self.gamma.shape}, beta.shape: {self.beta.shape}")

        # Apply the learned scaling (gamma) and shifting (beta)
        x_scaled = self.gamma * x_normalized + self.beta

        print(f"output shape after batchnorm: {x_scaled.shape}")
        return x_scaled

module = TPBatchNorm(num_features=128)
input_shape = [1, 128, 56, 56]
x = tp.ones([1, 128, 56, 56], dtype=tp.float32)

# # Compile the module
start_compile = time.perf_counter()
compiled_module = tp.compile(module, args=[tp.InputInfo(input_shape, dtype=tp.float32)])
print(f"Compilation of {module.__class__.__name__} took {time.perf_counter() - start_compile:.4f} seconds.")

# Forward pass
start_forward = time.perf_counter()
output = module(x)
print(f"Forward pass of {module.__class__.__name__} took {time.perf_counter()-start_forward:.4f} seconds.")

stout (helper prints for shapes)

x.shape before normalization: shape(1, 128, 56, 56)
mean.shape: shape(1, 128, 1, 1), variance.shape: shape(1, 128, 1, 1)
x_normalized.shape: shape(1, 128, 56, 56), gamma.shape: shape(1, 128, 1, 1), beta.shape: shape(1, 128, 1, 1)
output shape after batchnorm: shape(1, 128, 56, 56)

stderr

Traceback (most recent call last):
  File "/tripy/examples/resent50/batchnorm.py", line 40, in <module>
    output = module(x)
  File "/tripy/tripy/function_registry.py", line 358, in wrapper
    return self.find_overload(key, args, kwargs)(*args, **kwargs)
  File "/tripy/tripy/function_registry.py", line 262, in __call__
    return self.func(*args, **kwargs)
  File "/tripy/tripy/backend/api/compile.py", line 192, in compile
    executable = compiler.compile(mlir, flat_ir=flat_ir)
  File "/tripy/tripy/utils/utils.py", line 74, in wrapper
    result = func(*args, **kwargs)
  File "/tripy/tripy/backend/mlir/compiler.py", line 109, in compile
    map_error_to_user_code_and_raise(flat_ir, exc, stderr.decode())
  File "/tripy/tripy/backend/mlir/utils.py", line 513, in map_error_to_user_code_and_raise
    raise_error(
  File "/tripy/tripy/common/exception.py", line 195, in raise_error
    raise TripyException(msg) from None
tripy.common.exception.TripyException: 

--> /tripy/examples/resent50/batchnorm.py:40 in <module>()
      |
   40 | output = module(x)
      | 

MTRTException: InternalError: failed to run compilation on module with symbol name: ins_x_outs_t1392_1

Additional context:
Traceback (most recent call last):
  File "/tripy/tripy/backend/mlir/compiler.py", line 102, in compile
    executable = compiler.compiler_stablehlo_to_executable(
mlir_tensorrt.runtime._mlir_libs._api.MTRTException: InternalError: failed to run compilation on module with symbol name: ins_x_outs_t1392_1
.
    Loaded TensorRT version 10.5.0.18 but compiled for TensorRT 10.2.0.19. This can result in crashes or unintended behavior.
    (t8,t465)): error: op: %13 = "stablehlo.add"(%arg3, %arg4) : (tensor<f32>, tensor<f32>) -> tensor<f32> from function main is invalid, post clustering.
    (t8,t465)): error: op: "stablehlo.return"(%13) : (tensor<f32>) -> () from function main is invalid, post clustering.
    (t8,t465)): error: op: 
    %1 = "stablehlo.reduce"(%arg0, %0) <{dimensions = array<i64: 0, 2, 3>}> ({
    ^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
      %13 = "stablehlo.add"(%arg3, %arg4) : (tensor<f32>, tensor<f32>) -> tensor<f32>
      "stablehlo.return"(%13) : (tensor<f32>) -> ()
    }) : (tensor<1x128x56x56xf32>, tensor<f32>) -> tensor<128xf32> from function main is invalid, post clustering.
    (t8,t465)): error: op: %12 = "stablehlo.add"(%arg1, %arg2) : (tensor<f32>, tensor<f32>) -> tensor<f32> from function main is invalid, post clustering.
    (t8,t465)): error: op: "stablehlo.return"(%12) : (tensor<f32>) -> () from function main is invalid, post clustering.
    (t8,t465)): error: op: 
    %7 = "stablehlo.reduce"(%6#1, %0) <{dimensions = array<i64: 0, 2, 3>}> ({
    ^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>):
      %12 = "stablehlo.add"(%arg1, %arg2) : (tensor<f32>, tensor<f32>) -> tensor<f32>
      "stablehlo.return"(%12) : (tensor<f32>) -> ()
    }) : (tensor<1x128x56x56xf32>, tensor<f32>) -> tensor<128xf32> from function main is invalid, post clustering.

    This error occured while trying to compile the following FlatIR expression:
          |
          | t_inter9: [rank=(1), dtype=(float32), loc=(gpu:0)] = ReduceOp(t_inter8, t_inter10, reduce_mode='sum', reduce_dims=[0, 2, 3])
          | 

    This operation was introduced to Cloning tensor t8: [rank=(1), dtype=(float32), loc=(gpu:0)] for function input/output.

    Note: This originated from the following expression:

    --> /tripy/tripy/frontend/trace/ops/reduce.py:174 in sum()
          |
      174 |     return _reduce_impl(input, Reduce.Kind.SUM, dim, keepdim)
          |            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

    --> /tripy/tripy/frontend/trace/ops/reduce.py:318 in mean_impl()
          |
      318 |     sum_val = sum(tensor, dim=dim, keepdim=keepdim)
          |               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ --- required from here

    --> /tripy/tripy/frontend/trace/ops/reduce.py:361 in mean()
          |
      361 |     return mean_impl(input, dim, keepdim)
          |            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ --- required from here

    --> /tripy/examples/resent50/batchnorm.py:18 in __call__()
          |
       18 |         mean = tp.mean(x, dim=(0, 2, 3), keepdim=True)
          |                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ --- required from here

    --> /tripy/tripy/frontend/trace/ops/reduce.py:174 in sum()
          |
      174 |     return _reduce_impl(input, Reduce.Kind.SUM, dim, keepdim)
          |            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

    --> /tripy/tripy/frontend/trace/ops/reduce.py:318 in mean_impl()
          |
      318 |     sum_val = sum(tensor, dim=dim, keepdim=keepdim)
          |               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ --- required from here

    --> /tripy/tripy/frontend/trace/ops/reduce.py:361 in mean()
          |
      361 |     return mean_impl(input, dim, keepdim)
          |            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ --- required from here

    --> /tripy/tripy/frontend/trace/ops/reduce.py:408 in var()
          |
      408 |     mean_val = mean(input, dim=dim, keepdim=dim is not None)
          |                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ --- required from here

    --> /tripy/examples/resent50/batchnorm.py:19 in __call__()
          |
       19 |         variance = tp.var(x, dim=(0, 2, 3), keepdim=True)
          |                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ --- required from here

    Input 0:

    --> /tripy/tripy/frontend/utils.py:455 in wrapper()
          |
      455 |             return func(*new_args, **new_kwargs)
          | 

    --> /tripy/tripy/frontend/trace/ops/fill.py:141 in full()
          |
      141 |     return full_impl(shape, value, dtype, output_rank)
          | 

    --> /tripy/tripy/backend/api/compile.py:143 in process_arg()
          |
      143 |             tensor = full(shape=arg.shape_bounds.opt, value=init_value, dtype=arg.dtype)
          |                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ --- required from here

    --> /tripy/tripy/backend/api/compile.py:155 in compile()
          |
      155 |         new_args.append(process_arg(name, arg))
          |                         ^^^^^^^^^^^^^^^^^^^^^^ --- required from here

    --> /tripy/examples/resent50/batchnorm.py:40 in <module>()
          |
       40 | output = module(x)
          | 

    Input 1:

    --> /tripy/tripy/frontend/utils.py:455 in wrapper()
          |
      455 |             return func(*new_args, **new_kwargs)
          | 

    --> /tripy/tripy/frontend/trace/ops/fill.py:141 in full()
          |
      141 |     return full_impl(shape, value, dtype, output_rank)
          | 

    --> /tripy/tripy/backend/api/compile.py:143 in process_arg()
          |
      143 |             tensor = full(shape=arg.shape_bounds.opt, value=init_value, dtype=arg.dtype)
          |                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ --- required from here

    --> /tripy/tripy/backend/api/compile.py:155 in compile()
          |
      155 |         new_args.append(process_arg(name, arg))
          |                         ^^^^^^^^^^^^^^^^^^^^^^ --- required from here

    --> /tripy/examples/resent50/batchnorm.py:40 in <module>()
          |
       40 | output = module(x)
          | 

mlir dumps

tripy-mlir-batchnorm.zip

seems like tp.mean and tp.variance reduction operation failure at MLIR compile.

farazkh80 commented 3 days ago

new finding:tp.mean only fails if we skip a dimension @parthchadha

Example

x = tp.reshape(tp.arange(12), (2,3,2))

then if we do

>>> tp.mean(x, dim=[0], keepdim=True)
tensor(
    [[[3.0000, 4.0000],
      [5.0000, 6.0000],
      [7.0000, 8.0000]]], 
    dtype=float32, loc=gpu:0, shape=(1, 3, 2))

and then

>>> tp.mean(x, dim=[0,1], keepdim=True)
tensor(
    [[[5.0000, 6.0000]]], 
    dtype=float32, loc=gpu:0, shape=(1, 1, 2))

and even

>>> tp.mean(x, dim=[0,1,2], keepdim=True)
tensor([[[5.5000]]], dtype=float32, loc=gpu:0, shape=(1, 1, 1))

but if you skip a dim

>>> tp.mean(x, dim=[0,2], keepdim=True)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/tripy/tripy/frontend/tensor.py", line 214, in __repr__
    data_list = self.tolist()
  File "/tripy/tripy/frontend/tensor.py", line 195, in tolist
    data_memref = self.eval()
  File "/tripy/tripy/frontend/tensor.py", line 180, in eval
    executable = compiler.compile(mlir, flat_ir=flat_ir)
  File "/tripy/tripy/utils/utils.py", line 74, in wrapper
    result = func(*args, **kwargs)
  File "/tripy/tripy/backend/mlir/compiler.py", line 109, in compile
    map_error_to_user_code_and_raise(flat_ir, exc, stderr.decode())
  File "/tripy/tripy/backend/mlir/utils.py", line 513, in map_error_to_user_code_and_raise
    raise_error(
  File "/tripy/tripy/common/exception.py", line 195, in raise_error
    raise TripyException(msg) from None
tripy.common.exception.TripyException: 

--> <stdin>:1 in <module>()

MTRTException: InternalError: failed to run compilation on module with symbol name: outs_t2231_12

Additional context:
Traceback (most recent call last):
  File "/tripy/tripy/backend/mlir/compiler.py", line 102, in compile
    executable = compiler.compiler_stablehlo_to_executable(
mlir_tensorrt.runtime._mlir_libs._api.MTRTException: InternalError: failed to run compilation on module with symbol name: outs_t2231_12
.
    (t1926)): error: op: %7 = "stablehlo.add"(%arg0, %arg1) : (tensor<f32>, tensor<f32>) -> tensor<f32> from function main is invalid, post clustering.
    (t1926)): error: op: "stablehlo.return"(%7) : (tensor<f32>) -> () from function main is invalid, post clustering.
    (t1926)): error: op: 
    %2 = "stablehlo.reduce"(%1, %0) <{dimensions = array<i64: 0, 2>}> ({
    ^bb0(%arg0: tensor<f32>, %arg1: tensor<f32>):
      %7 = "stablehlo.add"(%arg0, %arg1) : (tensor<f32>, tensor<f32>) -> tensor<f32>
      "stablehlo.return"(%7) : (tensor<f32>) -> ()
    }) : (tensor<2x3x2xf32>, tensor<f32>) -> tensor<3xf32> from function main is invalid, post clustering.

    This error occured while trying to compile the following FlatIR expression:
          |
          | t_inter4: [rank=(1), dtype=(float32), loc=(gpu:0)] = ReduceOp(t_inter3, t_inter5, reduce_mode='sum', reduce_dims=[0, 2])
          | 

    This operation was introduced to Cloning tensor t1926: [rank=(1), dtype=(float32), loc=(gpu:0)] for function input/output.

    Note: This originated from the following expression:

    --> /tripy/tripy/frontend/trace/ops/reduce.py:174 in sum()
          |
      174 |     return _reduce_impl(input, Reduce.Kind.SUM, dim, keepdim)
          |            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

    --> /tripy/tripy/frontend/trace/ops/reduce.py:318 in mean_impl()
          |
      318 |     sum_val = sum(tensor, dim=dim, keepdim=keepdim)
          |               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ --- required from here

    --> /tripy/tripy/frontend/trace/ops/reduce.py:361 in mean()
          |
      361 |     return mean_impl(input, dim, keepdim)
          |            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ --- required from here

    Input 0:

    --> /tripy/tripy/frontend/utils.py:455 in wrapper()
          |
      455 |             return func(*new_args, **new_kwargs)
          | 

    --> /tripy/tripy/frontend/trace/ops/reshape.py:145 in reshape()
          |
      145 |     return reshape_impl(input, shape, len(shape), output_len)
          |
farazkh80 commented 3 days ago

I was able to around this to implement bathnorm as follows

# Transpose the channel dimension (dim 1) with the batch dimension (dim 0)
x_transposed = tp.transpose(x, 0, 1)

# Reshape to combine the batch and spatial dimensions
C, N, H, W = x_transposed.shape
x_reshaped = tp.reshape(x_transposed, (C, N * H * W))

# Calculate mean and variance across the merged dimensions for each channel (C)
mean = tp.mean(x_reshaped, dim=1, keepdim=True)
variance = tp.var(x_reshaped, dim=1, keepdim=True)
mean = tp.reshape(mean, (1, C, 1, 1))
variance = tp.reshape(variance, (1, C, 1, 1))

# Transpose back to the original shape
x_transposed_back = tp.transpose(x_transposed, 0, 1)

# Normalize the input
x_normalized = (x_transposed_back - mean) / tp.sqrt(variance + self.eps)

# Apply the learned scaling (gamma) and shifting (beta)
x_scaled = self.gamma * x_normalized + self.beta

return x_scaled