Open wangyian-me opened 1 year ago
It would report the same bug even when I use only one gpu.
Also, I got this warning /home/vipuser/miniconda3/envs/brax/lib/python3.8/site-packages/flax/core/frozen_dict.py:169: FutureWarning: jax.tree_util.register_keypaths is deprecated, and will be removed in a future release. Please use register_pytree_with_keys() instead. jax.tree_util.register_keypaths(
. I don't know if it is relevant.
I just realized that it might because some elements are nan and nan == nan
is false. Then the replicated judgement might return false.
Hi @wangyian-me, indeed when assert_is_replicated fails it's usually because of a NaN in training. So it looks like humanoidstandup trained with APG causes a NaN?
Yeah, it'll happen when I use the "generalized" backend. I've also tried to use the "positional" backend, which will work without this bug.
Also, I've tried to locate the position where NaN is made. It's after this line. So, I guess the gradient might explode in the back propagation process with "generalized" backend. @btaba
I meet the same question when I use ppo, there is NaNs but i don't know how to locate it. could you please help me?
@queenxy are you getting NaNs on humanoidstandup with PPO with the generalized backend (and on which device)? Afaik this was tested on TPU, but would be good to know Thanks @wangyian-me for confirming, we'll have to debug. But if you have some time, feel free to dig deeper
I am getting NaNs on my own environment with the PPO provided by brax. The device is GPU (both multi and single will lead to this question). I have checked my environment but there seems to be nothing wrong. So I am trying to locate whether the NaN is made in PPO. @btaba
When I try to use a 4-gpus machine to run the Analytic policy gradients training in parallel, it reports an AssertionError in
brax/training/agents/apg/train.py
line 255. Seems that it is becausetraining_state
becomes different on the devices while it should be replicated.I only make minimum change according to the example training code.
To make the error comes sooner, I add
pmap.assert_is_replicated(training_state)
in the iteration ofbrax/training/agents/apg/train.py
.And the full output is:
If I use
from brax.training.agents.apg.train import train as apgtrain
, the full output will become: