Closed markusheimerl closed 2 months ago
#@title Import MuJoCo, MJX, and Brax
# Create a virtual display
from pyvirtualdisplay import Display
display = Display(visible=0, size=(1400, 900))
display.start()
from datetime import datetime
import functools
import matplotlib.pyplot as plt
import jax
from jax import numpy as jp
import numpy as np
from typing import Any, Dict, Sequence, Tuple, Union
from brax import base
from brax import envs
from brax import math
from brax.base import Base, Motion, Transform
from brax.envs.base import Env, PipelineEnv, State
from brax.mjx.base import State as MjxState
from brax.training.agents.ppo import train as ppo
from brax.training.agents.ppo import networks as ppo_networks
from brax.io import html, mjcf, model
from etils import epath
from flax import struct
from ml_collections import config_dict
import mujoco
from mujoco import mjx
from brax.io import image
from PIL import Image
import io
#@title Humanoid Env
class Humanoid(PipelineEnv):
def __init__(
self,
forward_reward_weight=1.25,
ctrl_cost_weight=0.1,
healthy_reward=5.0,
terminate_when_unhealthy=True,
healthy_z_range=(1.0, 2.0),
reset_noise_scale=1e-2,
exclude_current_positions_from_observation=True,
**kwargs,
):
path = epath.Path(epath.resource_path('mujoco')) / (
'mjx/test_data/humanoid'
)
mj_model = mujoco.MjModel.from_xml_path(
(path / 'humanoid.xml').as_posix())
mj_model.opt.solver = mujoco.mjtSolver.mjSOL_CG
mj_model.opt.iterations = 6
mj_model.opt.ls_iterations = 6
sys = mjcf.load_model(mj_model)
physics_steps_per_control_step = 5
kwargs['n_frames'] = kwargs.get(
'n_frames', physics_steps_per_control_step)
kwargs['backend'] = 'mjx'
super().__init__(sys, **kwargs)
self._forward_reward_weight = forward_reward_weight
self._ctrl_cost_weight = ctrl_cost_weight
self._healthy_reward = healthy_reward
self._terminate_when_unhealthy = terminate_when_unhealthy
self._healthy_z_range = healthy_z_range
self._reset_noise_scale = reset_noise_scale
self._exclude_current_positions_from_observation = (
exclude_current_positions_from_observation
)
def reset(self, rng: jp.ndarray) -> State:
"""Resets the environment to an initial state."""
rng, rng1, rng2 = jax.random.split(rng, 3)
low, hi = -self._reset_noise_scale, self._reset_noise_scale
qpos = self.sys.qpos0 + jax.random.uniform(
rng1, (self.sys.nq,), minval=low, maxval=hi
)
qvel = jax.random.uniform(
rng2, (self.sys.nv,), minval=low, maxval=hi
)
data = self.pipeline_init(qpos, qvel)
obs = self._get_obs(data, jp.zeros(self.sys.nu))
reward, done, zero = jp.zeros(3)
metrics = {
'forward_reward': zero,
'reward_linvel': zero,
'reward_quadctrl': zero,
'reward_alive': zero,
'x_position': zero,
'y_position': zero,
'distance_from_origin': zero,
'x_velocity': zero,
'y_velocity': zero,
}
return State(data, obs, reward, done, metrics)
def step(self, state: State, action: jp.ndarray) -> State:
"""Runs one timestep of the environment's dynamics."""
data0 = state.pipeline_state
data = self.pipeline_step(data0, action)
com_before = data0.subtree_com[1]
com_after = data.subtree_com[1]
velocity = (com_after - com_before) / self.dt
forward_reward = self._forward_reward_weight * velocity[0]
min_z, max_z = self._healthy_z_range
is_healthy = jp.where(data.q[2] < min_z, 0.0, 1.0)
is_healthy = jp.where(data.q[2] > max_z, 0.0, is_healthy)
if self._terminate_when_unhealthy:
healthy_reward = self._healthy_reward
else:
healthy_reward = self._healthy_reward * is_healthy
ctrl_cost = self._ctrl_cost_weight * jp.sum(jp.square(action))
obs = self._get_obs(data, action)
reward = forward_reward + healthy_reward - ctrl_cost
done = 1.0 - is_healthy if self._terminate_when_unhealthy else 0.0
state.metrics.update(
forward_reward=forward_reward,
reward_linvel=forward_reward,
reward_quadctrl=-ctrl_cost,
reward_alive=healthy_reward,
x_position=com_after[0],
y_position=com_after[1],
distance_from_origin=jp.linalg.norm(com_after),
x_velocity=velocity[0],
y_velocity=velocity[1],
)
return state.replace(
pipeline_state=data, obs=obs, reward=reward, done=done
)
def _get_obs(
self, data: mjx.Data, action: jp.ndarray
) -> jp.ndarray:
"""Observes humanoid body position, velocities, and angles."""
position = data.qpos
if self._exclude_current_positions_from_observation:
position = position[2:]
# external_contact_forces are excluded
return jp.concatenate([
position,
data.qvel,
data.cinert[1:].ravel(),
data.cvel[1:].ravel(),
data.qfrc_actuator,
])
envs.register_environment('humanoid', Humanoid)
# instantiate the environment
env_name = 'humanoid'
env = envs.get_environment(env_name)
# define the jit reset/step functions
jit_reset = jax.jit(env.reset)
jit_step = jax.jit(env.step)
# initialize the state
state = jit_reset(jax.random.PRNGKey(0))
rollout = [state.pipeline_state]
# grab a trajectory
for i in range(10):
ctrl = -0.1 * jp.ones(env.sys.nu)
state = jit_step(state, ctrl)
rollout.append(state.pipeline_state)
img_bytes = image.render(env.sys, rollout)
img = Image.open(io.BytesIO(img_bytes))
img.save(f'output.png')
train_fn = functools.partial(
ppo.train, num_timesteps=30_000, num_evals=5, reward_scaling=0.1,
episode_length=100, normalize_observations=True, action_repeat=1,
unroll_length=10, num_minibatches=32, num_updates_per_batch=8,
discounting=0.97, learning_rate=3e-4, entropy_cost=1e-3, num_envs=128,
batch_size=128, seed=0)
x_data = []
y_data = []
ydataerr = []
times = [datetime.now()]
max_y, min_y = 13000, 0
def progress(num_steps, metrics):
times.append(datetime.now())
x_data.append(num_steps)
y_data.append(metrics['eval/episode_reward'])
ydataerr.append(metrics['eval/episode_reward_std'])
plt.xlim([0, train_fn.keywords['num_timesteps'] * 1.25])
plt.ylim([min_y, max_y])
plt.xlabel('# environment steps')
plt.ylabel('reward per episode')
plt.title(f'y={y_data[-1]:.3f}')
plt.errorbar(x_data, y_data, yerr=ydataerr)
plt.savefig(f"plot_{num_steps}.png")
make_inference_fn, params, _= train_fn(environment=env, progress_fn=progress)
print(f'time to jit: {times[1] - times[0]}')
print(f'time to train: {times[-1] - times[1]}')
#@title Save Model
model_path = '/tmp/mjx_brax_policy'
model.save_params(model_path, params)
yields
markusheimerl@t1v-n-471edefc-w-0:~/reinforcement$ /bin/python /home/markusheimerl/reinforcement/env_render.py
Traceback (most recent call last):
File "/home/markusheimerl/reinforcement/env_render.py", line 219, in <module>
make_inference_fn, params, _= train_fn(environment=env, progress_fn=progress)
File "/home/markusheimerl/.local/lib/python3.10/site-packages/brax/training/agents/ppo/train.py", line 405, in train
metrics = evaluator.run_evaluation(
File "/home/markusheimerl/.local/lib/python3.10/site-packages/brax/training/acting.py", line 127, in run_evaluation
eval_metrics.active_episodes.block_until_ready()
jaxlib.xla_extension.XlaRuntimeError: INTERNAL: Core halted unexpectedly: INTERNAL: Accelerator device halted prematurely, perhaps due to an on-device check-failure. Node 0 halted unexpectedly at tag:pc TensorCoreSequencer:8:0x630d (from TensorCoreSequencer:8:0x630d): no debugging message found for this tag:pc.
=== Source Location Trace: ===
learning/45eac/tpu/runtime/hal/internal/tpu_program_termination_validation.cc:113
markusheimerl@t1v-n-471edefc-w-0:~/reinforcement$
Hi @markusheimerl , thanks for posting!
It looks like your issue appears in the rendering call. This doesn't look like a bug with brax/MuJoCo. It looks like glfw is not able to create a window and context on that device (it's a TPU device, so you'd want to render on the CPU).
Consider using the brax viewer instead, see the last cell here, which may not be any slower than CPU rendering anyhow. You can optionally create an issue here to see if someone else can help.
yields
and
yields