google / brax

Massively parallel rigidbody physics simulation on accelerator hardware.
Apache License 2.0
2.14k stars 234 forks source link

assert_is_replicated in Analytic policy gradients training #328

Open wangyian-me opened 1 year ago

wangyian-me commented 1 year ago

When I try to use a 4-gpus machine to run the Analytic policy gradients training in parallel, it reports an AssertionError in brax/training/agents/apg/train.py line 255. Seems that it is because training_state becomes different on the devices while it should be replicated.

I only make minimum change according to the example training code.

import functools

from datetime import datetime
# from brax.training.agents.apg.train import train as apgtrain
from train import train as apgtrain
from brax import envs

env_name = 'humanoidstandup'  # @param ['ant', 'halfcheetah', 'hopper', 'humanoid', 'humanoidstandup', 'inverted_pendulum', 'inverted_double_pendulum', 'pusher', 'reacher', 'walker2d']
backend = 'generalized'  # @param ['generalized', 'positional', 'spring']

env = envs.get_environment(env_name=env_name,
                           backend=backend)

train_fn = {
  'humanoidstandup': functools.partial(apgtrain, episode_length=320,
          action_repeat=1,
          num_envs=16,
          num_eval_envs=4,
          learning_rate = 1e-4,
          seed = 0,
          max_gradient_norm = 1e8,
          num_evals = 10,
          normalize_observations = True,
          deterministic_eval = False)
}[env_name]

xdata, ydata = [], []
times = [datetime.now()]

def progress(num_steps, metrics):
  times.append(datetime.now())
  print(num_steps, metrics['eval/episode_reward'])

print("begin")

make_inference_fn, params, _ = train_fn(environment=env, progress_fn=progress)

print("end")

To make the error comes sooner, I add pmap.assert_is_replicated(training_state) in the iteration of brax/training/agents/apg/train.py.

  for it in range(num_evals_after_init):
    logging.info('starting iteration %s %s', it, time.time() - xt)

    # optimization
    epoch_key, local_key = jax.random.split(local_key)
    epoch_keys = jax.random.split(epoch_key, local_devices_to_use)
    (training_state,
     training_metrics) = training_epoch_with_timing(training_state, epoch_keys)
    ######################## I add it here #############################
    pmap.assert_is_replicated(training_state)
    ####################################################################
    if process_id == 0:
      # Run evals.
      metrics = evaluator.run_evaluation(
          _unpmap(
              (training_state.normalizer_params, training_state.policy_params)),
          training_metrics)
      logging.info(metrics)
      progress_fn(it + 1, metrics)

And the full output is:

begin
0 2238.8042
1 2367.4116
Traceback (most recent call last):
  File "xxxxx.py", line 36, in <module>
    make_inference_fn, params, _ = train_fn(environment=env, progress_fn=progress)
  File "/home/vipuser/playbrax/train.py", line 227, in train
    pmap.assert_is_replicated(training_state)
  File "/home/vipuser/playbrax/brax/brax/training/pmap.py", line 70, in assert_is_replicated
    assert jax.pmap(f, axis_name='i')(x)[0], debug
AssertionError: None

If I use from brax.training.agents.apg.train import train as apgtrain, the full output will become:


begin                                                                                                                   
0 2233.8481
1 2273.3516
2 2460.1377
3 2319.5432
4 2250.9502
5 2289.2446                                                                                                             
6 nan
7 nan
8 nan
9 nan
Traceback (most recent call last):
  File "xxxxx.py", line 36, in <module>
    make_inference_fn, params, _ = train_fn(environment=env, progress_fn=progress)
  File "/home/vipuser/playbrax/train.py", line 227, in train
    pmap.assert_is_replicated(training_state)
  File "/home/vipuser/playbrax/brax/brax/training/pmap.py", line 255, in assert_is_replicated
    assert jax.pmap(f, axis_name='i')(x)[0], debug
AssertionError: None                                                                                                                   
wangyian-me commented 1 year ago

It would report the same bug even when I use only one gpu. Also, I got this warning /home/vipuser/miniconda3/envs/brax/lib/python3.8/site-packages/flax/core/frozen_dict.py:169: FutureWarning: jax.tree_util.register_keypaths is deprecated, and will be removed in a future release. Please use register_pytree_with_keys() instead. jax.tree_util.register_keypaths(. I don't know if it is relevant.

wangyian-me commented 1 year ago

I just realized that it might because some elements are nan and nan == nan is false. Then the replicated judgement might return false.

btaba commented 1 year ago

Hi @wangyian-me, indeed when assert_is_replicated fails it's usually because of a NaN in training. So it looks like humanoidstandup trained with APG causes a NaN?

wangyian-me commented 1 year ago

Yeah, it'll happen when I use the "generalized" backend. I've also tried to use the "positional" backend, which will work without this bug.

wangyian-me commented 1 year ago

Also, I've tried to locate the position where NaN is made. It's after this line. So, I guess the gradient might explode in the back propagation process with "generalized" backend. @btaba

queenxy commented 1 year ago

I meet the same question when I use ppo, there is NaNs but i don't know how to locate it. could you please help me?

btaba commented 1 year ago

@queenxy are you getting NaNs on humanoidstandup with PPO with the generalized backend (and on which device)? Afaik this was tested on TPU, but would be good to know Thanks @wangyian-me for confirming, we'll have to debug. But if you have some time, feel free to dig deeper

queenxy commented 1 year ago

I am getting NaNs on my own environment with the PPO provided by brax. The device is GPU (both multi and single will lead to this question). I have checked my environment but there seems to be nothing wrong. So I am trying to locate whether the NaN is made in PPO. @btaba