google / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
29.03k stars 2.65k forks source link

jax-metal: reduce with multiple operands failed to legalize #21384

Open jonatanklosko opened 1 month ago

jonatanklosko commented 1 month ago

Description

import jax
import jax.numpy as jnp

def f(x):
  def reducer(op_val_index, acc_val_index):
    op_val, op_index = op_val_index
    acc_val, acc_index = acc_val_index
    return (op_val + acc_val, op_index + acc_index)

  idx = jax.lax.broadcasted_iota(jnp.int32, jnp.shape(x), 0)
  return jax.lax.reduce([x, idx], [jnp.array(0, x.dtype), jnp.array(0, idx.dtype)], reducer, [0])

x = jnp.array([1, 2])

print(jax.jit(f).lower(x).as_text())
print(jax.jit(f)(x))
HLO ``` module @jit_f attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { func.func public @main(%arg0: tensor<2xi32> {mhlo.layout_mode = "default"}) -> (tensor {jax.result_info = "[0]", mhlo.layout_mode = "default"}, tensor {jax.result_info = "[1]", mhlo.layout_mode = "default"}) { %0 = stablehlo.iota dim = 0 : tensor<2xi32> %1 = stablehlo.constant dense<0> : tensor %2 = stablehlo.constant dense<0> : tensor %3:2 = stablehlo.reduce(%arg0 init: %1), (%0 init: %2) across dimensions = [0] : (tensor<2xi32>, tensor<2xi32>, tensor, tensor) -> (tensor, tensor) reducer(%arg1: tensor, %arg3: tensor) (%arg2: tensor, %arg4: tensor) { %4 = stablehlo.add %arg1, %arg3 : tensor %5 = stablehlo.add %arg2, %arg4 : tensor stablehlo.return %4, %5 : tensor, tensor } return %3#0, %3#1 : tensor, tensor } } ```

This fails with:

Traceback (most recent call last):
  File "/Users/jonatanklosko/tmp/jax_metal_repro/reduce_multi_arg_illegal.py", line 16, in <module>
    print(jax.jit(f)(x))
jaxlib.xla_extension.XlaRuntimeError: UNKNOWN: /Users/jonatanklosko/tmp/jax_metal_repro/reduce_multi_arg_illegal.py:11:0: error: failed to legalize operation 'mhlo.reduce'
/Users/jonatanklosko/tmp/jax_metal_repro/reduce_multi_arg_illegal.py:15:0: note: called from
/Users/jonatanklosko/tmp/jax_metal_repro/reduce_multi_arg_illegal.py:11:0: note: see current operation:
%5:2 = "mhlo.reduce"(%arg0, %4, %1, %1) ({
^bb0(%arg1: tensor<si32>, %arg2: tensor<si32>, %arg3: tensor<si32>, %arg4: tensor<si32>):
  %6 = "mhlo.add"(%arg1, %arg3) : (tensor<si32>, tensor<si32>) -> tensor<si32>
  %7 = "mhlo.add"(%arg2, %arg4) : (tensor<si32>, tensor<si32>) -> tensor<si32>
  "mhlo.return"(%6, %7) : (tensor<si32>, tensor<si32>) -> ()
}) {dimensions = dense<0> : tensor<1xi64>} : (tensor<2xsi32>, tensor<2xsi32>, tensor<si32>, tensor<si32>) -> (tensor<si32>, tensor<si32>)

Interestingly jnp.argmax works and it is lowered to similar reduce on operand and index (just more elaborate).

System info (python version, jaxlib version, accelerator, etc.)

jax:    0.4.26
jaxlib: 0.4.26
numpy:  1.26.4
python: 3.10.8 (main, Nov 16 2022, 12:45:33) [Clang 14.0.0 (clang-1400.0.29.202)]
jax.devices (1 total, 1 local): [METAL(id=0)]
process_count: 1
platform: uname_result(system='Darwin', node='chonker', release='23.5.0', version='Darwin Kernel Version 23.5.0: Wed May  1 20:12:58 PDT 2024; root:xnu-10063.121.3~5/RELEASE_ARM64_T6000', machine='arm64')

jax-metal 0.0.7

(I also tried with jax/jaxlib 0.4.28 and ENABLE_PJRT_COMPATIBILITY=1, but same result)

shuhand0 commented 1 month ago

jax-metal (and its backend) don't yet support reducer with custom computing functions, neither multi-operands. Argmax/argmin are mapped to the corresponding backend ops as special cases.

jonatanklosko commented 1 month ago

I see, and so I assume for reduce with single operand there are also special cases for +, * and similar, that makes sense. I saw x, y -> x + y + 1 ignore the + 1 and that's what I thought :D

For context, we try to integrate the metal plugin in Nx (Elixir project that uses XLA similarly to Jax). We implement argmax/argmin on top of reduce, but the IR does not match Jax exactly. I may try to align the IR in the meantime.