Closed Nightbringers closed 7 months ago
Thank you for pointing this out!
All examples currently use the AlphaZero
class, which implements update_root
with the correct parameter ordering.
for MCTS
parameter ordering is indeed incorrect
Can you give a example that how to train use mult gpus?
I'm try to integrate with my custom JAX environment and have some problem. this is the key code:
def step_fn(state, action):
new_state = env.step(state, action)
return new_state, StepMetadata(
rewards=new_state.rewards,
action_mask=new_state.legal_action_mask,
terminated=new_state.terminated,
cur_player_id=new_state.current_player,
)
def init_fn(key):
state = env.init(key)
return state, StepMetadata(
rewards=state.rewards,
action_mask=state.legal_action_mask,
terminated=state.terminated,
cur_player_id=state.current_player,
)
az_evaluator = AlphaZero(MCTS)(
eval_fn = eval_fn,
num_iterations = 100,
max_nodes = 200,
branching_factor=82,
action_selector = PUCTSelector()
)
def env_step_fn(state, action):
new_state = env.step(state, action)
return new_state, StepMetadata(
rewards=new_state.rewards,
action_mask=new_state.legal_action_mask,
terminated=new_state.terminated,
cur_player_id=new_state.current_player,
)
eval_key, rng_key = jax.random.split(rng_key)
eval_keys = jax.random.split(eval_key, batch_size)
env, env_state_metadata = jax.vmap(init_fn)(eval_keys)`
evaluator_init = partial(az_evaluator.init, template_embedding=env)
eval_state = jax.vmap(evaluator_init)(eval_keys)
output = az_evaluator.evaluate(
eval_state=eval_state,
env_state=env,
root_metadata=env_state_metadata,
params=param,
env_step_fn=env_step_fn
)
this is the error:
eval_state = self.update_root(eval_state, env_state, root_metadata, params) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/turbozero/core/evaluators/alphazero.py", line 31, in update_root key, tree = get_rng(tree) ^^^^^^^^^^^^^ File "/turbozero/core/trees/tree.py", line 143, in get_rng rng, new_rng = jax.random.split(tree.key, 2) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/lib/python3.11/site-packages/jax/_src/random.py", line 303, in split return _return_prng_keys(wrapped, _split(typed_key, num)) ^^^^^^^^^^^^^^^^^^^^^^ File "/python3.11/site-packages/jax/_src/random.py", line 286, in _split raise TypeError("split accepts a single key, but was given a key array of" TypeError: split accepts a single key, but was given a key array ofshape (100,) != (). Use jax.vmap for batching.
evaluate
should be vmapped, try something like this:
eval_key, env_key, rng_key = jax.random.split(rng_key, 3)
eval_keys = jax.random.split(eval_key, batch_size)
env_keys = jax.random.split(env_key, batch_size)
env_state, env_state_metadata = jax.vmap(init_fn)(env_keys)
# template embedding should not have a batch dimension
template_env_state, _ = init_fn(jax.random.PRNGKey(0))
evaluator_init = partial(az_evaluator.init, template_embedding=template_env_state)
eval_state = jax.vmap(evaluator_init)(eval_keys)
evaluate = partial(az_evaluator.evaluate,
env_step_fn=env_step_fn,
params=param)
output = jax.vmap(evaluate)(
eval_state=eval_state,
env_state=env_state,
root_metadata=env_state_metadata)
I recommend using the Trainer
class as described here, https://github.com/lowrollr/turbozero/blob/main/notebooks/hello_world.ipynb
I haven't fully documented a lot of the underlying classes yet which do have their peculiarities -- Trainer
should be more straightforward to work with.
thanks,now have a new problem.
File "/turbozero/core/evaluators/mcts/mcts.py", line 81, in evaluate
eval_state = self.update_root(eval_state, env_state, root_metadata, params)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/turbozero/core/evaluators/alphazero.py", line 56, in update_root
return set_root(tree, root_node)
^^^^^^^^^^^^^^^^^^^^^^^^^
File "/turbozero/core/trees/tree.py", line 121, in set_root
data=jax.tree_util.tree_map(
^^^^^^^^^^^^^^^^^^^^^^^
File "/turbozero/core/trees/tree.py", line 122, in
Could you share your code?
def one_step(prev):
"""Execute one self-play move using MCTS.
"""
env_state, rng_key, step, eval_state = prev
rng_key, rng_key_next = jax.random.split(rng_key, 2)
env_state_metadata = StepMetadata(
rewards=env_state.rewards,
action_mask=env_state.legal_action_mask,
terminated=env_state.terminated,
cur_player_id=env_state.current_player,
)
terminated = env_state.terminated
output = jax.vmap(evaluate)(
eval_state=eval_state,
env_state=env_state,
root_metadata=env_state_metadata)
env_state = step_fn_move(env_state, output.action)
eval_state = output.eval_state
eval_state = az_evaluator.step(eval_state, output.action)
return (env_state, rng_key_next, step + 1, env_new3, eval_state)
it use jax.lax.scan:
output = jax.lax.scan(one_step, (env_state, rng_key, step,eval_state), None, length=400, unroll=1)
maybe this place is incorrect?
eval_state = output.eval_state
eval_state = az_evaluator.step(eval_state, output.action)
The error suggests to me that some of the data might have been passed to evaluate
without a batch dimension, just from debugging similar errors before.
Trainer
already implements the necessary functions to collect episodes and progress the env and evaluator to the next state, I'm curious why you are implementing these yourself? Is there a feature you need that's missing or something that's confusing or unclear? I'm hopeful that most users won't need to provide anything besides environment dynamics functions.
yes,there a feature i need that's missing, like go self-play data process, it need symmetries.
Could you describe how the feature should work? I can help you find the right spot to integrate it.
And i need use mult gpus, I don't see something like pmap in code. Need a example that how to train use mult gpus.
Supporting multiple GPUs should be straightforward to add but is not currently included. I can work on this and let you know when it is supported. Created a separate issue to track.
Could you describe how the feature should work? I can help you find the right spot to integrate it.
in go game, self-play data usually augmentation by something like np.rot90 and np.fliplr.
about that error, could be this place need vmap?
eval_state = output.eval_state
eval_state = az_evaluator.step(eval_state, output.action)
about that error, could be this place need vmap?
eval_state = output.eval_state eval_state = az_evaluator.step(eval_state, output.action)
Yes, every function that operates on an evaluator state assumes a singular input rather than a batch and should be vmapped.
Could you describe how the feature should work? I can help you find the right spot to integrate it.
in go game, self-play data usually augmentation by something like np.rot90 and np.fliplr.
If I were you, I would implement a custom class extending Trainer
, and overwrite collect
to place each of the symmetries into the replay buffer with self.memory_buffer.add_experience
. You will need to make sure policy_mask
and policy_weights
are augmented consistently with your game state data.
here: https://github.com/lowrollr/turbozero/blob/main/core/training/train.py#L103-L140
You should only need to extend the behavior of collect
.
It might be a good idea for me to allow for users to specify any number of transforms to apply to augment experiences prior to storing in replay memory -- this is a fairly common use-case so it ideally should not require a custom class.
If I were you, I would implement a custom class extending
Trainer
, and overwritecollect
to place each of the symmetries into the replay buffer withself.memory_buffer.add_experience
. You will need to make surepolicy_mask
andpolicy_weights
are augmented consistently with your game state data.
class GoTrainer(Trainer):
def collect(self,
state: CollectionState,
params: chex.ArrayTree
) -> CollectionState:
step_key, new_key = jax.random.split(state.key)
eval_output, new_env_state, new_metadata, terminated, rewards = \
self.step_train(
key = step_key,
env_state = state.env_state,
env_state_metadata = state.metadata,
eval_state = state.eval_state,
params = params
)
buffer_state = self.memory_buffer.add_experience(
state = state.buffer_state,
experience = BaseExperience(
env_state=state.env_state,
policy_mask=state.metadata.action_mask,
policy_weights=eval_output.policy_weights,
reward=jnp.empty_like(state.metadata.rewards)
)
)
# generate symmetries here and add to replay memory just like above
buffer_state = jax.lax.cond(
terminated,
lambda s: self.memory_buffer.assign_rewards(s, rewards),
lambda s: s,
buffer_state
)
return state.replace(
key=new_key,
eval_state=eval_output.eval_state,
env_state=new_env_state,
buffer_state=buffer_state,
metadata=new_metadata
)
yes, I'm going to try this, but if can't use mult gpus, training will be slow.
I change to this:
eval_state = output.eval_state eval_state = jax.vmap(az_evaluator.step)(eval_state, output.action)
But error occurred before the code reached that. that error is in this:
output = jax.vmap(evaluate)(
eval_state=eval_state,
env_state=env_state,
root_metadata=env_state_metadata)
and i print(state.observation.shape) in eval_fn, it shows (19, 19, 17) , without batch dimension, is this normal?
I seems find why this error happened.
I am going to close this issue as the original problem has been resolved.
I hope to continue to expand the documentation of this project so that it is more clear as to how to approach unique use-cases.
Thank you for your questions and feedback! Please create another issue if you run in to problems and feel free to email me if you have more questions.
The speed that I test was slow. Then I test https://github.com/lowrollr/turbozero/blob/main/notebooks/hello_world.ipynb this example, seems also slow from gpu utilization.
what is max_nodes means? I found that max_nodes has a significant impact on speed; the larger max_nodes is, the slower the speed. mctx-az is also in this situation. and mctx-az is faster than turbozero that in my test.
max_nodes
reflects the maximum capacity of the tree. Trees cannot be sized dynamically so must have a maximum number of nodes set prior to collection. It makes sense that performance gets worse as max_nodes
increases, this ultimately makes it so operations are on larger matrices.
I am aware that the backend is currently less performant than mctx. It is a priority for me to fix this.
I'm still confuse about max_nodes.
if max_nodes=100, is it means every node have 100 child node? or is it means that will search 100 node at most in this evaluate?
why you suggest it larger than num_simulations?that will be very very slow.
yes, I look forward to it running faster.
I think the most factor affecting speed lies in the computational aspect of the CPU. GPU not fully utilized.
max_nodes
refers to the maximum capacity of the tree -- so yes it does mean a tree with max_nodes=100 will at most evaluate 100 distinct game states.
I advise setting it higher than num_iterations
because this implementation re-uses subtrees from a previous search -- so most of the time the tree will already be partially populated when a new search is started. Setting max_nodes
higher than num_iterations
means that there will be room for num_iterations
nodes in the tree more often. I have yet to document out-of-bounds behavior but it works similarly to https://github.com/lowrollr/mctx-az.
Increasing max_nodes
linearly increases the memory footprint of the search tree data structure. It's definitely a trade-off of speed vs. accuracy to set it higher/lower, and its value relative to 'num_iterations' should be problem-dependent (branching factor and # of iterations both matter). Setting max_nodes
= num_iterations
is fine if you're worried about speed.
I think the most factor affecting speed lies in the computational aspect of the CPU. GPU not fully utilized.
Do you have evidence of this? I'm not aware of any CPU-bound portion of the training loop and in my experiments I've had no issues with GPU utilization.
In my experiments, keep num_simulations unchange, when max_nodes = 32, it cost 460 seconds,
when max_nodes = 600, it cost 2200 seconds,gpu Pwr:Usage/Cap is more lower than when max_nodes = 32. Test with mctx-az. Turbozero is same situation.
In this situation, because num_simulations unchange, so the computational workload of the GPU remains unchanged, so problem is in cpu?
Some more details could be useful here.
What are you running, one call to search?
What is num_simulations
set to?
What environment are you using?
I'm not sure this is entirely unexpected behavior.
Computational workload is higher when max_nodes
is increased. Search operates on tensors of size [num_batches, max_nodes, ... ]
environment: go 19*19 model_size: just like alphazero paper, 40 block, 256 channl num_simulations: 128 batch size: 50 per device step: 410
There is another strange phenomenon: as the number of steps increases, the gpu utilization rate becomes lower and lower
Thank you for letting me know, I will look into what you are describing to see if I can replicate it and diagnose.
Are you running more than one call to search (MCTS.evaluate
)? Some of the weird behavior you are describing could be down to JIT-compilation overhead on the first call.
It's use jax.lax.scan,like this: jax.lax.scan(one_step, (env_state, rng_key, step,eval_state), None, length=410, unroll=1)
JIT-compilation should only affect one times. environment: go 19*19 model_size: just like alphazero paper, 40 block, 256 channl num_simulations: 128 batch size: 50 per device step: 410 when max_nodes = 32, it cost 460 seconds, when max_nodes = 600, it cost 2200 seconds, gpu Pwr:Usage/Cap is more lower than when max_nodes = 32. Test with mctx-az. Turbozero is same situation.
Thanks for pointing this out, I will see if I can replicate.
Are you familiar with the MuZero algorithm? I have some questions and hope you can help me.
I haven't worked with MuZero specifically as much but have read the paper. Feel free to send me an email with your questions and I'll see if I can answer.
I still plan on looking into the issues you mention but have been very busy and have not had a chance yet.
It seems root_metadata not used in update_root, parameter are mismatch in mcts.py.
line 81
eval_state = self.update_root(eval_state, env_state, root_metadata, params)
line 97