I pass a dictionary into my jitted function as an argument, including several string/jax array pairs.
I can run my jitted function with no issues. Something like the following:
...
state = env.init(key)
# state is a dictionary containing things like
# {"numSteps": jnp.array([0]), "legal_actions": jnp.array([[True, False],[...]]), ... }
step_fn = jax.jit(env.step)
action = jnp.array([1])
state = step_fn(state, action)
...
File "C:\Users\kmjab\miniconda3\envs\myEnv\lib\site-packages\myEnv\core.py", line 205, in step
is_illegal = ~state['legal_actions'][action]
Jax code
File "C:\Users\kmjab\miniconda3\envs\myEnv\lib\site-packages\jax\_src\numpy\array_methods.py", line 741, in op
return getattr(self.aval, f"_{name}")(self, *args)
File "C:\Users\kmjab\miniconda3\envs\myEnv\lib\site-packages\jax\_src\numpy\array_methods.py", line 354, in _getitem
return lax_numpy._rewriting_take(self, item)
File "C:\Users\kmjab\miniconda3\envs\myEnv\lib\site-packages\jax\_src\numpy\lax_numpy.py", line 4497, in _rewriting_take
treedef, static_idx, dynamic_idx = _split_index_for_jit(idx, arr.shape)
File "C:\Users\kmjab\miniconda3\envs\myEnv\lib\site-packages\jax\_src\numpy\lax_numpy.py", line 4571, in _split_index_for_jit
raise TypeError(f"JAX does not support string indexing; got {idx=}")
TypeError: JAX does not support string indexing; got idx=('legal_actions',)
This error check was added here to address this issue. That issue seems to be unrelated, so maybe a more specific condition for triggering this error is in order? I would like to continue indexing into dictionaries using strings in jax2tf converted functions! :)
Going back to v0.4.20 (where the above code works) for now. Thanks!
Description
I pass a dictionary into my jitted function as an argument, including several string/jax array pairs. I can run my jitted function with no issues. Something like the following:
However, when I do this, it breaks:
Relevant stack trace
My code
File "C:\Users\kmjab\miniconda3\envs\myEnv\lib\site-packages\myEnv\core.py", line 205, in step
is_illegal = ~state['legal_actions'][action] Jax code
File "C:\Users\kmjab\miniconda3\envs\myEnv\lib\site-packages\jax\_src\numpy\array_methods.py", line 741, in op return getattr(self.aval, f"_{name}")(self, *args) File "C:\Users\kmjab\miniconda3\envs\myEnv\lib\site-packages\jax\_src\numpy\array_methods.py", line 354, in _getitem return lax_numpy._rewriting_take(self, item) File "C:\Users\kmjab\miniconda3\envs\myEnv\lib\site-packages\jax\_src\numpy\lax_numpy.py", line 4497, in _rewriting_take treedef, static_idx, dynamic_idx = _split_index_for_jit(idx, arr.shape) File "C:\Users\kmjab\miniconda3\envs\myEnv\lib\site-packages\jax\_src\numpy\lax_numpy.py", line 4571, in _split_index_for_jit raise TypeError(f"JAX does not support string indexing; got {idx=}") TypeError: JAX does not support string indexing; got idx=('legal_actions',)
This error check was added here to address this issue. That issue seems to be unrelated, so maybe a more specific condition for triggering this error is in order? I would like to continue indexing into dictionaries using strings in jax2tf converted functions! :)
Going back to v0.4.20 (where the above code works) for now. Thanks!
What jax/jaxlib version are you using?
0.4.21 for both
Which accelerator(s) are you using?
cpu
Additional system info?
Both ubuntu and windows 10
NVIDIA GPU info
n/a