jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.06k stars 2.75k forks source link

0.4.21 release: "JAX does not support string indexing" error, but it used to work fine (for jax2tf only!) #18903

Closed kjabon closed 9 months ago

kjabon commented 9 months ago

Description

I pass a dictionary into my jitted function as an argument, including several string/jax array pairs. I can run my jitted function with no issues. Something like the following:

...

state = env.init(key)
# state is a dictionary containing things like 
# {"numSteps": jnp.array([0]), "legal_actions": jnp.array([[True, False],[...]]), ... }
step_fn = jax.jit(env.step)
action = jnp.array([1])
state = step_fn(state, action)

...

However, when I do this, it breaks:

tf_init = jax2tf.convert(jax.jit(env.init), jit_compile=True, autograph=False)
tf_state = tf_init(tf.constant(key))
tf_step = jax2tf.convert(step_fn, jit_compile=True, autograph=False)
state = tf_step(tf_state, tf.constant(action))  # <-- breaks here
Relevant stack trace

My code

File "C:\Users\kmjab\miniconda3\envs\myEnv\lib\site-packages\myEnv\core.py", line 205, in step
is_illegal = ~state['legal_actions'][action] Jax code
File "C:\Users\kmjab\miniconda3\envs\myEnv\lib\site-packages\jax\_src\numpy\array_methods.py", line 741, in op return getattr(self.aval, f"_{name}")(self, *args) File "C:\Users\kmjab\miniconda3\envs\myEnv\lib\site-packages\jax\_src\numpy\array_methods.py", line 354, in _getitem return lax_numpy._rewriting_take(self, item) File "C:\Users\kmjab\miniconda3\envs\myEnv\lib\site-packages\jax\_src\numpy\lax_numpy.py", line 4497, in _rewriting_take treedef, static_idx, dynamic_idx = _split_index_for_jit(idx, arr.shape) File "C:\Users\kmjab\miniconda3\envs\myEnv\lib\site-packages\jax\_src\numpy\lax_numpy.py", line 4571, in _split_index_for_jit raise TypeError(f"JAX does not support string indexing; got {idx=}") TypeError: JAX does not support string indexing; got idx=('legal_actions',)

This error check was added here to address this issue. That issue seems to be unrelated, so maybe a more specific condition for triggering this error is in order? I would like to continue indexing into dictionaries using strings in jax2tf converted functions! :)

Going back to v0.4.20 (where the above code works) for now. Thanks!

What jax/jaxlib version are you using?

0.4.21 for both

Which accelerator(s) are you using?

cpu

Additional system info?

Both ubuntu and windows 10

NVIDIA GPU info

n/a

kjabon commented 9 months ago

Turns out this was an unrelated user-code bug. My bad, closing.