lowrollr / turbozero

fast + parallel AlphaZero in JAX
Apache License 2.0
79 stars 6 forks source link

bug #8

Closed Nightbringers closed 7 months ago

Nightbringers commented 7 months ago

It seems root_metadata not used in update_root, parameter are mismatch in mcts.py.

line 81 eval_state = self.update_root(eval_state, env_state, root_metadata, params)

line 97

def update_root(self, tree: MCTSTree, root_embedding: chex.ArrayTree, 
                    params: chex.ArrayTree, **kwargs) -> MCTSTree:

        key, tree = get_rng(tree)
        root_policy_logits, root_value = self.eval_fn(root_embedding, params, key)
        root_policy = jax.nn.softmax(root_policy_logits)
        root_node = tree.at(tree.ROOT_INDEX)
        root_node = self.update_root_node(root_node, root_policy, root_value, root_embedding)
        return set_root(tree, root_node)
lowrollr commented 7 months ago

Thank you for pointing this out!

All examples currently use the AlphaZero class, which implements update_root with the correct parameter ordering.

for MCTS parameter ordering is indeed incorrect

lowrollr commented 7 months ago

https://github.com/lowrollr/turbozero/commit/4e55c7291576966a5acd75df75ce34e4422d4e63

Nightbringers commented 7 months ago

Can you give a example that how to train use mult gpus?

I'm try to integrate with my custom JAX environment and have some problem. this is the key code:

def step_fn(state, action):
    new_state = env.step(state, action)
    return new_state, StepMetadata(
        rewards=new_state.rewards,
        action_mask=new_state.legal_action_mask,
        terminated=new_state.terminated,
        cur_player_id=new_state.current_player,
    )

def init_fn(key):
    state = env.init(key)
    return state, StepMetadata(
        rewards=state.rewards,
        action_mask=state.legal_action_mask,
        terminated=state.terminated,
        cur_player_id=state.current_player,
    )
az_evaluator = AlphaZero(MCTS)(
    eval_fn = eval_fn,
    num_iterations = 100,
    max_nodes = 200,
    branching_factor=82,
    action_selector = PUCTSelector()
)
def env_step_fn(state, action):
    new_state = env.step(state, action)
    return new_state, StepMetadata(
        rewards=new_state.rewards,
        action_mask=new_state.legal_action_mask,
        terminated=new_state.terminated,
        cur_player_id=new_state.current_player,
    )

eval_key, rng_key = jax.random.split(rng_key)
eval_keys = jax.random.split(eval_key, batch_size)
env, env_state_metadata = jax.vmap(init_fn)(eval_keys)`
evaluator_init = partial(az_evaluator.init, template_embedding=env)
eval_state = jax.vmap(evaluator_init)(eval_keys)
output = az_evaluator.evaluate(
            eval_state=eval_state,
            env_state=env,
            root_metadata=env_state_metadata,
            params=param,
            env_step_fn=env_step_fn
        )

this is the error:

eval_state = self.update_root(eval_state, env_state, root_metadata, params) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/turbozero/core/evaluators/alphazero.py", line 31, in update_root key, tree = get_rng(tree) ^^^^^^^^^^^^^ File "/turbozero/core/trees/tree.py", line 143, in get_rng rng, new_rng = jax.random.split(tree.key, 2) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/lib/python3.11/site-packages/jax/_src/random.py", line 303, in split return _return_prng_keys(wrapped, _split(typed_key, num)) ^^^^^^^^^^^^^^^^^^^^^^ File "/python3.11/site-packages/jax/_src/random.py", line 286, in _split raise TypeError("split accepts a single key, but was given a key array of" TypeError: split accepts a single key, but was given a key array ofshape (100,) != (). Use jax.vmap for batching.

lowrollr commented 7 months ago

evaluate should be vmapped, try something like this:

eval_key, env_key, rng_key = jax.random.split(rng_key, 3)
eval_keys = jax.random.split(eval_key, batch_size)
env_keys = jax.random.split(env_key, batch_size)

env_state, env_state_metadata = jax.vmap(init_fn)(env_keys)

# template embedding should not have a batch dimension
template_env_state, _ = init_fn(jax.random.PRNGKey(0))
evaluator_init = partial(az_evaluator.init, template_embedding=template_env_state)
eval_state = jax.vmap(evaluator_init)(eval_keys)

evaluate = partial(az_evaluator.evaluate, 
                   env_step_fn=env_step_fn,
                   params=param)

output = jax.vmap(evaluate)(
        eval_state=eval_state,
        env_state=env_state,
        root_metadata=env_state_metadata)

I recommend using the Trainer class as described here, https://github.com/lowrollr/turbozero/blob/main/notebooks/hello_world.ipynb

I haven't fully documented a lot of the underlying classes yet which do have their peculiarities -- Trainer should be more straightforward to work with.

Nightbringers commented 7 months ago

thanks,now have a new problem.

File "/turbozero/core/evaluators/mcts/mcts.py", line 81, in evaluate eval_state = self.update_root(eval_state, env_state, root_metadata, params) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/turbozero/core/evaluators/alphazero.py", line 56, in update_root return set_root(tree, root_node) ^^^^^^^^^^^^^^^^^^^^^^^^^ File "/turbozero/core/trees/tree.py", line 121, in set_root data=jax.tree_util.tree_map( ^^^^^^^^^^^^^^^^^^^^^^^ File "/turbozero/core/trees/tree.py", line 122, in lambda x, y: x.at[tree.ROOT_INDEX].set(y), ^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/lib/python3.11/site-packages/jax/_src/numpy/array_methods.py", line 497, in set return scatter._scatter_update(self.array, self.index, values, lax.scatter, ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/lib/python3.11/site-packages/jax/_src/ops/scatter.py", line 80, in _scatter_update return _scatter_impl(x, y, scatter_op, treedef, static_idx, dynamic_idx, ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/lib/python3.11/site-packages/jax/_src/ops/scatter.py", line 115, in _scatter_impl y = jnp.broadcast_to(y, tuple(indexer.slice_shape)) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py", line 1227, in broadcast_to return util._broadcast_to(array, shape) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "//lib/python3.11/site-packages/jax/_src/numpy/util.py", line 425, in _broadcast_to for arr_d, shape_d in safe_zip(arr_shape, shape_tail)) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ ValueError: safe_zip() argument 2 is shorter than argument 1

lowrollr commented 7 months ago

Could you share your code?

Nightbringers commented 7 months ago
def one_step(prev):
        """Execute one self-play move using MCTS.
        """
        env_state, rng_key, step, eval_state = prev
        rng_key, rng_key_next = jax.random.split(rng_key, 2)
        env_state_metadata = StepMetadata(
            rewards=env_state.rewards,
            action_mask=env_state.legal_action_mask,
            terminated=env_state.terminated,
            cur_player_id=env_state.current_player,
        )
        terminated = env_state.terminated

        output = jax.vmap(evaluate)(
            eval_state=eval_state,
            env_state=env_state,
            root_metadata=env_state_metadata)

        env_state = step_fn_move(env_state, output.action)

        eval_state = output.eval_state
        eval_state = az_evaluator.step(eval_state, output.action)

        return (env_state, rng_key_next, step + 1, env_new3, eval_state)
Nightbringers commented 7 months ago

it use jax.lax.scan:

output = jax.lax.scan(one_step, (env_state, rng_key, step,eval_state), None, length=400, unroll=1)

maybe this place is incorrect?

eval_state = output.eval_state
eval_state = az_evaluator.step(eval_state, output.action)
lowrollr commented 7 months ago

The error suggests to me that some of the data might have been passed to evaluate without a batch dimension, just from debugging similar errors before.

Trainer already implements the necessary functions to collect episodes and progress the env and evaluator to the next state, I'm curious why you are implementing these yourself? Is there a feature you need that's missing or something that's confusing or unclear? I'm hopeful that most users won't need to provide anything besides environment dynamics functions.

Nightbringers commented 7 months ago

yes,there a feature i need that's missing, like go self-play data process, it need symmetries.

lowrollr commented 7 months ago

Could you describe how the feature should work? I can help you find the right spot to integrate it.

Nightbringers commented 7 months ago

And i need use mult gpus, I don't see something like pmap in code. Need a example that how to train use mult gpus.

lowrollr commented 7 months ago

Supporting multiple GPUs should be straightforward to add but is not currently included. I can work on this and let you know when it is supported. Created a separate issue to track.

Nightbringers commented 7 months ago

Could you describe how the feature should work? I can help you find the right spot to integrate it.

in go game, self-play data usually augmentation by something like np.rot90 and np.fliplr.

Nightbringers commented 7 months ago

about that error, could be this place need vmap?

eval_state = output.eval_state
eval_state = az_evaluator.step(eval_state, output.action)
lowrollr commented 7 months ago

about that error, could be this place need vmap?

eval_state = output.eval_state
eval_state = az_evaluator.step(eval_state, output.action)

Yes, every function that operates on an evaluator state assumes a singular input rather than a batch and should be vmapped.

Could you describe how the feature should work? I can help you find the right spot to integrate it.

in go game, self-play data usually augmentation by something like np.rot90 and np.fliplr.

If I were you, I would implement a custom class extending Trainer, and overwrite collect to place each of the symmetries into the replay buffer with self.memory_buffer.add_experience. You will need to make sure policy_mask and policy_weights are augmented consistently with your game state data.

here: https://github.com/lowrollr/turbozero/blob/main/core/training/train.py#L103-L140

You should only need to extend the behavior of collect.

It might be a good idea for me to allow for users to specify any number of transforms to apply to augment experiences prior to storing in replay memory -- this is a fairly common use-case so it ideally should not require a custom class.

lowrollr commented 7 months ago

If I were you, I would implement a custom class extending Trainer, and overwrite collect to place each of the symmetries into the replay buffer with self.memory_buffer.add_experience. You will need to make sure policy_mask and policy_weights are augmented consistently with your game state data.

class GoTrainer(Trainer):
   def collect(self,
        state: CollectionState,
        params: chex.ArrayTree
    ) -> CollectionState:
        step_key, new_key = jax.random.split(state.key)
        eval_output, new_env_state, new_metadata, terminated, rewards = \
            self.step_train(
                key = step_key,
                env_state = state.env_state,
                env_state_metadata = state.metadata,
                eval_state = state.eval_state,
                params = params
            )
        buffer_state = self.memory_buffer.add_experience(
            state = state.buffer_state,
            experience = BaseExperience(
                env_state=state.env_state,
                policy_mask=state.metadata.action_mask,
                policy_weights=eval_output.policy_weights,
                reward=jnp.empty_like(state.metadata.rewards)
            )
        )

       # generate symmetries here and add to replay memory just like above

        buffer_state = jax.lax.cond(
            terminated,
            lambda s: self.memory_buffer.assign_rewards(s, rewards),
            lambda s: s,
            buffer_state
        )

        return state.replace(
            key=new_key,
            eval_state=eval_output.eval_state,
            env_state=new_env_state,
            buffer_state=buffer_state,
            metadata=new_metadata
        )
Nightbringers commented 7 months ago

yes, I'm going to try this, but if can't use mult gpus, training will be slow.

I change to this:

eval_state = output.eval_state eval_state = jax.vmap(az_evaluator.step)(eval_state, output.action)

But error occurred before the code reached that. that error is in this:

output = jax.vmap(evaluate)(
            eval_state=eval_state,
            env_state=env_state,
            root_metadata=env_state_metadata)

and i print(state.observation.shape) in eval_fn, it shows (19, 19, 17) , without batch dimension, is this normal?

Nightbringers commented 7 months ago

I seems find why this error happened.

lowrollr commented 7 months ago

I am going to close this issue as the original problem has been resolved.

I hope to continue to expand the documentation of this project so that it is more clear as to how to approach unique use-cases.

Thank you for your questions and feedback! Please create another issue if you run in to problems and feel free to email me if you have more questions.

Nightbringers commented 7 months ago

The speed that I test was slow. Then I test https://github.com/lowrollr/turbozero/blob/main/notebooks/hello_world.ipynb this example, seems also slow from gpu utilization.

what is max_nodes means? I found that max_nodes has a significant impact on speed; the larger max_nodes is, the slower the speed. mctx-az is also in this situation. and mctx-az is faster than turbozero that in my test.

lowrollr commented 7 months ago

max_nodes reflects the maximum capacity of the tree. Trees cannot be sized dynamically so must have a maximum number of nodes set prior to collection. It makes sense that performance gets worse as max_nodes increases, this ultimately makes it so operations are on larger matrices.

I am aware that the backend is currently less performant than mctx. It is a priority for me to fix this.

Nightbringers commented 7 months ago

I'm still confuse about max_nodes. if max_nodes=100, is it means every node have 100 child node? or is it means that will search 100 node at most in this evaluate?
why you suggest it larger than num_simulations?that will be very very slow.

yes, I look forward to it running faster.

Nightbringers commented 7 months ago

I think the most factor affecting speed lies in the computational aspect of the CPU. GPU not fully utilized.

lowrollr commented 7 months ago

max_nodes refers to the maximum capacity of the tree -- so yes it does mean a tree with max_nodes=100 will at most evaluate 100 distinct game states.

I advise setting it higher than num_iterations because this implementation re-uses subtrees from a previous search -- so most of the time the tree will already be partially populated when a new search is started. Setting max_nodes higher than num_iterations means that there will be room for num_iterations nodes in the tree more often. I have yet to document out-of-bounds behavior but it works similarly to https://github.com/lowrollr/mctx-az.

Increasing max_nodes linearly increases the memory footprint of the search tree data structure. It's definitely a trade-off of speed vs. accuracy to set it higher/lower, and its value relative to 'num_iterations' should be problem-dependent (branching factor and # of iterations both matter). Setting max_nodes = num_iterations is fine if you're worried about speed.

lowrollr commented 7 months ago

I think the most factor affecting speed lies in the computational aspect of the CPU. GPU not fully utilized.

Do you have evidence of this? I'm not aware of any CPU-bound portion of the training loop and in my experiments I've had no issues with GPU utilization.

Nightbringers commented 7 months ago

In my experiments, keep num_simulations unchange, when max_nodes = 32, it cost 460 seconds, when max_nodes = 600, it cost 2200 seconds,gpu Pwr:Usage/Cap is more lower than when max_nodes = 32. Test with mctx-az. Turbozero is same situation.
In this situation, because num_simulations unchange, so the computational workload of the GPU remains unchanged, so problem is in cpu?

lowrollr commented 7 months ago

Some more details could be useful here.

What are you running, one call to search? What is num_simulations set to? What environment are you using?

I'm not sure this is entirely unexpected behavior.

Computational workload is higher when max_nodes is increased. Search operates on tensors of size [num_batches, max_nodes, ... ]

Nightbringers commented 7 months ago

environment: go 19*19 model_size: just like alphazero paper, 40 block, 256 channl num_simulations: 128 batch size: 50 per device step: 410

There is another strange phenomenon: as the number of steps increases, the gpu utilization rate becomes lower and lower

lowrollr commented 7 months ago

Thank you for letting me know, I will look into what you are describing to see if I can replicate it and diagnose.

Are you running more than one call to search (MCTS.evaluate)? Some of the weird behavior you are describing could be down to JIT-compilation overhead on the first call.

Nightbringers commented 7 months ago

It's use jax.lax.scan,like this: jax.lax.scan(one_step, (env_state, rng_key, step,eval_state), None, length=410, unroll=1)

JIT-compilation should only affect one times. environment: go 19*19 model_size: just like alphazero paper, 40 block, 256 channl num_simulations: 128 batch size: 50 per device step: 410 when max_nodes = 32, it cost 460 seconds, when max_nodes = 600, it cost 2200 seconds, gpu Pwr:Usage/Cap is more lower than when max_nodes = 32. Test with mctx-az. Turbozero is same situation.

lowrollr commented 7 months ago

Thanks for pointing this out, I will see if I can replicate.

Nightbringers commented 7 months ago

Are you familiar with the MuZero algorithm? I have some questions and hope you can help me.

lowrollr commented 7 months ago

I haven't worked with MuZero specifically as much but have read the paper. Feel free to send me an email with your questions and I'll see if I can answer.

I still plan on looking into the issues you mention but have been very busy and have not had a chance yet.