google / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
29.03k stars 2.65k forks source link

jax-metal: count leading zeros not supported #21389

Open jonatanklosko opened 1 month ago

jonatanklosko commented 1 month ago

Description

import jax
import jax.numpy as jnp

def f(x):
  return jax.lax.clz(x)

x = jnp.array([0, 1 << 10, 1 << 20])

# Print lowered HLO
print(jax.jit(f).lower(x).as_text())
print(jax.jit(f)(x))
HLO ``` module @jit_f attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { func.func public @main(%arg0: tensor<3xi32> {mhlo.layout_mode = "default"}) -> (tensor<3xi32> {jax.result_info = "", mhlo.layout_mode = "default"}) { %0 = stablehlo.count_leading_zeros %arg0 : tensor<3xi32> return %0 : tensor<3xi32> } } ```

fails with

Traceback (most recent call last):
  File "/Users/jonatanklosko/tmp/jax_mlir.py", line 71, in <module>
    print(jax.jit(f)(x))
jaxlib.xla_extension.XlaRuntimeError: UNKNOWN: /Users/jonatanklosko/tmp/jax_mlir.py:63:0: error: failed to legalize operation 'mhlo.reduce_window'
/Users/jonatanklosko/tmp/jax_mlir.py:70:0: note: called from
/Users/jonatanklosko/tmp/jax_mlir.py:63:0: note: see current operation:
%2 = "mhlo.reduce_window"(%arg0, %1) ({
^bb0(%arg1: tensor<si32>, %arg2: tensor<si32>):
  %3 = "mhlo.add"(%arg1, %arg2) : (tensor<si32>, tensor<si32>) -> tensor<si32>
  "mhlo.return"(%3) : (tensor<si32>) -> ()
}) {base_dilations = dense<1> : tensor<1xi64>, padding = dense<1> : tensor<1x2xi64>, window_dilations = dense<1> : tensor<1xi64>, window_dimensions = dense<2> : tensor<1xi64>, window_strides = dense<1> : tensor<1xi64>} : (tensor<3xsi32>, tensor<si32>) -> tensor<4xsi32>

System info (python version, jaxlib version, accelerator, etc.)

jax:    0.4.26
jaxlib: 0.4.26
numpy:  1.26.4
python: 3.10.8 (main, Nov 16 2022, 12:45:33) [Clang 14.0.0 (clang-1400.0.29.202)]
jax.devices (1 total, 1 local): [METAL(id=0)]
process_count: 1
platform: uname_result(system='Darwin', node='chonker', release='23.5.0', version='Darwin Kernel Version 23.5.0: Wed May  1 20:12:58 PDT 2024; root:xnu-10063.121.3~5/RELEASE_ARM64_T6000', machine='arm64')

jax-metal 0.0.7

jonatanklosko commented 1 month ago

Same with jax.lax.population_count and jax.lax.cbrt, but all three are not extremely widespread, so I'm just going to mention them on the same issue.

jonatanklosko commented 1 month ago

If there is a list of operations that are inherently not supported for jax-metal, it would be very helpful to have the list documented somewhere :)