Open farazkh80 opened 4 days ago
new finding:tp.mean
only fails if we skip a dimension @parthchadha
x = tp.reshape(tp.arange(12), (2,3,2))
then if we do
>>> tp.mean(x, dim=[0], keepdim=True)
tensor(
[[[3.0000, 4.0000],
[5.0000, 6.0000],
[7.0000, 8.0000]]],
dtype=float32, loc=gpu:0, shape=(1, 3, 2))
and then
>>> tp.mean(x, dim=[0,1], keepdim=True)
tensor(
[[[5.0000, 6.0000]]],
dtype=float32, loc=gpu:0, shape=(1, 1, 2))
and even
>>> tp.mean(x, dim=[0,1,2], keepdim=True)
tensor([[[5.5000]]], dtype=float32, loc=gpu:0, shape=(1, 1, 1))
but if you skip a dim
>>> tp.mean(x, dim=[0,2], keepdim=True)
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/tripy/tripy/frontend/tensor.py", line 214, in __repr__
data_list = self.tolist()
File "/tripy/tripy/frontend/tensor.py", line 195, in tolist
data_memref = self.eval()
File "/tripy/tripy/frontend/tensor.py", line 180, in eval
executable = compiler.compile(mlir, flat_ir=flat_ir)
File "/tripy/tripy/utils/utils.py", line 74, in wrapper
result = func(*args, **kwargs)
File "/tripy/tripy/backend/mlir/compiler.py", line 109, in compile
map_error_to_user_code_and_raise(flat_ir, exc, stderr.decode())
File "/tripy/tripy/backend/mlir/utils.py", line 513, in map_error_to_user_code_and_raise
raise_error(
File "/tripy/tripy/common/exception.py", line 195, in raise_error
raise TripyException(msg) from None
tripy.common.exception.TripyException:
--> <stdin>:1 in <module>()
MTRTException: InternalError: failed to run compilation on module with symbol name: outs_t2231_12
Additional context:
Traceback (most recent call last):
File "/tripy/tripy/backend/mlir/compiler.py", line 102, in compile
executable = compiler.compiler_stablehlo_to_executable(
mlir_tensorrt.runtime._mlir_libs._api.MTRTException: InternalError: failed to run compilation on module with symbol name: outs_t2231_12
.
(t1926)): error: op: %7 = "stablehlo.add"(%arg0, %arg1) : (tensor<f32>, tensor<f32>) -> tensor<f32> from function main is invalid, post clustering.
(t1926)): error: op: "stablehlo.return"(%7) : (tensor<f32>) -> () from function main is invalid, post clustering.
(t1926)): error: op:
%2 = "stablehlo.reduce"(%1, %0) <{dimensions = array<i64: 0, 2>}> ({
^bb0(%arg0: tensor<f32>, %arg1: tensor<f32>):
%7 = "stablehlo.add"(%arg0, %arg1) : (tensor<f32>, tensor<f32>) -> tensor<f32>
"stablehlo.return"(%7) : (tensor<f32>) -> ()
}) : (tensor<2x3x2xf32>, tensor<f32>) -> tensor<3xf32> from function main is invalid, post clustering.
This error occured while trying to compile the following FlatIR expression:
|
| t_inter4: [rank=(1), dtype=(float32), loc=(gpu:0)] = ReduceOp(t_inter3, t_inter5, reduce_mode='sum', reduce_dims=[0, 2])
|
This operation was introduced to Cloning tensor t1926: [rank=(1), dtype=(float32), loc=(gpu:0)] for function input/output.
Note: This originated from the following expression:
--> /tripy/tripy/frontend/trace/ops/reduce.py:174 in sum()
|
174 | return _reduce_impl(input, Reduce.Kind.SUM, dim, keepdim)
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
--> /tripy/tripy/frontend/trace/ops/reduce.py:318 in mean_impl()
|
318 | sum_val = sum(tensor, dim=dim, keepdim=keepdim)
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ --- required from here
--> /tripy/tripy/frontend/trace/ops/reduce.py:361 in mean()
|
361 | return mean_impl(input, dim, keepdim)
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ --- required from here
Input 0:
--> /tripy/tripy/frontend/utils.py:455 in wrapper()
|
455 | return func(*new_args, **new_kwargs)
|
--> /tripy/tripy/frontend/trace/ops/reshape.py:145 in reshape()
|
145 | return reshape_impl(input, shape, len(shape), output_len)
|
I was able to around this to implement bathnorm as follows
# Transpose the channel dimension (dim 1) with the batch dimension (dim 0)
x_transposed = tp.transpose(x, 0, 1)
# Reshape to combine the batch and spatial dimensions
C, N, H, W = x_transposed.shape
x_reshaped = tp.reshape(x_transposed, (C, N * H * W))
# Calculate mean and variance across the merged dimensions for each channel (C)
mean = tp.mean(x_reshaped, dim=1, keepdim=True)
variance = tp.var(x_reshaped, dim=1, keepdim=True)
mean = tp.reshape(mean, (1, C, 1, 1))
variance = tp.reshape(variance, (1, C, 1, 1))
# Transpose back to the original shape
x_transposed_back = tp.transpose(x_transposed, 0, 1)
# Normalize the input
x_normalized = (x_transposed_back - mean) / tp.sqrt(variance + self.eps)
# Apply the learned scaling (gamma) and shifting (beta)
x_scaled = self.gamma * x_normalized + self.beta
return x_scaled
This issue is related to
tp.mean
andtp.var
failures when implementingBatchNorm
using Tripy for Resnet50 model.stout (helper prints for shapes)
stderr
mlir dumps
tripy-mlir-batchnorm.zip
seems like
tp.mean
andtp.variance
reduction operation failure at MLIR compile.