google / brax

Massively parallel rigidbody physics simulation on accelerator hardware.
Apache License 2.0
2.14k stars 234 forks source link

ValueError: vmap got inconsistent sizes for array axes to be mapped #465

Closed naijekux closed 4 months ago

naijekux commented 4 months ago

Hi,

I loaded the XML model in such way,

mj_model = mujoco.MjModel.from_xml_path(model_path)
sys = mjcf.load_model(mj_model)

then got the error described in title once this line code was called, pipeline_state = self.pipeline_init(qpos, qvel)

The error traceback is below,

Cell In[25], [line 51]
     [48]qvel = jax.random.uniform(rng2, (self.sys.nv,), minval=-0.005, maxval=0.005)
     [49]qvel.at[-4:].set(0.0)
---> [51]pipeline_state = self.pipeline_init(qpos, qvel)
     [52]obs = self._get_obs(pipeline_state)
     [53]reward, done, zero = jp.zeros(3)

File [/brax/envs/base.py:119], in PipelineEnv.pipeline_init(self, q, qd)
    [117]def pipeline_init(self, q: jax.Array, qd: jax.Array) -> base.State:
    [118]"""Initializes the pipeline state."""
--> [119]return self._pipeline.init(self.sys, q, qd, self._debug)

File [/brax/generalized/pipeline.py:47], in init(sys, q, qd, debug)
     [45]if sys.mj_model is not None:
     [46]mjcf.validate_model(sys.mj_model)
---> [47]x, xd = kinematics.forward(sys, q, qd)
     [48]state = State.init(q, qd, x, xd)  # pytype: disable=wrong-arg-types  # jax-ndarray
     [49]state = dynamics.transform_com(sys, state)

File [/brax/kinematics.py:85], in forward(sys, q, qd)
     [81]return j, jd
     [83]j, jd = scan.link_types(sys, jcalc, 'qdd', 'l', q, qd, sys.dof.motion)
---> [85]anchor = Transform.create(rot=j.rot).vmap().do(sys.link.joint)
     [86]j = j.replace(pos=j.pos + sys.link.joint.pos - anchor.pos)  # joint pos offset
     [87]j = sys.link.transform.vmap().do(j)  # link transform

    [... skipping hidden 2 frame]

File [/jax/_src/api.py:1350], in _mapped_axis_size(fn, tree, vals, dims, name)
   [1348] else:
   [1349]  msg.append(f"  * some axes ({ct} of them) had size {sz}, e.g. axis {ax} of {ex};\n")
-> [1350]raise ValueError(''.join(msg)[:-2])

ValueError: vmap got inconsistent sizes for array axes to be mapped:
  * most axes (3 of them) had size 9, e.g. axis 0 of argument self.pos of type float32[9,3];
  * one axis had size 11: axis 0 of argument o.rot of type float32[11,4]

I believe that the 3 axes of type float32[9,3] stands for the 'pos' in transfom, the 'ang' and 'vel' in motion, while this axis of type float32[11,4] is the 'rot'. There're totally 11 bodies inside the world body in the XML model I use. I cannot understand why 2 bodies are missing for that.

And the XML model used by me, downloaded from MuJoCo's model gallery with a little modification

<mujoco model="panda nohand">
  <compiler angle="radian" meshdir="assets" autolimits="true"/>

  <option timestep="0.01" gravity="0 0 0" iterations="20"/>

  <default>
    <default class="panda">
      <material specular="0.5" shininess="0.25"/>
      <joint armature="0.1" damping="1" axis="0 0 1" range="-2.8973 2.8973"/>
      <general dyntype="none" biastype="affine" ctrlrange="-2.8973 2.8973" forcerange="-87 87" ctrllimited="true"/>

      <default class="visual">
        <geom type="mesh" contype="0" conaffinity="0" group="2"/>
      </default>
      <default class="collision">
        <geom type="mesh" contype="0" conaffinity="1" group="3"/>
      </default>
      <site size="0.001" rgba="0.5 0.5 0.5 0.3" group="4"/>
    </default>
  </default>

  <asset>
    <material class="panda" name="white" rgba="1 1 1 1"/>
    <material class="panda" name="off_white" rgba="0.901961 0.921569 0.929412 1"/>
    <material class="panda" name="dark_grey" rgba="0.25 0.25 0.25 1"/>
    <material class="panda" name="green" rgba="0 1 0 1"/>
    <material class="panda" name="light_blue" rgba="0.039216 0.541176 0.780392 1"/>
    <!--
    <material class="panda" name="red" rgba="1 0 0 1"/>
    -->

    <!-- Collision meshes -->
    <mesh name="link0_c" file="link0.stl"/>
    <mesh name="link1_c" file="link1.stl"/>
    <mesh name="link2_c" file="link2.stl"/>
    <mesh name="link3_c" file="link3.stl"/>
    <mesh name="link4_c" file="link4.stl"/>
    <mesh name="link5_c0" file="link5_collision_0.obj"/>
    <mesh name="link5_c1" file="link5_collision_1.obj"/>
    <mesh name="link5_c2" file="link5_collision_2.obj"/>
    <mesh name="link6_c" file="link6.stl"/>
    <mesh name="link7_c" file="link7.stl"/>

    <!-- Visual meshes -->
    <mesh file="link0_0.obj"/>
    <mesh file="link0_1.obj"/>
    <mesh file="link0_2.obj"/>
    <mesh file="link0_3.obj"/>
    <mesh file="link0_4.obj"/>
    <mesh file="link0_5.obj"/>
    <mesh file="link0_7.obj"/>
    <mesh file="link0_8.obj"/>
    <mesh file="link0_9.obj"/>
    <mesh file="link0_10.obj"/>
    <mesh file="link0_11.obj"/>
    <mesh file="link1.obj"/>
    <mesh file="link2.obj"/>
    <mesh file="link3_0.obj"/>
    <mesh file="link3_1.obj"/>
    <mesh file="link3_2.obj"/>
    <mesh file="link3_3.obj"/>
    <mesh file="link4_0.obj"/>
    <mesh file="link4_1.obj"/>
    <mesh file="link4_2.obj"/>
    <mesh file="link4_3.obj"/>
    <mesh file="link5_0.obj"/>
    <mesh file="link5_1.obj"/>
    <mesh file="link5_2.obj"/>
    <mesh file="link6_0.obj"/>
    <mesh file="link6_1.obj"/>
    <mesh file="link6_2.obj"/>
    <mesh file="link6_3.obj"/>
    <mesh file="link6_4.obj"/>
    <mesh file="link6_5.obj"/>
    <mesh file="link6_6.obj"/>
    <mesh file="link6_7.obj"/>
    <mesh file="link6_8.obj"/>
    <mesh file="link6_9.obj"/>
    <mesh file="link6_10.obj"/>
    <mesh file="link6_11.obj"/>
    <mesh file="link6_12.obj"/>
    <mesh file="link6_13.obj"/>
    <mesh file="link6_14.obj"/>
    <mesh file="link6_15.obj"/>
    <mesh file="link6_16.obj"/>
    <mesh file="link7_0.obj"/>
    <mesh file="link7_1.obj"/>
    <mesh file="link7_2.obj"/>
    <mesh file="link7_3.obj"/>
    <mesh file="link7_4.obj"/>
    <mesh file="link7_5.obj"/>
    <mesh file="link7_6.obj"/>
    <mesh file="link7_7.obj"/>
  </asset>

  <worldbody>
    <light name="top" pos="0 0 2" mode="trackcom"/>

    <geom name="table" type="plane" pos="0 0.5 -0.325" size="1 1 0.1" contype="1" conaffinity="0"/>

    <body name="link0" childclass="panda" pos="1.1 -0.5 -0.2" quat="0 0 0 1">
      <inertial mass="0.629769" pos="-0.041018 -0.00014 0.049974"
        fullinertia="0.00315 0.00388 0.004285 8.2904e-7 0.00015 8.2299e-6"/>
      <geom mesh="link0_0" material="off_white" class="visual"/>
      <geom mesh="link0_1" material="dark_grey" class="visual"/>
      <geom mesh="link0_2" material="off_white" class="visual"/>
      <geom mesh="link0_3" material="dark_grey" class="visual"/>
      <geom mesh="link0_4" material="off_white" class="visual"/>
      <geom mesh="link0_5" material="dark_grey" class="visual"/>
      <geom mesh="link0_7" material="white" class="visual"/>
      <geom mesh="link0_8" material="white" class="visual"/>
      <geom mesh="link0_9" material="dark_grey" class="visual"/>
      <geom mesh="link0_10" material="off_white" class="visual"/>
      <geom mesh="link0_11" material="white" class="visual"/>
      <geom mesh="link0_c" class="collision"/>

      <body name="link1" pos="0 0 0.333">
        <inertial mass="4.970684" pos="0.003875 0.002081 -0.04762"
          fullinertia="0.70337 0.70661 0.0091170 -0.00013900 0.0067720 0.019169"/>
        <joint name="joint1"/>
        <geom material="white" mesh="link1" class="visual"/>
        <geom mesh="link1_c" class="collision"/>

        <body name="link2" quat="1 -1 0 0">
          <inertial mass="0.646926" pos="-0.003141 -0.02872 0.003495"
            fullinertia="0.0079620 2.8110e-2 2.5995e-2 -3.925e-3 1.0254e-2 7.04e-4"/>
          <joint name="joint2" range="-1.7628 1.7628"/>
          <geom material="white" mesh="link2" class="visual"/>
          <geom mesh="link2_c" class="collision"/>

          <body name="link3" pos="0 -0.316 0" quat="1 1 0 0">
            <joint name="joint3"/>
            <inertial mass="3.228604" pos="2.7518e-2 3.9252e-2 -6.6502e-2"
              fullinertia="3.7242e-2 3.6155e-2 1.083e-2 -4.761e-3 -1.1396e-2 -1.2805e-2"/>
            <geom mesh="link3_0" material="white" class="visual"/>
            <geom mesh="link3_1" material="white" class="visual"/>
            <geom mesh="link3_2" material="white" class="visual"/>
            <geom mesh="link3_3" material="dark_grey" class="visual"/>
            <geom mesh="link3_c" class="collision"/>

            <body name="link4" pos="0.0825 0 0" quat="1 1 0 0">
              <inertial mass="3.587895" pos="-5.317e-2 1.04419e-1 2.7454e-2"
                fullinertia="2.5853e-2 1.9552e-2 2.8323e-2 7.796e-3 -1.332e-3 8.641e-3"/>
              <joint name="joint4" range="-3.0718 -0.0698"/>
              <geom mesh="link4_0" material="white" class="visual"/>
              <geom mesh="link4_1" material="white" class="visual"/>
              <geom mesh="link4_2" material="dark_grey" class="visual"/>
              <geom mesh="link4_3" material="white" class="visual"/>
              <geom mesh="link4_c" class="collision"/>

              <body name="link5" pos="-0.0825 0.384 0" quat="1 -1 0 0">
                <inertial mass="1.225946" pos="-1.1953e-2 4.1065e-2 -3.8437e-2"
                  fullinertia="3.5549e-2 2.9474e-2 8.627e-3 -2.117e-3 -4.037e-3 2.29e-4"/>
                <joint name="joint5"/>
                <geom mesh="link5_0" material="dark_grey" class="visual"/>
                <geom mesh="link5_1" material="white" class="visual"/>
                <geom mesh="link5_2" material="white" class="visual"/>
                <geom mesh="link5_c0" class="collision"/>
                <geom mesh="link5_c1" class="collision"/>
                <geom mesh="link5_c2" class="collision"/>

                <body name="link6" quat="1 1 0 0">
                  <inertial mass="1.666555" pos="6.0149e-2 -1.4117e-2 -1.0517e-2"
                    fullinertia="1.964e-3 4.354e-3 5.433e-3 1.09e-4 -1.158e-3 3.41e-4"/>
                  <joint name="joint6" range="-0.0175 3.7525"/>
                  <geom mesh="link6_0" material="off_white" class="visual"/>
                  <geom mesh="link6_1" material="white" class="visual"/>
                  <geom mesh="link6_2" material="dark_grey" class="visual"/>
                  <geom mesh="link6_3" material="white" class="visual"/>
                  <geom mesh="link6_4" material="white" class="visual"/>
                  <geom mesh="link6_5" material="white" class="visual"/>
                  <geom mesh="link6_6" material="white" class="visual"/>
                  <geom mesh="link6_7" material="light_blue" class="visual"/>
                  <geom mesh="link6_8" material="light_blue" class="visual"/>
                  <geom mesh="link6_9" material="dark_grey" class="visual"/>
                  <geom mesh="link6_10" material="dark_grey" class="visual"/>
                  <geom mesh="link6_11" material="white" class="visual"/>
                  <geom mesh="link6_12" material="green" class="visual"/>
                  <geom mesh="link6_13" material="white" class="visual"/>
                  <geom mesh="link6_14" material="dark_grey" class="visual"/>
                  <geom mesh="link6_15" material="dark_grey" class="visual"/>
                  <geom mesh="link6_16" material="white" class="visual"/>
                  <geom mesh="link6_c" class="collision"/>

                  <body name="link7" pos="0.088 0 0" quat="1 1 0 0">
                    <inertial mass="7.35522e-01" pos="1.0517e-2 -4.252e-3 6.1597e-2"
                      fullinertia="1.2516e-2 1.0027e-2 4.815e-3 -4.28e-4 -1.196e-3 -7.41e-4"/>
                    <joint name="joint7"/>
                    <geom mesh="link7_0" material="white" class="visual"/>
                    <geom mesh="link7_1" material="dark_grey" class="visual"/>
                    <geom mesh="link7_2" material="dark_grey" class="visual"/>
                    <geom mesh="link7_3" material="dark_grey" class="visual"/>
                    <geom mesh="link7_4" material="dark_grey" class="visual"/>
                    <geom mesh="link7_5" material="dark_grey" class="visual"/>
                    <geom mesh="link7_6" material="dark_grey" class="visual"/>
                    <geom mesh="link7_7" material="white" class="visual"/>
                    <geom mesh="link7_c" class="collision"/>

                    <body name="attachment" pos="0 0 0.107" quat="0.3826834 0 0 0.9238795">
                      <site name="attachment_site"/>
                    </body>
                  </body>
                </body>
              </body>
            </body>
          </body>
        </body>
      </body>
    </body>

    <body name="object" pos="0.65 -0.15 -0.275" >
      <geom rgba="1 1 1 0" type="sphere" size="0.05 0.05 0.05" density="1000" conaffinity="0"/>
      <geom rgba="1 1 1 1" type="box" size="0.05 0.05 0.05" density="1000" contype="1" conaffinity="0"/>
      <joint name="obj_slidey" type="slide" pos="0 0 0" axis="0 1 0" range="-10.3213 10.3" damping="0.5"/>
      <joint name="obj_slidex" type="slide" pos="0 0 0" axis="1 0 0" range="-10.3213 10.3" damping="0.5"/>
    </body>    
    <body name="goal" pos="0.65 -0.15 -0.3240">
      <geom rgba="1 0 0 1" type="box" size="0.08 0.08 0.001" density='0.00001' contype="0" conaffinity="0"/>
      <joint name="goal_slidey" type="slide" pos="0 0 0" axis="0 1 0" range="-10.3213 10.3" damping="0.5"/>
      <joint name="goal_slidex" type="slide" pos="0 0 0" axis="1 0 0" range="-10.3213 10.3" damping="0.5"/>
    </body>
  </worldbody>

  <actuator>
    <general class="panda" name="actuator1" joint="joint1"/>
    <general class="panda" name="actuator2" joint="joint2" ctrlrange="-1.7628 1.7628"/>
    <general class="panda" name="actuator3" joint="joint3" />
    <general class="panda" name="actuator4" joint="joint4" ctrlrange="-3.0718 -0.0698"/>
    <general class="panda" name="actuator5" joint="joint5" forcerange="-12 12"/>
    <general class="panda" name="actuator6" joint="joint6" forcerange="-12 12" ctrlrange="-0.0175 3.7525"/>
    <general class="panda" name="actuator7" joint="joint7" forcerange="-12 12"/>
  </actuator>

  <keyframe>
    <key name="home" qpos="0 0 0 -1.57079 0 1.57079 -0.7853 0 0 0 0"/>
  </keyframe>
</mujoco>

Could someone please help me? Thanks

naijekux commented 4 months ago

These two 'missing' bodies are link0 and attachment in the MJCF. Having deleted these two, such error won't be raised. recommended to read https://github.com/google/brax/issues/382 (still not be solved though), where this issue had been discussed.