Open wbrenton opened 1 year ago
I don't seem to have access to your repository. Is it private? Also, we might not have time to go hunting for potential bugs in large codebases, so it would be of great help to us if you could try to minimize the example and provide us with a small self-contained snippet.
Sorry about that, it is now public. I will try to create a smaller reproduction of the bug, however, do to its nature not appearing until several hours into training, that will be difficult.
My best guess is that the bug is something similar to using jit inside of pmap and the performance issues that causes. Are there any similar known bugs that occur with a certain ordering of jax primitives?
I've tried to emulate the training scaffolding used in brax's ppo implementation to try have a starting point that I know has good performance and is free of pmap(jit) type bugs
I will continue to try and debug on my own. My current plan is to profile the function the executes a training epoch and see if I can identify the part of the code that has the greatest delta in duration from epoch to epoch. Do you have any thoughts on that approach? Open to any suggestions?
If you or someone else does find the time, the algorithm is in aggregate about 500 lines so it's not to bad. I think a open source MuZero with performance matching that in the paper would be extremely beneficial to the community. If you do dig into the code let me know and I can write a walkthrough in the readme that should aid in getting up to speed.
have you tried printing the dims of train_state
, env_states
and train_metrics
in your diagnostic output? just by glancing quickly at the code it looks like they might be growing unbounded and may be getting completely recomputed 1..N at each round (with a step up in run times occurring when resources run low). just a thought, hope it helps.
@ajweiss Do appreciate the input. The shapes are static each round. What in particular led you to this conclusion? I could be missing something
Description
I'm writing an open source MuZero for continuous action spaces with jax, haiku, mctx, and brax.
As training progresses the wall time per epoch increases 3x and I'm having a very difficult time tracking down the source of this issue.
I'm currently executing it on Cloud TPU, see quick_tpu to reproduce my setup.
Any help is greatly appreciated. Will update with info from jax.profiler once I figure out how to use it :D
What jax/jaxlib version are you using?
latest
Which accelerator(s) are you using?
TPU
Additional system info
Linux
NVIDIA GPU info
No response