jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
29.97k stars 2.75k forks source link

jax.pmap argument in_axes can lead to error when passed a list #12557

Open JonyEpsilon opened 1 year ago

JonyEpsilon commented 1 year ago

Description

I noticed that jax.pmap's in_axes argument is sensitive to list vs. tuple input in a way that I found unexpected. The docs indicate that this argument must be a "A non-negative integer, None, or nested Python container", however I find in the code below that while a tuple works, a list does not. Additionally, the error message is quite inscrutable!

import chex
import jax
import jax.numpy as jnp

@chex.dataclass
class Batch:
  x: jnp.ndarray
  y: jnp.ndarray

def f(key, b: Batch):
  return b.x + b.y + jax.random.uniform(key)

# This works (works = the evaluation on the final line below works).
fp = jax.pmap(f, axis_name="superbatch", in_axes=(0, Batch(x=0, y=0)))
# As does this, using a tree prefix.
# fp = jax.pmap(f, axis_name="superbatch", in_axes=(0, 0))
# But this doesn't.
# fp = jax.pmap(f, axis_name="superbatch", in_axes=[0, Batch(x=0, y=0)])
# Nor does this.
# fp = jax.pmap(f, axis_name="superbatch", in_axes=[0, 0])

NUM_DEVICES = jax.local_device_count()
TEST_ARRAY_SIZE = 8

k = jax.random.PRNGKey(42)
k, *keys = jax.random.split(k, NUM_DEVICES + 1)
super_batch = Batch(
    x=jnp.ones((NUM_DEVICES, TEST_ARRAY_SIZE)),
    y=jnp.ones((NUM_DEVICES, TEST_ARRAY_SIZE)))

fp(jnp.stack(keys), super_batch)

The error message given if the third definition of fp above is used is:

ValueError: pmap in_axes specification must be a tree prefix of the corresponding value, got specification ([0, Batch(x=0, y=0)], 0) for value tree PyTreeDef(((*, CustomNode(Batch[('x', 'y')], [*, *])), {})).

What jax/jaxlib version are you using?

0.3.18 (g3 HEAD)

Which accelerator(s) are you using?

TPU

Additional system info

No response

NVIDIA GPU info

No response

selamw1 commented 7 months ago

Hi @JonyEpsilon In JAX version 0.4.23, the jax.pmap function now explicitly requires a tuple for its in_axes argument, whereas this requirement wasn't clearly indicated in error messages in JAX 0.3.18. I executed the mentioned code with the latest version of JAX(0.4.23). Now the error message indicates that a tuple is required instead of a list. Could you please verify this behavior with JAX 0.4.23 and confirm closure of this issue once resolved?

import chex
import jax
import jax.numpy as jnp

@chex.dataclass
class Batch:
  x: jnp.ndarray
  y: jnp.ndarray

def f(key, b: Batch):
  return b.x + b.y + jax.random.uniform(key)

# This works (works = the evaluation on the final line below works).
# fp = jax.pmap(f, axis_name="superbatch", in_axes=(0, Batch(x=0, y=0)))
# As does this, using a tree prefix.
# fp = jax.pmap(f, axis_name="superbatch", in_axes=(0, 0))
# Will raise an error: `in_axes` argument to `pmap` must be a tree prefix of that tuple.
fp = jax.pmap(f, axis_name="superbatch", in_axes=[0, Batch(x=0, y=0)])
# Will raise an error: `in_axes` argument to `pmap` must be a tree prefix of that tuple.
# fp = jax.pmap(f, axis_name="superbatch", in_axes=[0, 0])

NUM_DEVICES = jax.local_device_count()
TEST_ARRAY_SIZE = 8

k = jax.random.PRNGKey(42)
k, *keys = jax.random.split(k, NUM_DEVICES + 1)
super_batch = Batch(
    x=jnp.ones((NUM_DEVICES, TEST_ARRAY_SIZE)),
    y=jnp.ones((NUM_DEVICES, TEST_ARRAY_SIZE)))

fp(jnp.stack(keys), super_batch)

Output(Traceback):

Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
ValueError: pytree structure error: different types at key path
    pmap in_axes[0]
At that key path, the prefix pytree pmap in_axes has a subtree of type
    <class 'list'>
but at the same key path the full pytree has a subtree of different type
    <class 'tuple'>.

The 'full pytree' here is the tuple of arguments passed positionally to the pmapped function, and the value of `in_axes` must be a tree prefix of that tuple. But it was not a prefix.

Check that the value of the `in_axes` argument to `pmap` is a tree prefix of the tuple of arguments passed positionally to the pmapped function.