google / brax

Massively parallel rigidbody physics simulation on accelerator hardware.
Apache License 2.0
2.14k stars 234 forks source link

Cannot run simple MJX example on standard v4-8 Cloud TPU VM #475

Closed markusheimerl closed 2 months ago

markusheimerl commented 2 months ago
import mujoco
from mujoco import mjx

xml = """
<mujoco>
  <worldbody>
    <light name="top" pos="0 0 1"/>
    <body name="box_and_sphere" euler="0 0 -30">
      <joint name="swing" type="hinge" axis="1 -1 0" pos="-.2 -.2 -.2"/>
      <geom name="red_box" type="box" size=".2 .2 .2" rgba="1 0 0 1"/>
      <geom name="green_sphere" pos=".2 .2 .2" size=".1" rgba="0 1 0 1"/>
    </body>
  </worldbody>
</mujoco>
"""

# Make model, data, and renderer
mj_model = mujoco.MjModel.from_xml_string(xml)
mj_data = mujoco.MjData(mj_model)
renderer = mujoco.Renderer(mj_model)

mjx_model = mjx.put_model(mj_model)
mjx_data = mjx.put_data(mj_model, mj_data)

print(mj_data.qpos, type(mj_data.qpos))
print(mjx_data.qpos, type(mjx_data.qpos), mjx_data.qpos.devices())

# enable joint visualization option:
scene_option = mujoco.MjvOption()
scene_option.flags[mujoco.mjtVisFlag.mjVIS_JOINT] = True

duration = 3.8  # (seconds)
framerate = 60  # (Hz)

frames = []
mujoco.mj_resetData(mj_model, mj_data)
while mj_data.time < duration:
  mujoco.mj_step(mj_model, mj_data)
  if len(frames) < mj_data.time * framerate:
    renderer.update_scene(mj_data, scene_option=scene_option)
    pixels = renderer.render()
    frames.append(pixels)

yields

markusheimerl@t1v-n-471edefc-w-0:~/reinforcement$ /bin/python /home/markusheimerl/reinforcement/test_mjx.py
/home/markusheimerl/.local/lib/python3.10/site-packages/glfw/__init__.py:914: GLFWError: (65544) b'X11: The DISPLAY environment variable is missing'
  warnings.warn(message, GLFWError)
/home/markusheimerl/.local/lib/python3.10/site-packages/glfw/__init__.py:914: GLFWError: (65537) b'The GLFW library is not initialized'
  warnings.warn(message, GLFWError)
Traceback (most recent call last):
  File "/home/markusheimerl/reinforcement/test_mjx.py", line 20, in <module>
    renderer = mujoco.Renderer(mj_model)
  File "/home/markusheimerl/.local/lib/python3.10/site-packages/mujoco/renderer.py", line 83, in __init__
    self._mjr_context = _render.MjrContext(
mujoco.FatalError: an OpenGL platform library has not been loaded into this process, this most likely means that a valid OpenGL context has not been created before mjr_makeContext was called
Exception ignored in: <function Renderer.__del__ at 0x7ff2c084b2e0>
Traceback (most recent call last):
  File "/home/markusheimerl/.local/lib/python3.10/site-packages/mujoco/renderer.py", line 330, in __del__
    self.close()
  File "/home/markusheimerl/.local/lib/python3.10/site-packages/mujoco/renderer.py", line 318, in close
    if self._mjr_context:
AttributeError: 'Renderer' object has no attribute '_mjr_context'
markusheimerl@t1v-n-471edefc-w-0:~/reinforcement$ 

and

# Create a virtual display
from pyvirtualdisplay import Display
display = Display(visible=0, size=(1400, 900))
display.start()

import mujoco
from mujoco import mjx

xml = """
<mujoco>
  <worldbody>
    <light name="top" pos="0 0 1"/>
    <body name="box_and_sphere" euler="0 0 -30">
      <joint name="swing" type="hinge" axis="1 -1 0" pos="-.2 -.2 -.2"/>
      <geom name="red_box" type="box" size=".2 .2 .2" rgba="1 0 0 1"/>
      <geom name="green_sphere" pos=".2 .2 .2" size=".1" rgba="0 1 0 1"/>
    </body>
  </worldbody>
</mujoco>
"""

# Make model, data, and renderer
mj_model = mujoco.MjModel.from_xml_string(xml)
mj_data = mujoco.MjData(mj_model)
renderer = mujoco.Renderer(mj_model)

mjx_model = mjx.put_model(mj_model)
mjx_data = mjx.put_data(mj_model, mj_data)

print(mj_data.qpos, type(mj_data.qpos))
print(mjx_data.qpos, type(mjx_data.qpos), mjx_data.qpos.devices())

# enable joint visualization option:
scene_option = mujoco.MjvOption()
scene_option.flags[mujoco.mjtVisFlag.mjVIS_JOINT] = True

duration = 3.8  # (seconds)
framerate = 60  # (Hz)

frames = []
mujoco.mj_resetData(mj_model, mj_data)
while mj_data.time < duration:
  mujoco.mj_step(mj_model, mj_data)
  if len(frames) < mj_data.time * framerate:
    renderer.update_scene(mj_data, scene_option=scene_option)
    pixels = renderer.render()
    frames.append(pixels)

yields

markusheimerl@t1v-n-471edefc-w-0:~/reinforcement$ /bin/python /home/markusheimerl/reinforcement/test_mjx.py
[0.] <class 'numpy.ndarray'>
[0.] <class 'jaxlib.xla_extension.ArrayImpl'> {TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0)}
Exception ignored in: <function Renderer.__del__ at 0x7feb0a431870>
Traceback (most recent call last):
  File "/home/markusheimerl/.local/lib/python3.10/site-packages/mujoco/renderer.py", line 330, in __del__
  File "/home/markusheimerl/.local/lib/python3.10/site-packages/mujoco/renderer.py", line 316, in close
  File "/home/markusheimerl/.local/lib/python3.10/site-packages/mujoco/glfw/__init__.py", line 35, in free
TypeError: 'NoneType' object is not callable
Exception ignored in: <function GLContext.__del__ at 0x7febab80c4c0>
Traceback (most recent call last):
  File "/home/markusheimerl/.local/lib/python3.10/site-packages/mujoco/glfw/__init__.py", line 41, in __del__
  File "/home/markusheimerl/.local/lib/python3.10/site-packages/mujoco/glfw/__init__.py", line 35, in free
TypeError: 'NoneType' object is not callable
markusheimerl@t1v-n-471edefc-w-0:~/reinforcement$ 
markusheimerl commented 2 months ago
#@title Import MuJoCo, MJX, and Brax

# Create a virtual display
from pyvirtualdisplay import Display
display = Display(visible=0, size=(1400, 900))
display.start()

from datetime import datetime
import functools
import matplotlib.pyplot as plt
import jax
from jax import numpy as jp
import numpy as np
from typing import Any, Dict, Sequence, Tuple, Union

from brax import base
from brax import envs
from brax import math
from brax.base import Base, Motion, Transform
from brax.envs.base import Env, PipelineEnv, State
from brax.mjx.base import State as MjxState
from brax.training.agents.ppo import train as ppo
from brax.training.agents.ppo import networks as ppo_networks
from brax.io import html, mjcf, model

from etils import epath
from flax import struct

from ml_collections import config_dict
import mujoco
from mujoco import mjx
from brax.io import image
from PIL import Image
import io

#@title Humanoid Env

class Humanoid(PipelineEnv):

  def __init__(
      self,
      forward_reward_weight=1.25,
      ctrl_cost_weight=0.1,
      healthy_reward=5.0,
      terminate_when_unhealthy=True,
      healthy_z_range=(1.0, 2.0),
      reset_noise_scale=1e-2,
      exclude_current_positions_from_observation=True,
      **kwargs,
  ):
    path = epath.Path(epath.resource_path('mujoco')) / (
        'mjx/test_data/humanoid'
    )
    mj_model = mujoco.MjModel.from_xml_path(
        (path / 'humanoid.xml').as_posix())
    mj_model.opt.solver = mujoco.mjtSolver.mjSOL_CG
    mj_model.opt.iterations = 6
    mj_model.opt.ls_iterations = 6

    sys = mjcf.load_model(mj_model)

    physics_steps_per_control_step = 5
    kwargs['n_frames'] = kwargs.get(
        'n_frames', physics_steps_per_control_step)
    kwargs['backend'] = 'mjx'

    super().__init__(sys, **kwargs)

    self._forward_reward_weight = forward_reward_weight
    self._ctrl_cost_weight = ctrl_cost_weight
    self._healthy_reward = healthy_reward
    self._terminate_when_unhealthy = terminate_when_unhealthy
    self._healthy_z_range = healthy_z_range
    self._reset_noise_scale = reset_noise_scale
    self._exclude_current_positions_from_observation = (
        exclude_current_positions_from_observation
    )

  def reset(self, rng: jp.ndarray) -> State:
    """Resets the environment to an initial state."""
    rng, rng1, rng2 = jax.random.split(rng, 3)

    low, hi = -self._reset_noise_scale, self._reset_noise_scale
    qpos = self.sys.qpos0 + jax.random.uniform(
        rng1, (self.sys.nq,), minval=low, maxval=hi
    )
    qvel = jax.random.uniform(
        rng2, (self.sys.nv,), minval=low, maxval=hi
    )

    data = self.pipeline_init(qpos, qvel)

    obs = self._get_obs(data, jp.zeros(self.sys.nu))
    reward, done, zero = jp.zeros(3)
    metrics = {
        'forward_reward': zero,
        'reward_linvel': zero,
        'reward_quadctrl': zero,
        'reward_alive': zero,
        'x_position': zero,
        'y_position': zero,
        'distance_from_origin': zero,
        'x_velocity': zero,
        'y_velocity': zero,
    }
    return State(data, obs, reward, done, metrics)

  def step(self, state: State, action: jp.ndarray) -> State:
    """Runs one timestep of the environment's dynamics."""
    data0 = state.pipeline_state
    data = self.pipeline_step(data0, action)

    com_before = data0.subtree_com[1]
    com_after = data.subtree_com[1]
    velocity = (com_after - com_before) / self.dt
    forward_reward = self._forward_reward_weight * velocity[0]

    min_z, max_z = self._healthy_z_range
    is_healthy = jp.where(data.q[2] < min_z, 0.0, 1.0)
    is_healthy = jp.where(data.q[2] > max_z, 0.0, is_healthy)
    if self._terminate_when_unhealthy:
      healthy_reward = self._healthy_reward
    else:
      healthy_reward = self._healthy_reward * is_healthy

    ctrl_cost = self._ctrl_cost_weight * jp.sum(jp.square(action))

    obs = self._get_obs(data, action)
    reward = forward_reward + healthy_reward - ctrl_cost
    done = 1.0 - is_healthy if self._terminate_when_unhealthy else 0.0
    state.metrics.update(
        forward_reward=forward_reward,
        reward_linvel=forward_reward,
        reward_quadctrl=-ctrl_cost,
        reward_alive=healthy_reward,
        x_position=com_after[0],
        y_position=com_after[1],
        distance_from_origin=jp.linalg.norm(com_after),
        x_velocity=velocity[0],
        y_velocity=velocity[1],
    )

    return state.replace(
        pipeline_state=data, obs=obs, reward=reward, done=done
    )

  def _get_obs(
      self, data: mjx.Data, action: jp.ndarray
  ) -> jp.ndarray:
    """Observes humanoid body position, velocities, and angles."""
    position = data.qpos
    if self._exclude_current_positions_from_observation:
      position = position[2:]

    # external_contact_forces are excluded
    return jp.concatenate([
        position,
        data.qvel,
        data.cinert[1:].ravel(),
        data.cvel[1:].ravel(),
        data.qfrc_actuator,
    ])

envs.register_environment('humanoid', Humanoid)

# instantiate the environment
env_name = 'humanoid'
env = envs.get_environment(env_name)

# define the jit reset/step functions
jit_reset = jax.jit(env.reset)
jit_step = jax.jit(env.step)

# initialize the state
state = jit_reset(jax.random.PRNGKey(0))
rollout = [state.pipeline_state]

# grab a trajectory
for i in range(10):
  ctrl = -0.1 * jp.ones(env.sys.nu)
  state = jit_step(state, ctrl)
  rollout.append(state.pipeline_state)

img_bytes = image.render(env.sys, rollout)
img = Image.open(io.BytesIO(img_bytes))
img.save(f'output.png')

train_fn = functools.partial(
    ppo.train, num_timesteps=30_000, num_evals=5, reward_scaling=0.1,
    episode_length=100, normalize_observations=True, action_repeat=1,
    unroll_length=10, num_minibatches=32, num_updates_per_batch=8,
    discounting=0.97, learning_rate=3e-4, entropy_cost=1e-3, num_envs=128,
    batch_size=128, seed=0)

x_data = []
y_data = []
ydataerr = []
times = [datetime.now()]

max_y, min_y = 13000, 0
def progress(num_steps, metrics):
  times.append(datetime.now())
  x_data.append(num_steps)
  y_data.append(metrics['eval/episode_reward'])
  ydataerr.append(metrics['eval/episode_reward_std'])

  plt.xlim([0, train_fn.keywords['num_timesteps'] * 1.25])
  plt.ylim([min_y, max_y])

  plt.xlabel('# environment steps')
  plt.ylabel('reward per episode')
  plt.title(f'y={y_data[-1]:.3f}')

  plt.errorbar(x_data, y_data, yerr=ydataerr)
  plt.savefig(f"plot_{num_steps}.png")

make_inference_fn, params, _= train_fn(environment=env, progress_fn=progress)

print(f'time to jit: {times[1] - times[0]}')
print(f'time to train: {times[-1] - times[1]}')

#@title Save Model
model_path = '/tmp/mjx_brax_policy'
model.save_params(model_path, params)

yields

markusheimerl@t1v-n-471edefc-w-0:~/reinforcement$ /bin/python /home/markusheimerl/reinforcement/env_render.py
Traceback (most recent call last):
  File "/home/markusheimerl/reinforcement/env_render.py", line 219, in <module>
    make_inference_fn, params, _= train_fn(environment=env, progress_fn=progress)
  File "/home/markusheimerl/.local/lib/python3.10/site-packages/brax/training/agents/ppo/train.py", line 405, in train
    metrics = evaluator.run_evaluation(
  File "/home/markusheimerl/.local/lib/python3.10/site-packages/brax/training/acting.py", line 127, in run_evaluation
    eval_metrics.active_episodes.block_until_ready()
jaxlib.xla_extension.XlaRuntimeError: INTERNAL: Core halted unexpectedly: INTERNAL: Accelerator device halted prematurely, perhaps due to an on-device check-failure. Node 0 halted unexpectedly at tag:pc TensorCoreSequencer:8:0x630d (from TensorCoreSequencer:8:0x630d): no debugging message found for this tag:pc. 
=== Source Location Trace: === 
learning/45eac/tpu/runtime/hal/internal/tpu_program_termination_validation.cc:113

markusheimerl@t1v-n-471edefc-w-0:~/reinforcement$ 
btaba commented 2 months ago

Hi @markusheimerl , thanks for posting!

It looks like your issue appears in the rendering call. This doesn't look like a bug with brax/MuJoCo. It looks like glfw is not able to create a window and context on that device (it's a TPU device, so you'd want to render on the CPU).

Consider using the brax viewer instead, see the last cell here, which may not be any slower than CPU rendering anyhow. You can optionally create an issue here to see if someone else can help.