Closed varadVaidya closed 8 months ago
Hi @varadVaidya thanks for the bug report. Can you please provide a way to reproduce the issue?
Oh yes, forgot to add the minimal code to reproduce the error. The python script is based on the MJX colab tutorial.
Running the above shows:
An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.
[0. 0. 0.4 1. 0. 0. 0. ] <class 'numpy.ndarray'>
[0. 0. 0.4 1. 0. 0. 0. ] <class 'jaxlib.xla_extension.ArrayImpl'> {CpuDevice(id=0)}
jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/home/varad/study/robotics/mujoco_sims/mjx_test/server/test_box.py", line 115, in <module>
make_inference_fn, params, _ = train_fn(environment=env, progress_fn=progress)
File "/home/varad/venv/mjx/lib/python3.10/site-packages/brax/training/agents/ppo/train.py", line 226, in train
env_state = reset_fn(key_envs)
File "/home/varad/venv/mjx/lib/python3.10/site-packages/brax/envs/wrappers/training.py", line 111, in reset
state = self.env.reset(rng)
File "/home/varad/venv/mjx/lib/python3.10/site-packages/brax/envs/wrappers/training.py", line 68, in reset
return jax.vmap(self.env.reset)(rng)
File "/home/varad/venv/mjx/lib/python3.10/site-packages/brax/envs/wrappers/training.py", line 83, in reset
state = self.env.reset(rng)
File "/home/varad/study/robotics/mujoco_sims/mjx_test/server/test_box.py", line 74, in reset
data = self.pipeline_init(qpos, qvel)
File "/home/varad/venv/mjx/lib/python3.10/site-packages/brax/envs/base.py", line 119, in pipeline_init
return self._pipeline.init(self.sys, q, qd, self._debug)
File "/home/varad/venv/mjx/lib/python3.10/site-packages/brax/generalized/pipeline.py", line 46, in init
mjcf.validate_model(sys.mj_model)
File "/home/varad/venv/mjx/lib/python3.10/site-packages/brax/io/mjcf.py", line 233, in validate_model
raise NotImplementedError('Only euler integration is supported.')
NotImplementedError: Only euler integration is supported.
Removing the RK4
from XML shows:
An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.
[0. 0. 0.4 1. 0. 0. 0. ] <class 'numpy.ndarray'>
[0. 0. 0.4 1. 0. 0. 0. ] <class 'jaxlib.xla_extension.ArrayImpl'> {CpuDevice(id=0)}
jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/home/varad/study/robotics/mujoco_sims/mjx_test/server/test_box.py", line 115, in <module>
make_inference_fn, params, _ = train_fn(environment=env, progress_fn=progress)
File "/home/varad/venv/mjx/lib/python3.10/site-packages/brax/training/agents/ppo/train.py", line 226, in train
env_state = reset_fn(key_envs)
File "/home/varad/venv/mjx/lib/python3.10/site-packages/brax/envs/wrappers/training.py", line 111, in reset
state = self.env.reset(rng)
File "/home/varad/venv/mjx/lib/python3.10/site-packages/brax/envs/wrappers/training.py", line 68, in reset
return jax.vmap(self.env.reset)(rng)
File "/home/varad/venv/mjx/lib/python3.10/site-packages/brax/envs/wrappers/training.py", line 83, in reset
state = self.env.reset(rng)
File "/home/varad/study/robotics/mujoco_sims/mjx_test/server/test_box.py", line 74, in reset
data = self.pipeline_init(qpos, qvel)
File "/home/varad/venv/mjx/lib/python3.10/site-packages/brax/envs/base.py", line 119, in pipeline_init
return self._pipeline.init(self.sys, q, qd, self._debug)
File "/home/varad/venv/mjx/lib/python3.10/site-packages/brax/generalized/pipeline.py", line 46, in init
mjcf.validate_model(sys.mj_model)
File "/home/varad/venv/mjx/lib/python3.10/site-packages/brax/io/mjcf.py", line 247, in validate_model
raise NotImplementedError(
NotImplementedError: Only joint transmission types are supported for actuators.
This clearly shows that even though MJX accepts RK4
integrator and site
transmission (if RK4
is changed to euler
), brax
pipeline shows error. I tried some solution of my own, but I thought there might be some subtle things that might break, hence raised this issue. Please let me know if anything more is required. Sorry, i forgot to attach all of this, when raising the issue.
Hi @varadVaidya
Looking at the traceback, Line 260 and 264 in brax/io/mjcf.py
don't raise any exceptions. Can you try updating your version of brax? See https://github.com/google/brax/releases/tag/v0.10.0
Ah. good catch, my server had the updated version, while i made the minimal code on a older local version of brax
. The error still shows up. I have edited the traceback in the previous comment to reflect on the new one.
The traceback suggests you are using the generalized
backend. You should set the backend string to "mjx" in super().__init__
for your class Box(PipelineEnv):
See https://github.com/google/brax/blob/3c109cfd131e01691a53891e8f2ec9b32cf97670/brax/envs/base.py#L88
Thanks for the help. This solves the problem. Sorry i missed the backend="mjx
detail.
With the MuJoCo 3.1.2 update, the site transmission was added to MJX, and a new pipeline to train RL agents using
PipelineEnv
was added brax. However while loading the model, through this pipeline, errors regardingRK4
integration scheme, and site transmission, are thrown, even though the pipeline is set tomjx
.