Open jonatanklosko opened 4 weeks ago
On a sidenote, I am wondering why jax.lax.cond(pred, true_fun, false_fun, *operands)
uses stablehlo.case
, and not stablehlo.if
(which is closer semantically). I found https://github.com/openxla/stablehlo/issues/599, so perhaps there's nothing to be gained from stablehlo.if
and stablehlo.case
is used just because it's more generic? I would love someone from the Jax team to confirm :)
Description
HLO
``` module @jit_f attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { func.func public @main(%arg0: tensorThe above fails with
However, it works if we change the branches slightly (just adding a constant):
Another example that fails is this:
but works when changed to
From the above I gather that this operation it is already supposed to work, but there is clearly some inconsistency.
System info (python version, jaxlib version, accelerator, etc.)
jax-metal 0.0.7