Open jonatanklosko opened 5 months ago
jax-metal (and its backend) don't yet support reducer with custom computing functions, neither multi-operands. Argmax/argmin are mapped to the corresponding backend ops as special cases.
I see, and so I assume for reduce with single operand there are also special cases for +
, *
and similar, that makes sense. I saw x, y -> x + y + 1
ignore the + 1
and that's what I thought :D
For context, we try to integrate the metal plugin in Nx (Elixir project that uses XLA similarly to Jax). We implement argmax/argmin on top of reduce, but the IR does not match Jax exactly. I may try to align the IR in the meantime.
Description
HLO
``` module @jit_f attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { func.func public @main(%arg0: tensor<2xi32> {mhlo.layout_mode = "default"}) -> (tensorThis fails with:
Interestingly
jnp.argmax
works and it is lowered to similar reduce on operand and index (just more elaborate).System info (python version, jaxlib version, accelerator, etc.)
jax-metal 0.0.7
(I also tried with jax/jaxlib 0.4.28 and
ENABLE_PJRT_COMPATIBILITY=1
, but same result)