Open Daffan opened 1 year ago
@Daffan Thanks for reporting! Which environment are you using and which backend?
Hello, @btaba. I'm following up to inquire about any progress regarding the issue we discussed earlier. In my recent experiments, I've encountered an unexpected problem with PPO returning NaN values after several iterations on my GPUs.
I've included the code snippet below for reference. The code aims to drive a humanoid model to move gradually to different positions one by one. While I understand that optimization convergence might not be achieved immediately, encountering NaN values seems peculiar and warrants investigation.
#!/usr/bin/env python
import time
import os
os.environ['XLA_FLAGS'] = '--xla_gpu_cuda_data_dir=/usr/lib/cuda/'
from datetime import datetime
import functools
import jax
from jax import numpy as jp
import numpy as np
from brax import base
from brax import envs
from brax import actuator
from brax.envs.base import PipelineEnv, State
from brax.training.agents.ppo import train as ppo
from brax.io import mjcf
from etils import epath
import mujoco
from jax import config
config.update("jax_debug_nans", True)
os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION']=".95"
# os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = "false"
print("xla setted")
np.set_printoptions(precision=3, suppress=True, linewidth=100)
class Humanoid(PipelineEnv):
# pyformat: enable
def __init__(
self,
terminate_when_unhealthy=True,
reset_noise_scale=1e-2,
exclude_current_positions_from_observation=True,
backend='generalized',
**kwargs,
):
path = epath.resource_path('brax') / 'envs/assets/humanoid.xml'
sys = mjcf.load(path)
n_frames = 5
if backend == 'mjx':
sys = sys.tree_replace({
'opt.solver': mujoco.mjtSolver.mjSOL_NEWTON,
'opt.disableflags': mujoco.mjtDisableBit.mjDSBL_EULERDAMP,
'opt.iterations': 1,
'opt.ls_iterations': 4,
})
kwargs['n_frames'] = kwargs.get('n_frames', n_frames)
super().__init__(sys=sys, backend=backend, **kwargs)
self._terminate_when_unhealthy = terminate_when_unhealthy
self._reset_noise_scale = reset_noise_scale
self._exclude_current_positions_from_observation = (
exclude_current_positions_from_observation
)
self.ref_traj = jp.asarray([1.0,2.0,3.0,4.0])
self.max_err = jp.array(0.5)
def reset(self, rng: jax.Array) -> State:
"""Resets the environment to an initial state."""
rng, rng1,rng2 = jax.random.split(rng, 3)
low, hi = -self._reset_noise_scale, self._reset_noise_scale
qpos = self.sys.init_q + jax.random.uniform(
rng1, (self.sys.q_size(),), minval=low, maxval=hi
)
qvel = jax.random.uniform(
rng2, (self.sys.qd_size(),), minval=low, maxval=hi
)
pipeline_state = self.pipeline_init(qpos, qvel)
obs = self._get_obs(pipeline_state, jp.zeros(self.sys.act_size()))
reward, done, zero = jp.zeros(3)
metrics = {
'mse_loss': zero,
'idx': 0,
'is_converge': zero,
'done': zero
}
return State(pipeline_state, obs, reward, done, metrics)
def step(self, state: State, action: jax.Array) -> State:
"""Runs one timestep of the environment's dynamics."""
data0 = state.pipeline_state
idx = state.metrics["idx"].astype(int)
target = self.ref_traj[idx].copy()
data = self.pipeline_step(data0, action)
mse_loss = jp.sum(jp.square(data.q[0]-target))
is_converge = jp.where(mse_loss < self.max_err, 1.0, 0.0)
idx_new = jp.where(is_converge==1.0, idx+1, idx)
obs = self._get_obs(data,action)
done = is_converge
reward = -mse_loss+is_converge*100
state.metrics.update(
mse_loss = mse_loss,
idx = idx_new,
is_converge = is_converge,
done = done
)
return state.replace(
pipeline_state=data, obs=obs, reward=reward, done=done
)
def _get_obs(
self, pipeline_state: base.State, action: jax.Array
) -> jax.Array:
"""Observes humanoid body position, velocities, and angles."""
position = pipeline_state.q
velocity = pipeline_state.qd
if self._exclude_current_positions_from_observation:
position = position[2:]
com, inertia, mass_sum, x_i = self._com(pipeline_state)
cinr = x_i.replace(pos=x_i.pos - com).vmap().do(inertia)
com_inertia = jp.hstack(
[cinr.i.reshape((cinr.i.shape[0], -1)), inertia.mass[:, None]]
)
xd_i = (
base.Transform.create(pos=x_i.pos - pipeline_state.x.pos)
.vmap()
.do(pipeline_state.xd)
)
com_vel = inertia.mass[:, None] * xd_i.vel / mass_sum
com_ang = xd_i.ang
com_velocity = jp.hstack([com_vel, com_ang])
qfrc_actuator = actuator.to_tau(
self.sys, action, pipeline_state.q, pipeline_state.qd)
# external_contact_forces are excluded
return jp.concatenate([
position,
velocity,
com_inertia.ravel(),
com_velocity.ravel(),
qfrc_actuator,
])
def _com(self, pipeline_state: base.State) -> jax.Array:
inertia = self.sys.link.inertia
mass_sum = jp.sum(inertia.mass)
x_i = pipeline_state.x.vmap().do(inertia.transform)
com = (
jp.sum(jax.vmap(jp.multiply)(inertia.mass, x_i.pos), axis=0) / mass_sum
)
return com, inertia, mass_sum, x_i # pytype: disable=bad-return-type # jax-ndarray
envs.register_environment('humanoid', Humanoid)
print("env registered")
# instantiate the environment
env_name = 'humanoid'
env = envs.get_environment(env_name)
x_data = []
y_data = []
ydataerr = []
times = [datetime.now()]
max_y, min_y = 13000, 0
def progress(num_steps, metrics):
times.append(datetime.now())
print(times[-2:],"len: ",len(times))
print(num_steps)
print(metrics['eval/episode_reward'])
train_fn = functools.partial(
ppo.train, num_timesteps=10_00, num_evals=1200, reward_scaling=0.1,
episode_length=10, normalize_observations=True, action_repeat=1,
unroll_length=5, num_minibatches=32, num_updates_per_batch=8,
discounting=0.97, learning_rate=3e-6, entropy_cost=1e-3, num_envs=512,
batch_size=256, seed=0)
# define the jit reset/step functions
jit_reset = jax.jit(env.reset)
jit_step = jax.jit(env.step)
state = jit_reset(jax.random.PRNGKey(0))
rollout = [state.pipeline_state]
# grab a trajectory
make_inference_fn, params, _= train_fn(environment=env, progress_fn=progress)
print(f'time to jit: {times[1] - times[0]}')
print(f'time to train: {times[-1] - times[1]}')
Thank you for your time.
@Yunkai-Yu Any progress made with this?
I encountered the same issue in a customized environment. Any progress on this?
PPO training returns nan when using multiple GPU. Forcing t use one GPU works fine. I just ran the exactly same code in training code in Brax Training. Can somebody help to try it? Thanks!