google / brax

Massively parallel rigidbody physics simulation on accelerator hardware.
Apache License 2.0
2.21k stars 243 forks source link

Joints and Actuators are Not Aligned #165

Open namheegordonkim opened 2 years ago

namheegordonkim commented 2 years ago

Here's my humanoid character: https://pastebin.com/7Ym4aqy6

Now, borrowing from the existing environment code, if I use something like

joint_1d_angle, joint_1d_vel = self.sys.joints[0].angle_vel(qp)
joint_2d_angle, joint_2d_vel = self.sys.joints[1].angle_vel(qp)
joint_3d_angle, joint_3d_vel = self.sys.joints[2].angle_vel(qp)
joint_angles = jnp.concatenate([
            joint_1d_angle[0],
            joint_2d_angle[0],
            joint_2d_angle[1],
            joint_3d_angle[0],
            joint_3d_angle[1],
            joint_3d_angle[2]]
)

then I get an array of size self.sys.num_joint_dof. Now, the actuation (torque) passed into self.sys.step() has the same size. However, I notice that actuators and joints are not gathered with a consistent indexing. For example, in my environment, act[-1] is the actuation of left_elbow, but joint_angles[-1] is not the angle of left_elbow.

I suspect this has to do with how indices are separately computed for joints and actuators. It seems that the actuators are sorted by the order of appearance in my config, while the joints are sorted by degrees of freedom (not sure what the suborder is... by appearance?). This essentially prevents me from applying a PD controller.

In actuators.py:

image

In joints.py:

image

Current workaround is to use self.sys.actuator[0].act_index and so on, concatenate them, and then sort joint_angles, such that the joint indices match the actuator indices. However, it would be much more convenient if joints and actuators were perfectly aligned by default.

namheegordonkim commented 2 years ago

Elaborating more on the workaround: something like this works for now,

    def get_joint_angles(self, qp: QP):
        """
        Return joint angles in the same order as actuator order.
        """
        joint_1d_act_index = self.sys.actuators[0].act_index
        joint_2d_act_index = self.sys.actuators[1].act_index
        joint_3d_act_index = self.sys.actuators[2].act_index

        joint_1d_angle, joint_1d_vel = self.sys.joints[0].angle_vel(qp)
        joint_2d_angle, joint_2d_vel = self.sys.joints[1].angle_vel(qp)
        joint_3d_angle, joint_3d_vel = self.sys.joints[2].angle_vel(qp)

        joint_1d_angle = jnp.stack(joint_1d_angle)
        joint_2d_angle = jnp.stack(joint_2d_angle).T
        joint_3d_angle = jnp.stack(joint_3d_angle).T

        ret = jnp.zeros(self.sys.num_joint_dof)
        ret = ret.at[joint_1d_act_index.reshape(-1)].set(joint_1d_angle.reshape(-1))
        ret = ret.at[joint_2d_act_index.reshape(-1)].set(joint_2d_angle.reshape(-1))
        ret = ret.at[joint_3d_act_index.reshape(-1)].set(joint_3d_angle.reshape(-1))
        return ret