google / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
29.94k stars 2.74k forks source link

vmap(SPMD axis)/shmap/(vmap with capture) pattern breaks batching #23476

Open jkr26 opened 1 week ago

jkr26 commented 1 week ago

Description

By doing sufficiently-exciting capture + sharding' + mapping behavior, it is possible to induce jax's batching to witness inconsistent sizes for the batch axis. The following code snippet is a full repro on Google-internal TPU-backed colab runtimes:

import jax
from jax import numpy as jnp
import numpy as np
from jax.sharding import PartitionSpec as P
from jax.experimental import shard_map

device_mesh = jax.sharding.Mesh(np.array(jax.devices()).reshape([2, 4]), ('x', 'y'))

x_sharding = jax.sharding.NamedSharding(mesh=device_mesh, spec=P('x'))
xy_sharding = jax.sharding.NamedSharding(mesh=device_mesh, spec=P('x', 'y'))

def foo_with_float_arg_no_cond(to_add_slice, global_state):

  # We're going to vmap this gut over the shard_map slice of state.
  def foo_capturing_something(state_slice):
    return to_add_slice + state_slice

  vmap_capture = jax.vmap(foo_capturing_something)

  # Note check-rep = True or False results in the same error
  shmap_vmap_capture = shard_map.shard_map(vmap_capture, mesh=device_mesh, in_specs=P('y'), out_specs=P('y'), check_rep=True)
  result = shmap_vmap_capture(global_state)
  return result

global_state = jax.device_put(jnp.ones(shape=[2, 4]), xy_sharding)

float_vector = jax.device_put(jnp.array([0.0, 1.0]), x_sharding)
# Works
jax.vmap(foo_with_float_arg_no_cond)(float_vector, global_state)
# Doesn't
jax.vmap(foo_with_float_arg_no_cond, spmd_axis_name='x')(float_vector, global_state)

I'm also sending out a failing test, which will reproduce this problem.

System info (python version, jaxlib version, accelerator, etc.)

jax:    0.4.32
jaxlib: 0.4.32
numpy:  1.26.3
python: 3.11.8 (stable, redacted, redacted) [Clang google3-trunk (7f7f4feaf07dd3bb4b22d0c25d34b6c99c753aa2)]
jax.devices (8 total, 8 local): [TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0) TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1) ... TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0) TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]
process_count: 1
platform: uname_result(system='Linux', node='d00ef52addb6d6c9-532494bfbbe.borgtask.google.com', release='5.10.0-smp-1103.32.0.0', version='#1 [v5.10.0-1103.32.0.0] SMP @1721941885', machine='x86_64')
mattjj commented 1 week ago

Mentioned in chat: this is a very nice reproducer of a known bug, and that bug is the main reason spmd_axis_name is undocumented!