Open JonyEpsilon opened 1 year ago
Hi @JonyEpsilon
In JAX version 0.4.23, the jax.pmap
function now explicitly requires a tuple for its in_axes
argument, whereas this requirement wasn't clearly indicated in error messages in JAX 0.3.18.
I executed the mentioned code with the latest version of JAX(0.4.23). Now the error message indicates that a tuple is required instead of a list. Could you please verify this behavior with JAX 0.4.23 and confirm closure of this issue once resolved?
import chex
import jax
import jax.numpy as jnp
@chex.dataclass
class Batch:
x: jnp.ndarray
y: jnp.ndarray
def f(key, b: Batch):
return b.x + b.y + jax.random.uniform(key)
# This works (works = the evaluation on the final line below works).
# fp = jax.pmap(f, axis_name="superbatch", in_axes=(0, Batch(x=0, y=0)))
# As does this, using a tree prefix.
# fp = jax.pmap(f, axis_name="superbatch", in_axes=(0, 0))
# Will raise an error: `in_axes` argument to `pmap` must be a tree prefix of that tuple.
fp = jax.pmap(f, axis_name="superbatch", in_axes=[0, Batch(x=0, y=0)])
# Will raise an error: `in_axes` argument to `pmap` must be a tree prefix of that tuple.
# fp = jax.pmap(f, axis_name="superbatch", in_axes=[0, 0])
NUM_DEVICES = jax.local_device_count()
TEST_ARRAY_SIZE = 8
k = jax.random.PRNGKey(42)
k, *keys = jax.random.split(k, NUM_DEVICES + 1)
super_batch = Batch(
x=jnp.ones((NUM_DEVICES, TEST_ARRAY_SIZE)),
y=jnp.ones((NUM_DEVICES, TEST_ARRAY_SIZE)))
fp(jnp.stack(keys), super_batch)
Output(Traceback):
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
ValueError: pytree structure error: different types at key path
pmap in_axes[0]
At that key path, the prefix pytree pmap in_axes has a subtree of type
<class 'list'>
but at the same key path the full pytree has a subtree of different type
<class 'tuple'>.
The 'full pytree' here is the tuple of arguments passed positionally to the pmapped function, and the value of `in_axes` must be a tree prefix of that tuple. But it was not a prefix.
Check that the value of the `in_axes` argument to `pmap` is a tree prefix of the tuple of arguments passed positionally to the pmapped function.
Description
I noticed that
jax.pmap
'sin_axes
argument is sensitive to list vs. tuple input in a way that I found unexpected. The docs indicate that this argument must be a "A non-negative integer, None, or nested Python container", however I find in the code below that while a tuple works, a list does not. Additionally, the error message is quite inscrutable!The error message given if the third definition of
fp
above is used is:What jax/jaxlib version are you using?
0.3.18 (g3 HEAD)
Which accelerator(s) are you using?
TPU
Additional system info
No response
NVIDIA GPU info
No response