google / brax

Massively parallel rigidbody physics simulation on accelerator hardware.
Apache License 2.0
2.25k stars 249 forks source link

PPO returns nan with multiple GPU #332

Open Daffan opened 1 year ago

Daffan commented 1 year ago

PPO training returns nan when using multiple GPU. Forcing t use one GPU works fine. I just ran the exactly same code in training code in Brax Training. Can somebody help to try it? Thanks!

btaba commented 1 year ago

@Daffan Thanks for reporting! Which environment are you using and which backend?

Yunkai-Yu commented 6 months ago

Hello, @btaba. I'm following up to inquire about any progress regarding the issue we discussed earlier. In my recent experiments, I've encountered an unexpected problem with PPO returning NaN values after several iterations on my GPUs.

I've included the code snippet below for reference. The code aims to drive a humanoid model to move gradually to different positions one by one. While I understand that optimization convergence might not be achieved immediately, encountering NaN values seems peculiar and warrants investigation.

#!/usr/bin/env python
import time
import os
os.environ['XLA_FLAGS'] = '--xla_gpu_cuda_data_dir=/usr/lib/cuda/'
from datetime import datetime
import functools
import jax
from jax import numpy as jp
import numpy as np
from brax import base
from brax import envs
from brax import actuator
from brax.envs.base import PipelineEnv, State
from brax.training.agents.ppo import train as ppo
from brax.io import mjcf
from etils import epath

import mujoco
from jax import config
config.update("jax_debug_nans", True)

os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION']=".95"
# os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = "false"

print("xla setted")
np.set_printoptions(precision=3, suppress=True, linewidth=100)

class Humanoid(PipelineEnv):

  # pyformat: enable
  def __init__(
      self,
      terminate_when_unhealthy=True,
      reset_noise_scale=1e-2,
      exclude_current_positions_from_observation=True,
      backend='generalized',
      **kwargs,
  ):
    path = epath.resource_path('brax') / 'envs/assets/humanoid.xml'
    sys = mjcf.load(path)
    n_frames = 5

    if backend == 'mjx':
      sys = sys.tree_replace({
          'opt.solver': mujoco.mjtSolver.mjSOL_NEWTON,
          'opt.disableflags': mujoco.mjtDisableBit.mjDSBL_EULERDAMP,
          'opt.iterations': 1,
          'opt.ls_iterations': 4,
      })

    kwargs['n_frames'] = kwargs.get('n_frames', n_frames)

    super().__init__(sys=sys, backend=backend, **kwargs)
    self._terminate_when_unhealthy = terminate_when_unhealthy
    self._reset_noise_scale = reset_noise_scale
    self._exclude_current_positions_from_observation = (
        exclude_current_positions_from_observation
    )
    self.ref_traj = jp.asarray([1.0,2.0,3.0,4.0])
    self.max_err = jp.array(0.5)

  def reset(self, rng: jax.Array) -> State:
    """Resets the environment to an initial state."""
    rng, rng1,rng2 = jax.random.split(rng, 3)

    low, hi = -self._reset_noise_scale, self._reset_noise_scale
    qpos = self.sys.init_q + jax.random.uniform(
        rng1, (self.sys.q_size(),), minval=low, maxval=hi
    )

    qvel = jax.random.uniform(
        rng2, (self.sys.qd_size(),), minval=low, maxval=hi
    )

    pipeline_state = self.pipeline_init(qpos, qvel)

    obs = self._get_obs(pipeline_state, jp.zeros(self.sys.act_size()))
    reward, done, zero = jp.zeros(3)
    metrics = {
        'mse_loss': zero,
        'idx': 0,
        'is_converge': zero,
        'done': zero
    }
    return State(pipeline_state, obs, reward, done, metrics)

  def step(self, state: State, action: jax.Array) -> State:
    """Runs one timestep of the environment's dynamics."""
    data0 = state.pipeline_state
    idx = state.metrics["idx"].astype(int)
    target = self.ref_traj[idx].copy()

    data = self.pipeline_step(data0, action)
    mse_loss = jp.sum(jp.square(data.q[0]-target))

    is_converge = jp.where(mse_loss < self.max_err, 1.0, 0.0)
    idx_new = jp.where(is_converge==1.0, idx+1, idx)

    obs = self._get_obs(data,action)
    done = is_converge

    reward = -mse_loss+is_converge*100
    state.metrics.update(
        mse_loss = mse_loss,
        idx = idx_new,
        is_converge = is_converge,
        done = done
    )
    return state.replace(
        pipeline_state=data, obs=obs, reward=reward, done=done
    )

  def _get_obs(
      self, pipeline_state: base.State, action: jax.Array
  ) -> jax.Array:
    """Observes humanoid body position, velocities, and angles."""
    position = pipeline_state.q
    velocity = pipeline_state.qd

    if self._exclude_current_positions_from_observation:
      position = position[2:]

    com, inertia, mass_sum, x_i = self._com(pipeline_state)
    cinr = x_i.replace(pos=x_i.pos - com).vmap().do(inertia)
    com_inertia = jp.hstack(
        [cinr.i.reshape((cinr.i.shape[0], -1)), inertia.mass[:, None]]
    )

    xd_i = (
        base.Transform.create(pos=x_i.pos - pipeline_state.x.pos)
        .vmap()
        .do(pipeline_state.xd)
    )
    com_vel = inertia.mass[:, None] * xd_i.vel / mass_sum
    com_ang = xd_i.ang
    com_velocity = jp.hstack([com_vel, com_ang])

    qfrc_actuator = actuator.to_tau(
        self.sys, action, pipeline_state.q, pipeline_state.qd)

    # external_contact_forces are excluded
    return jp.concatenate([
        position,
        velocity,
        com_inertia.ravel(),
        com_velocity.ravel(),
        qfrc_actuator,
    ])

  def _com(self, pipeline_state: base.State) -> jax.Array:
    inertia = self.sys.link.inertia

    mass_sum = jp.sum(inertia.mass)
    x_i = pipeline_state.x.vmap().do(inertia.transform)
    com = (
        jp.sum(jax.vmap(jp.multiply)(inertia.mass, x_i.pos), axis=0) / mass_sum
    )
    return com, inertia, mass_sum, x_i  # pytype: disable=bad-return-type  # jax-ndarray

envs.register_environment('humanoid', Humanoid)
print("env registered")
# instantiate the environment
env_name = 'humanoid'
env = envs.get_environment(env_name)

x_data = []
y_data = []
ydataerr = []
times = [datetime.now()]

max_y, min_y = 13000, 0
def progress(num_steps, metrics):
  times.append(datetime.now())
  print(times[-2:],"len: ",len(times))
  print(num_steps)
  print(metrics['eval/episode_reward'])

train_fn = functools.partial(
    ppo.train, num_timesteps=10_00, num_evals=1200, reward_scaling=0.1,
    episode_length=10, normalize_observations=True, action_repeat=1,
    unroll_length=5, num_minibatches=32, num_updates_per_batch=8,
    discounting=0.97, learning_rate=3e-6, entropy_cost=1e-3, num_envs=512,
    batch_size=256, seed=0)

# define the jit reset/step functions
jit_reset = jax.jit(env.reset)
jit_step = jax.jit(env.step)
state = jit_reset(jax.random.PRNGKey(0))
rollout = [state.pipeline_state]

# grab a trajectory
make_inference_fn, params, _= train_fn(environment=env, progress_fn=progress)

print(f'time to jit: {times[1] - times[0]}')
print(f'time to train: {times[-1] - times[1]}')

nan

Thank you for your time.

i1Cps commented 2 months ago

@Yunkai-Yu Any progress made with this?

RuiningLi commented 1 month ago

I encountered the same issue in a customized environment. Any progress on this?