Open farazkh80 opened 1 month ago
new finding:tp.mean
only fails if we skip a dimension @parthchadha
x = tp.reshape(tp.arange(12), (2,3,2))
then if we do
>>> tp.mean(x, dim=[0], keepdim=True)
tensor(
[[[3.0000, 4.0000],
[5.0000, 6.0000],
[7.0000, 8.0000]]],
dtype=float32, loc=gpu:0, shape=(1, 3, 2))
and then
>>> tp.mean(x, dim=[0,1], keepdim=True)
tensor(
[[[5.0000, 6.0000]]],
dtype=float32, loc=gpu:0, shape=(1, 1, 2))
and even
>>> tp.mean(x, dim=[0,1,2], keepdim=True)
tensor([[[5.5000]]], dtype=float32, loc=gpu:0, shape=(1, 1, 1))
but if you skip a dim
>>> tp.mean(x, dim=[0,2], keepdim=True)
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/tripy/tripy/frontend/tensor.py", line 214, in __repr__
data_list = self.tolist()
File "/tripy/tripy/frontend/tensor.py", line 195, in tolist
data_memref = self.eval()
File "/tripy/tripy/frontend/tensor.py", line 180, in eval
executable = compiler.compile(mlir, flat_ir=flat_ir)
File "/tripy/tripy/utils/utils.py", line 74, in wrapper
result = func(*args, **kwargs)
File "/tripy/tripy/backend/mlir/compiler.py", line 109, in compile
map_error_to_user_code_and_raise(flat_ir, exc, stderr.decode())
File "/tripy/tripy/backend/mlir/utils.py", line 513, in map_error_to_user_code_and_raise
raise_error(
File "/tripy/tripy/common/exception.py", line 195, in raise_error
raise TripyException(msg) from None
tripy.common.exception.TripyException:
--> <stdin>:1 in <module>()
MTRTException: InternalError: failed to run compilation on module with symbol name: outs_t2231_12
Additional context:
Traceback (most recent call last):
File "/tripy/tripy/backend/mlir/compiler.py", line 102, in compile
executable = compiler.compiler_stablehlo_to_executable(
mlir_tensorrt.runtime._mlir_libs._api.MTRTException: InternalError: failed to run compilation on module with symbol name: outs_t2231_12
.
(t1926)): error: op: %7 = "stablehlo.add"(%arg0, %arg1) : (tensor<f32>, tensor<f32>) -> tensor<f32> from function main is invalid, post clustering.
(t1926)): error: op: "stablehlo.return"(%7) : (tensor<f32>) -> () from function main is invalid, post clustering.
(t1926)): error: op:
%2 = "stablehlo.reduce"(%1, %0) <{dimensions = array<i64: 0, 2>}> ({
^bb0(%arg0: tensor<f32>, %arg1: tensor<f32>):
%7 = "stablehlo.add"(%arg0, %arg1) : (tensor<f32>, tensor<f32>) -> tensor<f32>
"stablehlo.return"(%7) : (tensor<f32>) -> ()
}) : (tensor<2x3x2xf32>, tensor<f32>) -> tensor<3xf32> from function main is invalid, post clustering.
This error occured while trying to compile the following FlatIR expression:
|
| t_inter4: [rank=(1), dtype=(float32), loc=(gpu:0)] = ReduceOp(t_inter3, t_inter5, reduce_mode='sum', reduce_dims=[0, 2])
|
This operation was introduced to Cloning tensor t1926: [rank=(1), dtype=(float32), loc=(gpu:0)] for function input/output.
Note: This originated from the following expression:
--> /tripy/tripy/frontend/trace/ops/reduce.py:174 in sum()
|
174 | return _reduce_impl(input, Reduce.Kind.SUM, dim, keepdim)
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
--> /tripy/tripy/frontend/trace/ops/reduce.py:318 in mean_impl()
|
318 | sum_val = sum(tensor, dim=dim, keepdim=keepdim)
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ --- required from here
--> /tripy/tripy/frontend/trace/ops/reduce.py:361 in mean()
|
361 | return mean_impl(input, dim, keepdim)
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ --- required from here
Input 0:
--> /tripy/tripy/frontend/utils.py:455 in wrapper()
|
455 | return func(*new_args, **new_kwargs)
|
--> /tripy/tripy/frontend/trace/ops/reshape.py:145 in reshape()
|
145 | return reshape_impl(input, shape, len(shape), output_len)
|
I was able to around this to implement bathnorm as follows
# Transpose the channel dimension (dim 1) with the batch dimension (dim 0)
x_transposed = tp.transpose(x, 0, 1)
# Reshape to combine the batch and spatial dimensions
C, N, H, W = x_transposed.shape
x_reshaped = tp.reshape(x_transposed, (C, N * H * W))
# Calculate mean and variance across the merged dimensions for each channel (C)
mean = tp.mean(x_reshaped, dim=1, keepdim=True)
variance = tp.var(x_reshaped, dim=1, keepdim=True)
mean = tp.reshape(mean, (1, C, 1, 1))
variance = tp.reshape(variance, (1, C, 1, 1))
# Transpose back to the original shape
x_transposed_back = tp.transpose(x_transposed, 0, 1)
# Normalize the input
x_normalized = (x_transposed_back - mean) / tp.sqrt(variance + self.eps)
# Apply the learned scaling (gamma) and shifting (beta)
x_scaled = self.gamma * x_normalized + self.beta
return x_scaled
TensorRT doesn't actually support doing reduction across multiple dimensions.
In MLIR-TRT don't do anything special to work around this limitations. We would have to add a transformation that decomposes stablehlo.reduce
with multiple dimensions into a sequence of single-dim reductions or via reshape+reduce+reshape.
@christopherbate TRT does support reduction across multiple dimensions - the axes parameter is a bitset. MLIR-TRT also seems to support this in the case where the reduces axes are contiguous (see #297 for an example of multiple contiguous axes working).
It seems to be only when we skip over certain dimensions that compilation fails. TRT should work fine with skipped dimensions, so is this just a lowering bug?
@pranavm-nvidia
I'm aware the axes parameter is a bitset, but IIRC if you actually tried to reduce multiple dimensions, tRT will return an error.
This is the reason we have the current restriction, although maybe it has been lifted since TRT 8. We would need to confirm. Right now in StableHLO-to-TRT, we only convert single-axis reductions. The preprocessing pipelien tries to take care of flattening in the case the reduction is over multiple dims IIRC.
This issue is related to
tp.mean
andtp.var
failures when implementingBatchNorm
using Tripy for Resnet50 model.stout (helper prints for shapes)
stderr
mlir dumps
tripy-mlir-batchnorm.zip
seems like
tp.mean
andtp.variance
reduction operation failure at MLIR compile.