Open jonatanklosko opened 4 months ago
Same with jax.lax.population_count
and jax.lax.cbrt
, but all three are not extremely widespread, so I'm just going to mention them on the same issue.
If there is a list of operations that are inherently not supported for jax-metal, it would be very helpful to have the list documented somewhere :)
Description
HLO
``` module @jit_f attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { func.func public @main(%arg0: tensor<3xi32> {mhlo.layout_mode = "default"}) -> (tensor<3xi32> {jax.result_info = "", mhlo.layout_mode = "default"}) { %0 = stablehlo.count_leading_zeros %arg0 : tensor<3xi32> return %0 : tensor<3xi32> } } ```fails with
System info (python version, jaxlib version, accelerator, etc.)
jax-metal 0.0.7