openxla / stablehlo

Backward compatible ML compute opset inspired by HLO/MHLO
Apache License 2.0
357 stars 100 forks source link

Align canonicalization of stablehlo.broadcast_in_dim with XLA #2371

Closed mvpant closed 2 weeks ago

mvpant commented 1 month ago

This patch is a two-stage: 1) change isIotaRange constraint to llvm::sorted in order to enable the rewrite from broadcast_in_dim -> reshape:

  %0 = stablehlo.broadcast_in_dim %arg0, dims = [1, 2] : (tensor<3x6xi32>) -> tensor<1x3x6xi32>

is transformed into:

  %0 = stablehlo.reshape %arg0 : (tensor<3x6xi32>) -> tensor<1x3x6xi32>

(hlo_ops.cpp: BroadcastInDimSimplifier uses same approach)

2) broadcast_in_dim operator can be quite complex in its semantics, sometimes functioning as a combination of transpose and broadcast. In order to simplify analysis and optimization, it can be beneficial to fully expand operand dimensions by ensuring that the size of broadcast_dimensions matches the rank of the result. This can be achieved by reshaping the input operand while preserving the constraint size(broadcast_dimensions) = rank(operand) that the size of broadcast_dimensions matches the rank of the operand.

    %3 = stablehlo.broadcast_in_dim %arg3, dims = [2, 1, 4] : (tensor<32x64x16xf32>) -> tensor<4x64x32x2x16xf32>

is transformed into:

  %2 = stablehlo.reshape %arg3 : (tensor<32x64x16xf32>) -> tensor<1x32x64x1x16xf32>
  %3 = stablehlo.broadcast_in_dim %2, dims = [0, 2, 1, 3, 4] : (tensor<1x32x64x1x16xf32>) -> tensor<4x64x32x2x16xf32>

Later it can be simplified even more by extracting transpose logic into standalone operation, something like this:

  %2 = stablehlo.reshape %arg3 : (tensor<32x64x16xf32>) -> tensor<1x32x64x1x16xf32>
  %3 = stablehlo.transpose %2, dims = [0, 2, 1, 3, 4] : (tensor<1x32x64x1x16xf32>) -> tensor<1x64x32x1x16xf32>
  %4 = stablehlo.broadcast_in_dim %3, dims = [0, 1, 2, 3, 4] : (tensor<1x64x32x1x16xf32>) -> tensor<4x64x32x2x16xf32> 

Since broadcast is on the way out from stablehlo.

UPD: only patch 1 is accepted.

mvpant commented 1 month ago

WDYT?

mvpant commented 3 weeks ago

I have a general question before I dig more this the review:

I can see that with the "use reshape to ensure size(broadcast_dimensions) = rank(operand)" transformation, even the scalar broadcast_in_dim transformed to the reshape; broadcast_in_dim.

stablehlo.broadcast_in_dim %arg0, dims = [] : (tensor<i32>) -> tensor<4x32xi32>

to

stablehlo.reshape %arg0 : (tensor<i32>) -> tensor<1x1xi32>
stablehlo.broadcast_in_dim [[R0]], dims = [0, 1] : (tensor<1x1xi32>) -> tensor<4x32xi32>

This is how I interpret the transformation.

Yes, that's right.

The Stablehlo broadcast_in_dim can simultaneously do

  1. Broadcasting of a lower-rank array to a higher-rank array
  2. Broadcasting using degenerate dimensions

However, the proposed transformation seems to transform the code such that broadcast_in_dim only has to deal with (2) cases. Indeed the transformation makes sense when there is a implicit transpose involved in the broadcast_in_dim, but for cases when implicit transpose is not there, I was wondering about the optimization opportunities this transformation would open up (other than simplifying the code for readability)

I agree that this transformation doesn’t offer any optimization opportunities for stablehlo itself. However, such transformations can be valuable for bringing operations closer to the hardware level. For example, in a typical non-scalar accelerator that performs optimally with data batches, it would be advantageous to break down complex operations into simpler, more basic operations resembling RISC instructions:

  %0 = stablehlo.broadcast_in_dim %arg0, dims = [2, 4, 1] : (tensor<32x64x16xf32>) -> tensor<2x16x32x4x64xf32>

to a series of trivial operations:

  %0 = stablehlo.reshape %arg0 : (tensor<32x64x16xf32>) -> tensor<1x32x64x1x16xf32>
  // swap 1-axis with 4-axis
  %1 = stablehlo.transpose %0, dims = [0, 4, 2, 3, 1] : (tensor<1x32x64x1x16xf32>) -> tensor<1x16x64x1x32xf32>
  // swap 2-axis with 4-axis
  %2 = stablehlo.transpose %1, dims = [0, 1, 4, 3, 2] : (tensor<1x16x64x1x32xf32>) -> tensor<1x16x32x1x64xf32>
  // broadcast 3 axis (tail strategy)
  %4 = stablehlo.broadcast_in_dim %0, dims = [0, 1, 2, 3, 4] : (tensor<1x16x32x1x64xf32>) -> tensor<1x16x32x4x64xf32>
  // broadcast 0 axis
  %4 = stablehlo.broadcast_in_dim %0, dims = [0, 1, 2, 3, 4] : (tensor<1x16x32x4x64xf32>) -> tensor<2x16x32x4x64xf32>

I can envision that such transformations are usually done in hardware-focused dialects. This brings up a question of ideology --- does stablehlo assume that these transformations are done in target-specific dialects, or is it open to such hardware-related simplifications?

PS: IMO having a pass aligned with https://github.com/openxla/xla/blob/9c7314893fb16b897370e12c73409d5d9b5eab5c/xla/mlir_hlo/mhlo/IR/hlo_ops.cc#L2364 seems sufficient (unless I am missing something).

Agree

sdasgup3 commented 2 weeks ago

My take: StableHLO transformations are mostly hardware agnostic and for some backend (as opposed to ops which has native support for some of this composite patterns) it is relevant to batch simpler operations for better performance then there should be passes, achieving the same, closer to the backend.

How about for now we make the pass aligned with https://github.com/openxla/xla/blob/9c7314893fb16b897370e12c73409d5d9b5eab5c/xla/mlir_hlo/mhlo/IR/hlo_ops.cc#L2364. That will include expanding the broadcast with implicit transpose.

cc @GleasonK

mvpant commented 2 weeks ago

... for some backend (as opposed to ops which has native support for some of this composite patterns) it is relevant to batch simpler operations for better performance ...

Agree.

For now reverting patch №2. When it comes to breaking it down into a series of operations (how and what), it appears to be a backend-specific operation. However, I’m not entirely sure if reshaping it (without extracting the transpose) is not useful, at least for readability and simplifying the code for further transformations.

sdasgup3 commented 2 weeks ago

Thanks! The changes LGTM.