google-deepmind / mctx

Monte Carlo tree search in JAX
Apache License 2.0
2.33k stars 189 forks source link

PMAP w/ mctx policy #28

Closed wbrenton closed 2 years ago

wbrenton commented 2 years ago

How would one go about pmap'ing the mctx policy fn of choice when the root (argument to the policy) contains an embedding (arg to the root constructor) that is a data class. I would like to map across the 0-axis in all attributes of the data class, but when specifying 0 in the in-axes argument of pmap (e.g: mctx.gumbel_muzero_policy, in_axes(0, ... )(root, ......) it throws

"ValueError: pmap in_axes specification must be a tree prefix of the corresponding value"

My guess is that specifying an axis to map over in a dataclass where each attr contains arrays (that are the ultimate target to be mapped over) is not supported. Any suggested work arounds?

root = mctx.RootFnOutput(
    prior_logits=policy_logits,
    value=value_scalar,
    embedding=env_state) # env_state is a dataclass 

recurrent_fn = get_recurrent_fn(env, env_state, env_params, net_func)

key, subkey = jx.random.split(key)

policy_output_maml = pmap(mctx.gumbel_muzero_policy, in_axes=(None, None, None, 0, None, None, None))(
    params=net_params,
    max_depth=config.max_search_depth,
    rng_key=subkey,
    root=root,
    recurrent_fn=recurrent_fn,
    num_simulations=config.num_simulations,
    qtransform=partial(
        mctx.qtransform_completed_by_mix_value,
        use_mixed_value=config.use_mixed_value,
        value_scale=config.value_scale,
        rescale_values=True))

Thanks in advance.

fidlej commented 2 years ago
  1. Be aware that in_axes is for positional arguments. It is not used for keyword arguments.
  2. To get the best speed, the pmap should be called on a top level function. For example, if the root policy_logits and value_scalar are produced by a network, put the network call and the mctx.gumbel_muzero_policy call to a new function. And call pmap on the new function.
carlosgmartin commented 1 year ago

@fidlej My understanding is that search (and therefore gumbel_muzero_policy) works with batches only. For example:

from jax import random, numpy as jnp
import mctx

root = mctx.RootFnOutput(
    prior_logits=jnp.zeros(3),
    value=jnp.array(0.),
    embedding=jnp.array(7.),
)

def recurrent_fn(params, key, action, state):
    new_state = state + action - 1
    reward = -state ** 2
    discount = jnp.array(1.)
    value = jnp.array(0.)
    logits = jnp.zeros(3),
    return mctx.RecurrentFnOutput(
        reward=reward,
        discount=discount,
        prior_logits=logits,
        value=value,
    ), new_state

output = mctx.gumbel_muzero_policy(
    params=None,
    rng_key=random.PRNGKey(0),
    root=root,
    recurrent_fn=recurrent_fn,
    num_simulations=100,
)
Traceback (most recent call last):
  File "/Users/carlos/Desktop/example.py", line 23, in <module>
    output = mctx.gumbel_muzero_policy(
  File "/usr/local/lib/python3.10/site-packages/mctx/_src/policies.py", line 189, in gumbel_muzero_policy
    search_tree = search.search(
  File "/usr/local/lib/python3.10/site-packages/mctx/_src/search.py", line 81, in search
    batch_size = root.value.shape[0]
IndexError: tuple index out of range

Is there a reason search was designed to work with batches rather than individuals—that is, letting the user apply vmap externally to the whole process, outside the call to gumbel_muzero_policy, if they want to work with batches? Is there any performance difference between these two approaches? Just want to make sure I'm not missing something.

wbrenton commented 1 year ago

@carlosgmartin I believe the idea behind implementing mctx generally speaking was to offer the research community a batched MCTS that is python native, with emphasis on ease of use (don't have to worry about tying C or C++ to python). It is correct that it works for batches only. You can add a "dummy" batch dimension to "individuals" relatively easily and then vmap that function if that is your desired functionality. I'm not confident in this but I think this would lead to slightly poorer GPU/TPU utilization. @fidlej I'd be interested to hear you comment on that.

@carlosgmartin I had to do what I had outlined above to use mctx inside of the anakin architecture. Would be happy to share that code if you interested. Feel free to email me.

fidlej commented 1 year ago

Thanks for asking. We wanted to support usage with already batched recurrent_fn functions. For example, usually we write a neural network that expects inputs with shape [batch_size, num_features]. Two benefits:

  1. Explicit working with the batch is in some situations faster or more memory efficient than vmap.
  2. Support for batch normalization. (Maybe someone is still using it.)