google-deepmind / mujoco

Multi-Joint dynamics with Contact. A general purpose physics simulator.
https://mujoco.org
Apache License 2.0
8.2k stars 819 forks source link

[MJX] Model loading errors in `brax.io`, for features supported by MJX #1442

Closed varadVaidya closed 8 months ago

varadVaidya commented 8 months ago

With the MuJoCo 3.1.2 update, the site transmission was added to MJX, and a new pipeline to train RL agents using PipelineEnv was added brax. However while loading the model, through this pipeline, errors regarding RK4 integration scheme, and site transmission, are thrown, even though the pipeline is set to mjx.

btaba commented 8 months ago

Hi @varadVaidya thanks for the bug report. Can you please provide a way to reproduce the issue?

varadVaidya commented 8 months ago

Oh yes, forgot to add the minimal code to reproduce the error. The python script is based on the MJX colab tutorial.

XML File
```XML ```
Python Code
```python import os # set MUJOCO_GL=egl as environment variable os.environ["MUJOCO_GL"] = "egl" os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.4" import time import itertools import numpy as np np.set_printoptions(precision=3, suppress=True, linewidth=100) from typing import Callable, NamedTuple, Optional, Union, List import matplotlib.pyplot as plt from datetime import datetime import functools from IPython.display import HTML import jax from jax import numpy as jp import numpy as np from typing import Any, Dict, Sequence, Tuple, Union from brax import base from brax import envs from brax import math from brax.base import Base, Motion, Transform from brax.envs.base import Env, PipelineEnv, State from brax.mjx.base import State as MjxState from brax.training.agents.ppo import train as ppo from brax.training.agents.ppo import networks as ppo_networks from brax.io import html, mjcf, model from etils import epath from flax import struct from matplotlib import pyplot as plt import mediapy as media from ml_collections import config_dict import mujoco from mujoco import mjx def test_mjx_put(): mj_model = mujoco.MjModel.from_xml_path("box.xml") mj_data = mujoco.MjData(mj_model) renderer = mujoco.Renderer(mj_model,height=1080,width=1920) mjx_model = mjx.put_model(mj_model) mjx_data = mjx.put_data(mj_model, mj_data) print(mj_data.qpos, type(mj_data.qpos)) print(mjx_data.qpos, type(mjx_data.qpos), mjx_data.qpos.devices()) class Box(PipelineEnv): def __init__(self, mj_freq = 100, control_freq = 50, **kwargs ): path = "box.xml" model = mujoco.MjModel.from_xml_path(path) model.opt.timestep = 1.0 / mj_freq sys = mjcf.load_model(model) physics_steps_per_control_step = int(mj_freq / control_freq) kwargs['n_frames'] = kwargs.get( 'n_frames', physics_steps_per_control_step) super().__init__(sys, **kwargs) def reset(self, rng:jp.ndarray) -> State: """Resets the environment to an initial state.""" qpos = self.sys.qpos0 qvel = jp.zeros(6) data = self.pipeline_init(qpos, qvel) obs, reward, done, zero = jp.zeros(4) metrics = { x_pos: zero, } return State(data, obs, reward, done, metrics) def step(self, state: State, action: jp.ndarray) -> State: data = self.pipeline_step(state.pipeline_state, action) obs, reward, done, zero = jp.zeros(4) return state.replace( pipeline_state=data, obs=obs, reward=reward, done=done ) if __name__ == '__main__': test_mjx_put() envs.register_environment('box', Box) # instantiate the environment env_name = 'box' env = envs.get_environment(env_name) train_fn = functools.partial( ppo.train, num_timesteps=1e5, num_evals=4, reward_scaling=1,num_eval_envs = 4, episode_length=600, normalize_observations=False, action_repeat=1, unroll_length=10, num_minibatches=32, num_updates_per_batch=2, discounting=0.99, learning_rate=2e-4, entropy_cost=1e-3,clipping_epsilon=0.2, num_envs=4096, batch_size=512, seed=0) times = [datetime.now()] def progress(num_steps, metrics): times.append(datetime.now()) print("\n################################") print("\nCurrent Time", datetime.now()) print("\nNum Steps:", num_steps) print("\nEpisode Reward", metrics['eval/episode_reward']) print("\nEpisode Reward STD:", metrics['eval/episode_reward_std']) print("\n################################") make_inference_fn, params, _ = train_fn(environment=env, progress_fn=progress) print(f'time to jit: {times[1] - times[0]}') print(f'time to train: {times[-1] - times[1]}') ```

Running the above shows:

An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.
[0.  0.  0.4 1.  0.  0.  0. ] <class 'numpy.ndarray'>
[0.  0.  0.4 1.  0.  0.  0. ] <class 'jaxlib.xla_extension.ArrayImpl'> {CpuDevice(id=0)}
jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/varad/study/robotics/mujoco_sims/mjx_test/server/test_box.py", line 115, in <module>
    make_inference_fn, params, _ = train_fn(environment=env, progress_fn=progress)
  File "/home/varad/venv/mjx/lib/python3.10/site-packages/brax/training/agents/ppo/train.py", line 226, in train
    env_state = reset_fn(key_envs)
  File "/home/varad/venv/mjx/lib/python3.10/site-packages/brax/envs/wrappers/training.py", line 111, in reset
    state = self.env.reset(rng)
  File "/home/varad/venv/mjx/lib/python3.10/site-packages/brax/envs/wrappers/training.py", line 68, in reset
    return jax.vmap(self.env.reset)(rng)
  File "/home/varad/venv/mjx/lib/python3.10/site-packages/brax/envs/wrappers/training.py", line 83, in reset
    state = self.env.reset(rng)
  File "/home/varad/study/robotics/mujoco_sims/mjx_test/server/test_box.py", line 74, in reset
    data = self.pipeline_init(qpos, qvel)
  File "/home/varad/venv/mjx/lib/python3.10/site-packages/brax/envs/base.py", line 119, in pipeline_init
    return self._pipeline.init(self.sys, q, qd, self._debug)
  File "/home/varad/venv/mjx/lib/python3.10/site-packages/brax/generalized/pipeline.py", line 46, in init
    mjcf.validate_model(sys.mj_model)
  File "/home/varad/venv/mjx/lib/python3.10/site-packages/brax/io/mjcf.py", line 233, in validate_model
    raise NotImplementedError('Only euler integration is supported.')
NotImplementedError: Only euler integration is supported.

Removing the RK4 from XML shows:

An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.
[0.  0.  0.4 1.  0.  0.  0. ] <class 'numpy.ndarray'>
[0.  0.  0.4 1.  0.  0.  0. ] <class 'jaxlib.xla_extension.ArrayImpl'> {CpuDevice(id=0)}
jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/varad/study/robotics/mujoco_sims/mjx_test/server/test_box.py", line 115, in <module>
    make_inference_fn, params, _ = train_fn(environment=env, progress_fn=progress)
  File "/home/varad/venv/mjx/lib/python3.10/site-packages/brax/training/agents/ppo/train.py", line 226, in train
    env_state = reset_fn(key_envs)
  File "/home/varad/venv/mjx/lib/python3.10/site-packages/brax/envs/wrappers/training.py", line 111, in reset
    state = self.env.reset(rng)
  File "/home/varad/venv/mjx/lib/python3.10/site-packages/brax/envs/wrappers/training.py", line 68, in reset
    return jax.vmap(self.env.reset)(rng)
  File "/home/varad/venv/mjx/lib/python3.10/site-packages/brax/envs/wrappers/training.py", line 83, in reset
    state = self.env.reset(rng)
  File "/home/varad/study/robotics/mujoco_sims/mjx_test/server/test_box.py", line 74, in reset
    data = self.pipeline_init(qpos, qvel)
  File "/home/varad/venv/mjx/lib/python3.10/site-packages/brax/envs/base.py", line 119, in pipeline_init
    return self._pipeline.init(self.sys, q, qd, self._debug)
  File "/home/varad/venv/mjx/lib/python3.10/site-packages/brax/generalized/pipeline.py", line 46, in init
    mjcf.validate_model(sys.mj_model)
  File "/home/varad/venv/mjx/lib/python3.10/site-packages/brax/io/mjcf.py", line 247, in validate_model
    raise NotImplementedError(
NotImplementedError: Only joint transmission types are supported for actuators.

This clearly shows that even though MJX accepts RK4 integrator and site transmission (if RK4 is changed to euler), brax pipeline shows error. I tried some solution of my own, but I thought there might be some subtle things that might break, hence raised this issue. Please let me know if anything more is required. Sorry, i forgot to attach all of this, when raising the issue.

btaba commented 8 months ago

Hi @varadVaidya Looking at the traceback, Line 260 and 264 in brax/io/mjcf.py don't raise any exceptions. Can you try updating your version of brax? See https://github.com/google/brax/releases/tag/v0.10.0

varadVaidya commented 8 months ago

Ah. good catch, my server had the updated version, while i made the minimal code on a older local version of brax. The error still shows up. I have edited the traceback in the previous comment to reflect on the new one.

btaba commented 8 months ago

The traceback suggests you are using the generalized backend. You should set the backend string to "mjx" in super().__init__ for your class Box(PipelineEnv):

See https://github.com/google/brax/blob/3c109cfd131e01691a53891e8f2ec9b32cf97670/brax/envs/base.py#L88

varadVaidya commented 8 months ago

Thanks for the help. This solves the problem. Sorry i missed the backend="mjx detail.