Closed smurthas closed 3 weeks ago
This appears to only be present in 3.1.4 and 3.1.5 -- if I explicitly specify version 3.1.3 when pip install
-ing, the false collision goes away.
Hopefully this helps narrow down the search space of the bug (especially since it looks like some recent work was done to refactor collision checks and add condim support)!
Also, the original example had condim="1"
set for the "starting plane" body, but I removed that and confirmed that the bug is present even without it, so I'm going to remove it from the example above and the colab notebook.
Thanks @smurthas for the bug report, we have a fix that will be pushed out shortly!
Thanks!
I rolled back to 3.1.3
and have encountered another contact-related issue that is present in MJX in 3.1.3
, but NOT present in MJX 3.1.4
and 3.1.5
nor in regular MuJoCo on any of the three versions.
Videos from 3.1.3
, comparing MuJoCo and MJX, slowed down 10x:
MuJoCo -- no issue
https://github.com/google-deepmind/mujoco/assets/399496/43c051f9-062f-440e-a12b-83497bd4e0b9
MJX -- box jumps
https://github.com/google-deepmind/mujoco/assets/399496/233d0f62-d845-422e-a36e-45bba1742aca
Python test case function repro (same as colab code):
import mujoco
from mujoco import mjx
import jax
import mediapy
import numpy as np
block_only = """
<mujoco model="block_only">
<worldbody>
<camera name="side" pos="0. -.8 0.1" xyaxes="1 0 0 0 0 1" mode="trackcom"/>
<body name="starting_plane">
<geom type="box" size=".2 .2 .01" pos="0.0 0 -.02" rgba=".5 .8 .5 1"/>
</body>
<body name="box" pos="0 0 .03">
<freejoint/>
<geom type="box" size="0.04 0.04 0.04" rgba=".8 .8 .5 1"/>
</body>
</worldbody>
</mujoco>
"""
def test_block_only_env_mujoco_vs_mjx():
# init MuJoCo
mj_model = mujoco.MjModel.from_xml_string(block_only)
mj_data = mujoco.MjData(mj_model)
renderer = mujoco.Renderer(mj_model, width=640, height=480)
mujoco.mj_resetData(mj_model, mj_data)
# init MJX from same model and data
jit_step = jax.jit(mjx.step)
mjx_model = mjx.put_model(mj_model)
mjx_data = mjx.put_data(mj_model, mj_data)
mj_datas = []
mjx_mj_datas = []
# step forward to 0.25 seconds
for _ in range(125):
# step mujoco
mujoco.mj_step(mj_model, mj_data)
mj_datas.append(mj_data)
# step mjx and fetch data
mjx_data = jit_step(mjx_model, mjx_data)
mjx_mj_data = mjx.get_data(mj_model, mjx_data)
mjx_mj_datas.append(mjx_mj_data)
# render videos
mujoco_frames = []
mjx_frames = []
for mj_data, mjx_mj_data in zip(mj_datas, mjx_mj_datas):
# render MuJoCo
renderer.update_scene(mj_data, camera="side")
mujoco_frames.append(renderer.render())
# render MJX
renderer.update_scene(mjx_mj_data, camera="side")
mjx_frames.append(renderer.render())
# Display video at a 10x slowdown since it happens quickly
mediapy.show_video(mujoco_frames, fps=500/10)
mediapy.show_video(mjx_frames, fps=500/10)
# assert that the datas match
for mj_data, mjx_mj_data in zip(mj_datas, mjx_mj_datas):
np.testing.assert_allclose(
mj_data.xpos,
mjx_mj_data.xpos,
rtol=1e-3,
atol=1e-3,
)
test_block_only_env_mujoco_vs_mjx()
I realize you are not going to try to fix issues that are no longer present, but I figured an isolated repro might be a useful regression test since the changes that introduced the false collision bug seems to have fixed a previous bug, so as you work to fix the false collision, it might help to ensure you don't inadvertently reintroduce this other bug.
Nice, thank you for the clean repros @smurthas . Looks like we're not regressing to this bug, but please let us know if you find any issues!
With the default settings of MJX and MuJoCo, this model and basic rollout produces a false collision in MJX that is (correctly) not present when the same thing is executed in MuJoCo directly.
Colab repro, the same code is also pasted below with model XML inline.
Here is a video of what this looks like (I just captured this video with me moving it by hand in the Mac MuJoCo GUI, but it looks about the same when rendered in python by adding a camera and colors):
https://github.com/google-deepmind/mujoco/assets/399496/9316a13a-b06c-484e-9393-e8b096e2eb8c
Note that there is no visible collision -- the objects are well clear of each other.
Here is the minimal repro (same as the colab):