google / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
Apache License 2.0
29.76k stars 2.72k forks source link

pmap race condition (?) #21946

Closed LeonEricsson closed 2 months ago

LeonEricsson commented 2 months ago


I've got a self-play (not important) function that I would like to execute on the CPU, but I'm running into an issue where the execution mysteriously freezes on the 32nd iteration. Now, when I say freeze or halt, I actually don't know if this is a complete freeze or just a immense slow down. The first 31 iterations run in less than a second, and I've waited for 5 minutes without the 32nd iteration completing. Anyway, let me walk through the code and what I've found so far.

This is the selfplay function that is called externally N times, and halts to a stop on exactly the 32nd call (every time):

@partial(jax.pmap, backend='cpu')
def selfplay(
    rng_key: jnp.ndarray,
) -> Sample:
    model_params, model_state = model
    batch_size = config.selfplay_batch_size // num_cpu_devices

    def step_fn(state, key):
        key1, key2 = jax.random.split(key)
        observation = state.observation

        (logits, value), _ = forward.apply(
        root = RootFnOutput(prior_logits=logits, value=value, embedding=state)

        policy_output = modified_gumbel_muzero_policy(

        actor = state.current_player
        keys = jax.random.split(key2, batch_size)
        state = jax.vmap(auto_reset(auto_chance(env.step), env.init))(

        return state, observation

I've identified the modified_gumbel_muzero_policy function as the culprit (swapping it out for a dummy initialization fixes the issue). My initial thought was that there was some complex logic in modified_gumbel_muzero_policy that for some (unknown) reason was really resource intense. So I started replacing all the logic step-by-step with dummy initializers to see when the issue fixed itself, and I ended up with this:

def modified_gumbel_muzero_policy(
    params: base.Params,
    rng_key: chex.PRNGKey,
    root: base.RootFnOutput,
    recurrent_fn: base.RecurrentFn,
    num_simulations: int,
    invalid_actions: Optional[chex.Array] = None,
    max_depth: Optional[int] = None,
    loop_fn: base.LoopFn = jax.lax.fori_loop,
    qtransform: base.QTransform = qtransforms.qtransform_completed_by_mix_value,
    max_num_considered_actions: int = 4,
    gumbel_scale: chex.Numeric = 1.0,
) -> base.PolicyOutputWithValue[action_selection.GumbelMuZeroExtraData]:
    """Runs Gumbel MuZero search and returns the `PolicyOutput`.

    action = jnp.argmax(root.prior_logits, axis=-1)
    action_weights = root.prior_logits

    return base.PolicyOutput(

There's hardly anything left?! I've been able to remove almost all the logic from this function and everything still freezes on the 32nd selfplay call.

Swapping out action and action_weights to dummy initializers does fix the problem:

def modified_gumbel_muzero_policy(
    params: base.Params,
    rng_key: chex.PRNGKey,
    root: base.RootFnOutput,
    recurrent_fn: base.RecurrentFn,
    num_simulations: int,
    invalid_actions: Optional[chex.Array] = None,
    max_depth: Optional[int] = None,
    loop_fn: base.LoopFn = jax.lax.fori_loop,
    qtransform: base.QTransform = qtransforms.qtransform_completed_by_mix_value,
    max_num_considered_actions: int = 4,
    gumbel_scale: chex.Numeric = 1.0,
) -> base.PolicyOutputWithValue[action_selection.GumbelMuZeroExtraData]:
    """Runs Gumbel MuZero search and returns the `PolicyOutput`.

    # action = jnp.argmax(root.prior_logits, axis=-1)
    # action_weights = root.prior_logits

    action = jnp.zeros(64, dtype=jnp.int32)
    action_weights = jnp.zeros((64, 89))

return base.PolicyOutput(

This runs.

What is going on here? Does anything have any intuition as to why this is happening?

I'm enabling CPU executing across cores using:

os.environ['JAX_PLATFORMS'] = 'cpu'
os.environ['CUDA_VISIBLE_DEVICES'] = ''
os.environ['XLA_FLAGS'] = '--xla_force_host_platform_device_count=' + str(

System info (python version, jaxlib version, accelerator, etc.)

jax:    0.4.28
jaxlib: 0.4.28
numpy:  1.26.4
python: 3.10.12 (main, Nov 20 2023, 15:14:05) [GCC 11.4.0]
jax.devices (16 total, 16 local): [CpuDevice(id=0) CpuDevice(id=1) ... CpuDevice(id=14) CpuDevice(id=15)]
process_count: 1
platform: uname_result(system='Linux', node='75k', release='6.2.0-39-generic', version='#40~22.04.1-Ubuntu SMP PREEMPT_DYNAMIC Thu Nov 16 10:53:04 UTC 2', machine='x86_64')

$ nvidia-smi
Tue Jun 18 13:41:13 2024       
| NVIDIA-SMI 530.30.02              Driver Version: 530.30.02    CUDA Version: 12.1     |
| GPU  Name                  Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf            Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  NVIDIA GeForce RTX 4090         Off| 00000000:01:00.0 Off |                  Off |
|  0%   39C    P8               18W / 450W|    601MiB / 24564MiB |      2%      Default |
|                                         |                      |                  N/A |

| Processes:                                                                            |
|  GPU   GI   CI        PID   Type   Process name                            GPU Memory |
|        ID   ID                                                             Usage      |
|    0   N/A  N/A      1534      G   /usr/lib/xorg/Xorg                          202MiB |
|    0   N/A  N/A      2531      G   /usr/bin/gnome-shell                         80MiB |
|    0   N/A  N/A      3150      G   ...irefox/1635/usr/lib/firefox/firefox      139MiB |
|    0   N/A  N/A      7585      G   ...sion,SpareRendererForSitePerProcess      175MiB |
hawkinsp commented 2 months ago

Can you try adding a call to .block_until_ready() after each pmap call on one of the outputs? 32 sounds like the size of one of our internal queues that limits how far ahead we can enqueue computations. A possible explanation is that the very first computation you run isn't terminating for some reason, and if you add block_until_ready() then you should see that behavior right away.

(As to why it isn't terminating? I don't know without a repro, but my guess is it's a bug in your code.)

LeonEricsson commented 2 months ago

Thanks for the hasty response.

Note this is how selfplay is called:

while True:
        # Log to tensorboard
        for k, v in log.items():
            writer.add_scalar(f'{k}', v, iteration)

        if iteration >= config.max_num_iters:

        log = {}

        # Selfplay
        key, subkey = jax.random.split(key)
        keys = jax.random.split(subkey, num_cpu_devices)

        env_state, samples = selfplay(

                'stats/samples_per_sec': SAMPLES_PER_ITERATION / time_sample_generation,

        iteration += 1

Adding the block does seem to reveal the problem. The selfplay calls are extremely slow, and it takes probably a good 15 minutes until we reach the 32nd iteration. Given your suggestion of a internal queue, this makes a lot of sense. I've failed to consider the asynchronous dispatch, and was fooled by the print statements!

Thank you again.