google / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
29.76k stars 2.72k forks source link

pmap race condition (?) #21946

Closed LeonEricsson closed 2 months ago

LeonEricsson commented 2 months ago

Description

I've got a self-play (not important) function that I would like to execute on the CPU, but I'm running into an issue where the execution mysteriously freezes on the 32nd iteration. Now, when I say freeze or halt, I actually don't know if this is a complete freeze or just a immense slow down. The first 31 iterations run in less than a second, and I've waited for 5 minutes without the 32nd iteration completing. Anyway, let me walk through the code and what I've found so far.

This is the selfplay function that is called externally N times, and halts to a stop on exactly the 32nd call (every time):

@partial(jax.pmap, backend='cpu')
def selfplay(
    model,
    rng_key: jnp.ndarray,
    state,
) -> Sample:
    model_params, model_state = model
    batch_size = config.selfplay_batch_size // num_cpu_devices

    def step_fn(state, key):
        key1, key2 = jax.random.split(key)
        observation = state.observation

        (logits, value), _ = forward.apply(
            model_params,
            model_state,
            state.observation,
            is_eval=True,
        )
        root = RootFnOutput(prior_logits=logits, value=value, embedding=state)

        policy_output = modified_gumbel_muzero_policy(
            params=model,
            rng_key=key1,
            root=root,
            recurrent_fn=recurrent_fn,
            num_simulations=config.num_simulations,
            invalid_actions=~state.legal_action_mask,
        )  

        actor = state.current_player
        keys = jax.random.split(key2, batch_size)
        state = jax.vmap(auto_reset(auto_chance(env.step), env.init))(
            state,
            policy_output.action,
            keys,
        )

        return state, observation

I've identified the modified_gumbel_muzero_policy function as the culprit (swapping it out for a dummy initialization fixes the issue). My initial thought was that there was some complex logic in modified_gumbel_muzero_policy that for some (unknown) reason was really resource intense. So I started replacing all the logic step-by-step with dummy initializers to see when the issue fixed itself, and I ended up with this:

def modified_gumbel_muzero_policy(
    params: base.Params,
    rng_key: chex.PRNGKey,
    root: base.RootFnOutput,
    recurrent_fn: base.RecurrentFn,
    num_simulations: int,
    invalid_actions: Optional[chex.Array] = None,
    max_depth: Optional[int] = None,
    loop_fn: base.LoopFn = jax.lax.fori_loop,
    *,
    qtransform: base.QTransform = qtransforms.qtransform_completed_by_mix_value,
    max_num_considered_actions: int = 4,
    gumbel_scale: chex.Numeric = 1.0,
) -> base.PolicyOutputWithValue[action_selection.GumbelMuZeroExtraData]:
    """Runs Gumbel MuZero search and returns the `PolicyOutput`.

    action = jnp.argmax(root.prior_logits, axis=-1)
    action_weights = root.prior_logits

    return base.PolicyOutput(
        action=action,
        action_weights=action_weights,
        search_tree=None,
    )

There's hardly anything left?! I've been able to remove almost all the logic from this function and everything still freezes on the 32nd selfplay call.

Swapping out action and action_weights to dummy initializers does fix the problem:

def modified_gumbel_muzero_policy(
    params: base.Params,
    rng_key: chex.PRNGKey,
    root: base.RootFnOutput,
    recurrent_fn: base.RecurrentFn,
    num_simulations: int,
    invalid_actions: Optional[chex.Array] = None,
    max_depth: Optional[int] = None,
    loop_fn: base.LoopFn = jax.lax.fori_loop,
    *,
    qtransform: base.QTransform = qtransforms.qtransform_completed_by_mix_value,
    max_num_considered_actions: int = 4,
    gumbel_scale: chex.Numeric = 1.0,
) -> base.PolicyOutputWithValue[action_selection.GumbelMuZeroExtraData]:
    """Runs Gumbel MuZero search and returns the `PolicyOutput`.

    # action = jnp.argmax(root.prior_logits, axis=-1)
    # action_weights = root.prior_logits

    action = jnp.zeros(64, dtype=jnp.int32)
    action_weights = jnp.zeros((64, 89))

return base.PolicyOutput(
        action=action,
        action_weights=action_weights,
        search_tree=None,
    )

This runs.

What is going on here? Does anything have any intuition as to why this is happening?

I'm enabling CPU executing across cores using:

USE_N_CORES = 16
os.environ['JAX_PLATFORMS'] = 'cpu'
os.environ['CUDA_VISIBLE_DEVICES'] = ''
os.environ['XLA_FLAGS'] = '--xla_force_host_platform_device_count=' + str(
    USE_N_CORES
)

System info (python version, jaxlib version, accelerator, etc.)

jax:    0.4.28
jaxlib: 0.4.28
numpy:  1.26.4
python: 3.10.12 (main, Nov 20 2023, 15:14:05) [GCC 11.4.0]
jax.devices (16 total, 16 local): [CpuDevice(id=0) CpuDevice(id=1) ... CpuDevice(id=14) CpuDevice(id=15)]
process_count: 1
platform: uname_result(system='Linux', node='75k', release='6.2.0-39-generic', version='#40~22.04.1-Ubuntu SMP PREEMPT_DYNAMIC Thu Nov 16 10:53:04 UTC 2', machine='x86_64')

$ nvidia-smi
Tue Jun 18 13:41:13 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 530.30.02              Driver Version: 530.30.02    CUDA Version: 12.1     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                  Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf            Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|=========================================+======================+======================|
|   0  NVIDIA GeForce RTX 4090         Off| 00000000:01:00.0 Off |                  Off |
|  0%   39C    P8               18W / 450W|    601MiB / 24564MiB |      2%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+

+---------------------------------------------------------------------------------------+
| Processes:                                                                            |
|  GPU   GI   CI        PID   Type   Process name                            GPU Memory |
|        ID   ID                                                             Usage      |
|=======================================================================================|
|    0   N/A  N/A      1534      G   /usr/lib/xorg/Xorg                          202MiB |
|    0   N/A  N/A      2531      G   /usr/bin/gnome-shell                         80MiB |
|    0   N/A  N/A      3150      G   ...irefox/1635/usr/lib/firefox/firefox      139MiB |
|    0   N/A  N/A      7585      G   ...sion,SpareRendererForSitePerProcess      175MiB |
+---------------------------------------------------------------------------------------+
hawkinsp commented 2 months ago

Can you try adding a call to .block_until_ready() after each pmap call on one of the outputs? 32 sounds like the size of one of our internal queues that limits how far ahead we can enqueue computations. A possible explanation is that the very first computation you run isn't terminating for some reason, and if you add block_until_ready() then you should see that behavior right away.

(As to why it isn't terminating? I don't know without a repro, but my guess is it's a bug in your code.)

LeonEricsson commented 2 months ago

Thanks for the hasty response.

Note this is how selfplay is called:

while True:
        # Log to tensorboard
        for k, v in log.items():
            writer.add_scalar(f'{k}', v, iteration)
        writer.flush()

        if iteration >= config.max_num_iters:
            break

        log = {}

        # Selfplay
        key, subkey = jax.random.split(key)
        keys = jax.random.split(subkey, num_cpu_devices)

        env_state, samples = selfplay(
            models,
            keys,
            env_state,
        )
        samples.obs.block_until_ready()

        log.update(
            {
                'stats/samples_per_sec': SAMPLES_PER_ITERATION / time_sample_generation,
            },
        )

        print(iteration)
        iteration += 1

Adding the block does seem to reveal the problem. The selfplay calls are extremely slow, and it takes probably a good 15 minutes until we reach the 32nd iteration. Given your suggestion of a internal queue, this makes a lot of sense. I've failed to consider the asynchronous dispatch, and was fooled by the print statements!

Thank you again.