jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.65k stars 2.82k forks source link

Slowed performance as Training Progresses #14602

Open wbrenton opened 1 year ago

wbrenton commented 1 year ago

Description

I'm writing an open source MuZero for continuous action spaces with jax, haiku, mctx, and brax.

As training progresses the wall time per epoch increases 3x and I'm having a very difficult time tracking down the source of this issue.

epoch: 10, train steps: 100 loss: 1.8471885919570923, reward: -7.4807448387146 time: 0:12:29.247406 test_reward: -92.07862854003906
epoch: 20, train steps: 200 loss: 1.870937705039978, reward: -8.306370735168457 time: 0:09:58.730033 test_reward: -74.05025482177734
epoch: 30, train steps: 300 loss: 1.7748759984970093, reward: -6.956284523010254 time: 0:11:44.268393 test_reward: -61.988037109375
epoch: 40, train steps: 400 loss: 1.7103341817855835, reward: -6.363528728485107 time: 0:28:23.637063 test_reward: -52.56919860839844
epoch: 50, train steps: 500 loss: 1.7735674381256104, reward: -5.5273051261901855 time: 0:12:33.038024 test_reward: -3.968066692352295
epoch: 60, train steps: 600 loss: 1.7565151453018188, reward: -4.69722843170166 time: 0:09:27.737378 test_reward: -34.667877197265625
epoch: 70, train steps: 700 loss: 1.5296885967254639, reward: -4.434412956237793 time: 0:21:49.107402 test_reward: 17.13837242126465
epoch: 80, train steps: 800 loss: 0.9052988886833191, reward: -0.5253445506095886 time: 0:36:09.626553 test_reward: -79.37960052490234
epoch: 90, train steps: 900 loss: 1.2583162784576416, reward: -7.548539161682129 time: 0:33:14.198018 test_reward: -140.06710815429688
epoch: 100, train steps: 1000 loss: 1.2938309907913208, reward: -14.065295219421387 time: 0:35:49.456317 test_reward: -17.377328872680664
epoch: 110, train steps: 1100 loss: 1.5919981002807617, reward: -6.678733825683594 time: 0:35:29.829538 test_reward: 5.407070159912109
epoch: 120, train steps: 1200 loss: 0.5133532881736755, reward: 2.831914186477661 time: 0:37:31.921805 test_reward: 60.34507751464844

I'm currently executing it on Cloud TPU, see quick_tpu to reproduce my setup.

Any help is greatly appreciated. Will update with info from jax.profiler once I figure out how to use it :D

What jax/jaxlib version are you using?

latest

Which accelerator(s) are you using?

TPU

Additional system info

Linux

NVIDIA GPU info

No response

apaszke commented 1 year ago

I don't seem to have access to your repository. Is it private? Also, we might not have time to go hunting for potential bugs in large codebases, so it would be of great help to us if you could try to minimize the example and provide us with a small self-contained snippet.

wbrenton commented 1 year ago

Sorry about that, it is now public. I will try to create a smaller reproduction of the bug, however, do to its nature not appearing until several hours into training, that will be difficult.

My best guess is that the bug is something similar to using jit inside of pmap and the performance issues that causes. Are there any similar known bugs that occur with a certain ordering of jax primitives?

I've tried to emulate the training scaffolding used in brax's ppo implementation to try have a starting point that I know has good performance and is free of pmap(jit) type bugs

I will continue to try and debug on my own. My current plan is to profile the function the executes a training epoch and see if I can identify the part of the code that has the greatest delta in duration from epoch to epoch. Do you have any thoughts on that approach? Open to any suggestions?

If you or someone else does find the time, the algorithm is in aggregate about 500 lines so it's not to bad. I think a open source MuZero with performance matching that in the paper would be extremely beneficial to the community. If you do dig into the code let me know and I can write a walkthrough in the readme that should aid in getting up to speed.

ajweiss commented 1 year ago

have you tried printing the dims of train_state, env_states and train_metrics in your diagnostic output? just by glancing quickly at the code it looks like they might be growing unbounded and may be getting completely recomputed 1..N at each round (with a step up in run times occurring when resources run low). just a thought, hope it helps.

wbrenton commented 1 year ago

@ajweiss Do appreciate the input. The shapes are static each round. What in particular led you to this conclusion? I could be missing something