Closed mvpant closed 2 weeks ago
WDYT?
I have a general question before I dig more this the review:
I can see that with the "use reshape to ensure size(broadcast_dimensions) = rank(operand)" transformation, even the scalar broadcast_in_dim transformed to the
reshape; broadcast_in_dim
.stablehlo.broadcast_in_dim %arg0, dims = [] : (tensor<i32>) -> tensor<4x32xi32> to stablehlo.reshape %arg0 : (tensor<i32>) -> tensor<1x1xi32> stablehlo.broadcast_in_dim [[R0]], dims = [0, 1] : (tensor<1x1xi32>) -> tensor<4x32xi32>
This is how I interpret the transformation.
Yes, that's right.
The Stablehlo
broadcast_in_dim
can simultaneously do
- Broadcasting of a lower-rank array to a higher-rank array
- Broadcasting using degenerate dimensions
However, the proposed transformation seems to transform the code such that
broadcast_in_dim
only has to deal with (2) cases. Indeed the transformation makes sense when there is a implicit transpose involved in thebroadcast_in_dim
, but for cases when implicit transpose is not there, I was wondering about the optimization opportunities this transformation would open up (other than simplifying the code for readability)
I agree that this transformation doesn’t offer any optimization opportunities for stablehlo itself. However, such transformations can be valuable for bringing operations closer to the hardware level. For example, in a typical non-scalar accelerator that performs optimally with data batches, it would be advantageous to break down complex operations into simpler, more basic operations resembling RISC instructions:
%0 = stablehlo.broadcast_in_dim %arg0, dims = [2, 4, 1] : (tensor<32x64x16xf32>) -> tensor<2x16x32x4x64xf32>
to a series of trivial operations:
%0 = stablehlo.reshape %arg0 : (tensor<32x64x16xf32>) -> tensor<1x32x64x1x16xf32>
// swap 1-axis with 4-axis
%1 = stablehlo.transpose %0, dims = [0, 4, 2, 3, 1] : (tensor<1x32x64x1x16xf32>) -> tensor<1x16x64x1x32xf32>
// swap 2-axis with 4-axis
%2 = stablehlo.transpose %1, dims = [0, 1, 4, 3, 2] : (tensor<1x16x64x1x32xf32>) -> tensor<1x16x32x1x64xf32>
// broadcast 3 axis (tail strategy)
%4 = stablehlo.broadcast_in_dim %0, dims = [0, 1, 2, 3, 4] : (tensor<1x16x32x1x64xf32>) -> tensor<1x16x32x4x64xf32>
// broadcast 0 axis
%4 = stablehlo.broadcast_in_dim %0, dims = [0, 1, 2, 3, 4] : (tensor<1x16x32x4x64xf32>) -> tensor<2x16x32x4x64xf32>
I can envision that such transformations are usually done in hardware-focused dialects. This brings up a question of ideology --- does stablehlo assume that these transformations are done in target-specific dialects, or is it open to such hardware-related simplifications?
PS: IMO having a pass aligned with https://github.com/openxla/xla/blob/9c7314893fb16b897370e12c73409d5d9b5eab5c/xla/mlir_hlo/mhlo/IR/hlo_ops.cc#L2364 seems sufficient (unless I am missing something).
Agree
My take: StableHLO transformations are mostly hardware agnostic and for some backend (as opposed to ops which has native support for some of this composite patterns) it is relevant to batch simpler operations for better performance then there should be passes, achieving the same, closer to the backend.
How about for now we make the pass aligned with https://github.com/openxla/xla/blob/9c7314893fb16b897370e12c73409d5d9b5eab5c/xla/mlir_hlo/mhlo/IR/hlo_ops.cc#L2364. That will include expanding the broadcast with implicit transpose.
cc @GleasonK
... for some backend (as opposed to ops which has native support for some of this composite patterns) it is relevant to batch simpler operations for better performance ...
Agree.
For now reverting patch №2. When it comes to breaking it down into a series of operations (how and what), it appears to be a backend-specific operation. However, I’m not entirely sure if reshaping it (without extracting the transpose) is not useful, at least for readability and simplifying the code for further transformations.
Thanks! The changes LGTM.
This patch is a two-stage: 1) change
isIotaRange
constraint tollvm::sorted
in order to enable the rewrite frombroadcast_in_dim
->reshape
:is transformed into:
(hlo_ops.cpp: BroadcastInDimSimplifier uses same approach)
2)
broadcast_in_dim
operator can be quite complex in its semantics, sometimes functioning as a combination oftranspose
andbroadcast
. In order to simplify analysis and optimization, it can be beneficial to fully expand operand dimensions by ensuring that the size of broadcast_dimensions matches the rank of the result. This can be achieved by reshaping the input operand while preserving the constraintsize(broadcast_dimensions) = rank(operand)
that the size of broadcast_dimensions matches the rank of the operand.is transformed into:
Later it can be simplified even more by extracting transpose logic into standalone operation, something like this:
Since
broadcast
is on the way out from stablehlo.UPD: only patch 1 is accepted.