google / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
29.03k stars 2.65k forks source link

jax-metal: cond fails in compile in certain cases #21601

Open jonatanklosko opened 4 weeks ago

jonatanklosko commented 4 weeks ago

Description

import jax
import jax.numpy as jnp

def f(pred, x, y):
  return jax.lax.cond(pred, lambda xy: xy[0], lambda xy: xy[1], (x, y))

pred = jnp.array(0)
x = jnp.array(10.0)
y = jnp.array(20.0)

# Print lowered HLO
print(jax.jit(f).lower(pred, x, y).as_text())
print(jax.jit(f)(pred, x, y))
HLO ``` module @jit_f attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { func.func public @main(%arg0: tensor {mhlo.layout_mode = "default"}, %arg1: tensor {mhlo.layout_mode = "default"}, %arg2: tensor {mhlo.layout_mode = "default"}) -> (tensor {jax.result_info = "", mhlo.layout_mode = "default"}) { %0 = stablehlo.constant dense<0> : tensor %1 = stablehlo.compare NE, %arg0, %0, SIGNED : (tensor, tensor) -> tensor %2 = stablehlo.convert %1 : (tensor) -> tensor %3 = "stablehlo.case"(%2) ({ stablehlo.return %arg2 : tensor }, { stablehlo.return %arg1 : tensor }) : (tensor) -> tensor return %3 : tensor } } ```

The above fails with

jaxlib.xla_extension.XlaRuntimeError: INTERNAL: Unable to serialize MPS module

However, it works if we change the branches slightly (just adding a constant):

jax.lax.cond(pred, lambda xy: xy[0] + 1.0, lambda xy: xy[1] + 2.0, (x, y))

Another example that fails is this:

def f(pred, x):
  return jax.lax.cond(pred, lambda x: (x + 1.0, 1.0), lambda x: (x + 2.0, 2.0), x)

pred = jnp.array(0)
x = jnp.array([1, 2, 3])

but works when changed to

jax.lax.cond(pred, lambda x: (x + 1.0, x + 1.0), lambda x: (x + 2.0, x + 2.0), x)

From the above I gather that this operation it is already supposed to work, but there is clearly some inconsistency.

System info (python version, jaxlib version, accelerator, etc.)

jax:    0.4.26
jaxlib: 0.4.26
numpy:  1.26.4
python: 3.10.8 (main, Nov 16 2022, 12:45:33) [Clang 14.0.0 (clang-1400.0.29.202)]
jax.devices (1 total, 1 local): [METAL(id=0)]
process_count: 1
platform: uname_result(system='Darwin', node='chonker', release='23.5.0', version='Darwin Kernel Version 23.5.0: Wed May  1 20:12:58 PDT 2024; root:xnu-10063.121.3~5/RELEASE_ARM64_T6000', machine='arm64')

jax-metal 0.0.7

jonatanklosko commented 4 weeks ago

On a sidenote, I am wondering why jax.lax.cond(pred, true_fun, false_fun, *operands) uses stablehlo.case, and not stablehlo.if (which is closer semantically). I found https://github.com/openxla/stablehlo/issues/599, so perhaps there's nothing to be gained from stablehlo.if and stablehlo.case is used just because it's more generic? I would love someone from the Jax team to confirm :)