google / brax

Massively parallel rigidbody physics simulation on accelerator hardware.
Apache License 2.0
2.14k stars 234 forks source link

ValueError: vmap got inconsistent sizes for array axes to be mapped #465

Closed naijekux closed 4 months ago

naijekux commented 4 months ago


I loaded the XML model in such way,

mj_model = mujoco.MjModel.from_xml_path(model_path)
sys = mjcf.load_model(mj_model)

then got the error described in title once this line code was called, pipeline_state = self.pipeline_init(qpos, qvel)

The error traceback is below,

Cell In[25], [line 51]
     [48]qvel = jax.random.uniform(rng2, (self.sys.nv,), minval=-0.005, maxval=0.005)
---> [51]pipeline_state = self.pipeline_init(qpos, qvel)
     [52]obs = self._get_obs(pipeline_state)
     [53]reward, done, zero = jp.zeros(3)

File [/brax/envs/], in PipelineEnv.pipeline_init(self, q, qd)
    [117]def pipeline_init(self, q: jax.Array, qd: jax.Array) -> base.State:
    [118]"""Initializes the pipeline state."""
--> [119]return self._pipeline.init(self.sys, q, qd, self._debug)

File [/brax/generalized/], in init(sys, q, qd, debug)
     [45]if sys.mj_model is not None:
---> [47]x, xd = kinematics.forward(sys, q, qd)
     [48]state = State.init(q, qd, x, xd)  # pytype: disable=wrong-arg-types  # jax-ndarray
     [49]state = dynamics.transform_com(sys, state)

File [/brax/], in forward(sys, q, qd)
     [81]return j, jd
     [83]j, jd = scan.link_types(sys, jcalc, 'qdd', 'l', q, qd, sys.dof.motion)
---> [85]anchor = Transform.create(rot=j.rot).vmap().do(
     [86]j = j.replace(pos=j.pos + - anchor.pos)  # joint pos offset
     [87]j =  # link transform

    [... skipping hidden 2 frame]

File [/jax/_src/], in _mapped_axis_size(fn, tree, vals, dims, name)
   [1348] else:
   [1349]  msg.append(f"  * some axes ({ct} of them) had size {sz}, e.g. axis {ax} of {ex};\n")
-> [1350]raise ValueError(''.join(msg)[:-2])

ValueError: vmap got inconsistent sizes for array axes to be mapped:
  * most axes (3 of them) had size 9, e.g. axis 0 of argument self.pos of type float32[9,3];
  * one axis had size 11: axis 0 of argument o.rot of type float32[11,4]

I believe that the 3 axes of type float32[9,3] stands for the 'pos' in transfom, the 'ang' and 'vel' in motion, while this axis of type float32[11,4] is the 'rot'. There're totally 11 bodies inside the world body in the XML model I use. I cannot understand why 2 bodies are missing for that.

And the XML model used by me, downloaded from MuJoCo's model gallery with a little modification

<mujoco model="panda nohand">
  <compiler angle="radian" meshdir="assets" autolimits="true"/>

  <option timestep="0.01" gravity="0 0 0" iterations="20"/>

    <default class="panda">
      <material specular="0.5" shininess="0.25"/>
      <joint armature="0.1" damping="1" axis="0 0 1" range="-2.8973 2.8973"/>
      <general dyntype="none" biastype="affine" ctrlrange="-2.8973 2.8973" forcerange="-87 87" ctrllimited="true"/>

      <default class="visual">
        <geom type="mesh" contype="0" conaffinity="0" group="2"/>
      <default class="collision">
        <geom type="mesh" contype="0" conaffinity="1" group="3"/>
      <site size="0.001" rgba="0.5 0.5 0.5 0.3" group="4"/>

    <material class="panda" name="white" rgba="1 1 1 1"/>
    <material class="panda" name="off_white" rgba="0.901961 0.921569 0.929412 1"/>
    <material class="panda" name="dark_grey" rgba="0.25 0.25 0.25 1"/>
    <material class="panda" name="green" rgba="0 1 0 1"/>
    <material class="panda" name="light_blue" rgba="0.039216 0.541176 0.780392 1"/>
    <material class="panda" name="red" rgba="1 0 0 1"/>

    <!-- Collision meshes -->
    <mesh name="link0_c" file="link0.stl"/>
    <mesh name="link1_c" file="link1.stl"/>
    <mesh name="link2_c" file="link2.stl"/>
    <mesh name="link3_c" file="link3.stl"/>
    <mesh name="link4_c" file="link4.stl"/>
    <mesh name="link5_c0" file="link5_collision_0.obj"/>
    <mesh name="link5_c1" file="link5_collision_1.obj"/>
    <mesh name="link5_c2" file="link5_collision_2.obj"/>
    <mesh name="link6_c" file="link6.stl"/>
    <mesh name="link7_c" file="link7.stl"/>

    <!-- Visual meshes -->
    <mesh file="link0_0.obj"/>
    <mesh file="link0_1.obj"/>
    <mesh file="link0_2.obj"/>
    <mesh file="link0_3.obj"/>
    <mesh file="link0_4.obj"/>
    <mesh file="link0_5.obj"/>
    <mesh file="link0_7.obj"/>
    <mesh file="link0_8.obj"/>
    <mesh file="link0_9.obj"/>
    <mesh file="link0_10.obj"/>
    <mesh file="link0_11.obj"/>
    <mesh file="link1.obj"/>
    <mesh file="link2.obj"/>
    <mesh file="link3_0.obj"/>
    <mesh file="link3_1.obj"/>
    <mesh file="link3_2.obj"/>
    <mesh file="link3_3.obj"/>
    <mesh file="link4_0.obj"/>
    <mesh file="link4_1.obj"/>
    <mesh file="link4_2.obj"/>
    <mesh file="link4_3.obj"/>
    <mesh file="link5_0.obj"/>
    <mesh file="link5_1.obj"/>
    <mesh file="link5_2.obj"/>
    <mesh file="link6_0.obj"/>
    <mesh file="link6_1.obj"/>
    <mesh file="link6_2.obj"/>
    <mesh file="link6_3.obj"/>
    <mesh file="link6_4.obj"/>
    <mesh file="link6_5.obj"/>
    <mesh file="link6_6.obj"/>
    <mesh file="link6_7.obj"/>
    <mesh file="link6_8.obj"/>
    <mesh file="link6_9.obj"/>
    <mesh file="link6_10.obj"/>
    <mesh file="link6_11.obj"/>
    <mesh file="link6_12.obj"/>
    <mesh file="link6_13.obj"/>
    <mesh file="link6_14.obj"/>
    <mesh file="link6_15.obj"/>
    <mesh file="link6_16.obj"/>
    <mesh file="link7_0.obj"/>
    <mesh file="link7_1.obj"/>
    <mesh file="link7_2.obj"/>
    <mesh file="link7_3.obj"/>
    <mesh file="link7_4.obj"/>
    <mesh file="link7_5.obj"/>
    <mesh file="link7_6.obj"/>
    <mesh file="link7_7.obj"/>

    <light name="top" pos="0 0 2" mode="trackcom"/>

    <geom name="table" type="plane" pos="0 0.5 -0.325" size="1 1 0.1" contype="1" conaffinity="0"/>

    <body name="link0" childclass="panda" pos="1.1 -0.5 -0.2" quat="0 0 0 1">
      <inertial mass="0.629769" pos="-0.041018 -0.00014 0.049974"
        fullinertia="0.00315 0.00388 0.004285 8.2904e-7 0.00015 8.2299e-6"/>
      <geom mesh="link0_0" material="off_white" class="visual"/>
      <geom mesh="link0_1" material="dark_grey" class="visual"/>
      <geom mesh="link0_2" material="off_white" class="visual"/>
      <geom mesh="link0_3" material="dark_grey" class="visual"/>
      <geom mesh="link0_4" material="off_white" class="visual"/>
      <geom mesh="link0_5" material="dark_grey" class="visual"/>
      <geom mesh="link0_7" material="white" class="visual"/>
      <geom mesh="link0_8" material="white" class="visual"/>
      <geom mesh="link0_9" material="dark_grey" class="visual"/>
      <geom mesh="link0_10" material="off_white" class="visual"/>
      <geom mesh="link0_11" material="white" class="visual"/>
      <geom mesh="link0_c" class="collision"/>

      <body name="link1" pos="0 0 0.333">
        <inertial mass="4.970684" pos="0.003875 0.002081 -0.04762"
          fullinertia="0.70337 0.70661 0.0091170 -0.00013900 0.0067720 0.019169"/>
        <joint name="joint1"/>
        <geom material="white" mesh="link1" class="visual"/>
        <geom mesh="link1_c" class="collision"/>

        <body name="link2" quat="1 -1 0 0">
          <inertial mass="0.646926" pos="-0.003141 -0.02872 0.003495"
            fullinertia="0.0079620 2.8110e-2 2.5995e-2 -3.925e-3 1.0254e-2 7.04e-4"/>
          <joint name="joint2" range="-1.7628 1.7628"/>
          <geom material="white" mesh="link2" class="visual"/>
          <geom mesh="link2_c" class="collision"/>

          <body name="link3" pos="0 -0.316 0" quat="1 1 0 0">
            <joint name="joint3"/>
            <inertial mass="3.228604" pos="2.7518e-2 3.9252e-2 -6.6502e-2"
              fullinertia="3.7242e-2 3.6155e-2 1.083e-2 -4.761e-3 -1.1396e-2 -1.2805e-2"/>
            <geom mesh="link3_0" material="white" class="visual"/>
            <geom mesh="link3_1" material="white" class="visual"/>
            <geom mesh="link3_2" material="white" class="visual"/>
            <geom mesh="link3_3" material="dark_grey" class="visual"/>
            <geom mesh="link3_c" class="collision"/>

            <body name="link4" pos="0.0825 0 0" quat="1 1 0 0">
              <inertial mass="3.587895" pos="-5.317e-2 1.04419e-1 2.7454e-2"
                fullinertia="2.5853e-2 1.9552e-2 2.8323e-2 7.796e-3 -1.332e-3 8.641e-3"/>
              <joint name="joint4" range="-3.0718 -0.0698"/>
              <geom mesh="link4_0" material="white" class="visual"/>
              <geom mesh="link4_1" material="white" class="visual"/>
              <geom mesh="link4_2" material="dark_grey" class="visual"/>
              <geom mesh="link4_3" material="white" class="visual"/>
              <geom mesh="link4_c" class="collision"/>

              <body name="link5" pos="-0.0825 0.384 0" quat="1 -1 0 0">
                <inertial mass="1.225946" pos="-1.1953e-2 4.1065e-2 -3.8437e-2"
                  fullinertia="3.5549e-2 2.9474e-2 8.627e-3 -2.117e-3 -4.037e-3 2.29e-4"/>
                <joint name="joint5"/>
                <geom mesh="link5_0" material="dark_grey" class="visual"/>
                <geom mesh="link5_1" material="white" class="visual"/>
                <geom mesh="link5_2" material="white" class="visual"/>
                <geom mesh="link5_c0" class="collision"/>
                <geom mesh="link5_c1" class="collision"/>
                <geom mesh="link5_c2" class="collision"/>

                <body name="link6" quat="1 1 0 0">
                  <inertial mass="1.666555" pos="6.0149e-2 -1.4117e-2 -1.0517e-2"
                    fullinertia="1.964e-3 4.354e-3 5.433e-3 1.09e-4 -1.158e-3 3.41e-4"/>
                  <joint name="joint6" range="-0.0175 3.7525"/>
                  <geom mesh="link6_0" material="off_white" class="visual"/>
                  <geom mesh="link6_1" material="white" class="visual"/>
                  <geom mesh="link6_2" material="dark_grey" class="visual"/>
                  <geom mesh="link6_3" material="white" class="visual"/>
                  <geom mesh="link6_4" material="white" class="visual"/>
                  <geom mesh="link6_5" material="white" class="visual"/>
                  <geom mesh="link6_6" material="white" class="visual"/>
                  <geom mesh="link6_7" material="light_blue" class="visual"/>
                  <geom mesh="link6_8" material="light_blue" class="visual"/>
                  <geom mesh="link6_9" material="dark_grey" class="visual"/>
                  <geom mesh="link6_10" material="dark_grey" class="visual"/>
                  <geom mesh="link6_11" material="white" class="visual"/>
                  <geom mesh="link6_12" material="green" class="visual"/>
                  <geom mesh="link6_13" material="white" class="visual"/>
                  <geom mesh="link6_14" material="dark_grey" class="visual"/>
                  <geom mesh="link6_15" material="dark_grey" class="visual"/>
                  <geom mesh="link6_16" material="white" class="visual"/>
                  <geom mesh="link6_c" class="collision"/>

                  <body name="link7" pos="0.088 0 0" quat="1 1 0 0">
                    <inertial mass="7.35522e-01" pos="1.0517e-2 -4.252e-3 6.1597e-2"
                      fullinertia="1.2516e-2 1.0027e-2 4.815e-3 -4.28e-4 -1.196e-3 -7.41e-4"/>
                    <joint name="joint7"/>
                    <geom mesh="link7_0" material="white" class="visual"/>
                    <geom mesh="link7_1" material="dark_grey" class="visual"/>
                    <geom mesh="link7_2" material="dark_grey" class="visual"/>
                    <geom mesh="link7_3" material="dark_grey" class="visual"/>
                    <geom mesh="link7_4" material="dark_grey" class="visual"/>
                    <geom mesh="link7_5" material="dark_grey" class="visual"/>
                    <geom mesh="link7_6" material="dark_grey" class="visual"/>
                    <geom mesh="link7_7" material="white" class="visual"/>
                    <geom mesh="link7_c" class="collision"/>

                    <body name="attachment" pos="0 0 0.107" quat="0.3826834 0 0 0.9238795">
                      <site name="attachment_site"/>

    <body name="object" pos="0.65 -0.15 -0.275" >
      <geom rgba="1 1 1 0" type="sphere" size="0.05 0.05 0.05" density="1000" conaffinity="0"/>
      <geom rgba="1 1 1 1" type="box" size="0.05 0.05 0.05" density="1000" contype="1" conaffinity="0"/>
      <joint name="obj_slidey" type="slide" pos="0 0 0" axis="0 1 0" range="-10.3213 10.3" damping="0.5"/>
      <joint name="obj_slidex" type="slide" pos="0 0 0" axis="1 0 0" range="-10.3213 10.3" damping="0.5"/>
    <body name="goal" pos="0.65 -0.15 -0.3240">
      <geom rgba="1 0 0 1" type="box" size="0.08 0.08 0.001" density='0.00001' contype="0" conaffinity="0"/>
      <joint name="goal_slidey" type="slide" pos="0 0 0" axis="0 1 0" range="-10.3213 10.3" damping="0.5"/>
      <joint name="goal_slidex" type="slide" pos="0 0 0" axis="1 0 0" range="-10.3213 10.3" damping="0.5"/>

    <general class="panda" name="actuator1" joint="joint1"/>
    <general class="panda" name="actuator2" joint="joint2" ctrlrange="-1.7628 1.7628"/>
    <general class="panda" name="actuator3" joint="joint3" />
    <general class="panda" name="actuator4" joint="joint4" ctrlrange="-3.0718 -0.0698"/>
    <general class="panda" name="actuator5" joint="joint5" forcerange="-12 12"/>
    <general class="panda" name="actuator6" joint="joint6" forcerange="-12 12" ctrlrange="-0.0175 3.7525"/>
    <general class="panda" name="actuator7" joint="joint7" forcerange="-12 12"/>

    <key name="home" qpos="0 0 0 -1.57079 0 1.57079 -0.7853 0 0 0 0"/>

Could someone please help me? Thanks

naijekux commented 4 months ago

These two 'missing' bodies are link0 and attachment in the MJCF. Having deleted these two, such error won't be raised. recommended to read (still not be solved though), where this issue had been discussed.