google / brax

Massively parallel rigidbody physics simulation on accelerator hardware.
Apache License 2.0
2.25k stars 249 forks source link

Are Contact Forces supported? #353

Open Kallinteris-Andreas opened 1 year ago

Kallinteris-Andreas commented 1 year ago

The reference Ant and the 2 Humanoids implementations do not include them, but this issue: https://github.com/google/brax/issues/254

Indicates that they might be implemented which is it?

If it is not implemented is there an approximate ETA?

Thanks!

btaba commented 1 year ago

Contact impulses are not populated in the State, although the generalized backend populates qf_constraint. #254 references a previous version of brax

RE: ETA, what are you needing the contact forces for?

Kallinteris-Andreas commented 1 year ago

I am a developer of Gymnasium and maintainer of Gymnasium-Robotics, I'm trying to find out if Brax has the ability necessary to load the existing (mujoco-env) models, and support all of their features

btaba commented 1 year ago

Hi @Kallinteris-Andreas , we have examples here loading a whole host of mujoco envs. I believe the contact is the only missing feature, but it can be added. Are you planning to add brax to Gymnasium any time soon?

Kallinteris-Andreas commented 1 year ago

For now, I am evaluating if brax fulfills the requirements of the gymnasium/mujoco envs.

From what I tell brax is missing cfrt_ext and tendon support.

The inclusion of the brax in Gymnasium will likely happen in 3 months after those requirements are met.

cdagher commented 1 year ago

Is there any update on this? I am a PhD student working on training robot controllers, and would like to use Brax. One of the requirements I have is using the contact forces as inputs to the NN during training, so the ability to get the contact forces from the system would allow me to use Brax.

namheegordonkim commented 1 year ago

@btaba

https://github.com/google/brax/issues/254 references a previous version of brax

In that issue, you promised that v2 will have contact forces...

I think you'll much prefer the next release of v2 brax, closing this for now

Now I'm not sure if info.contact_normal and such work as before since v2 update.

namheegordonkim commented 1 year ago

Ok, at least in PBD, it seems one can use the penetration field of Contact objects to at least (1) get some fuzzy binary contact states by setting some small threshold to detect contact and (2) guesstimate how much reaction forces might result from interpenetration.

Kallinteris-Andreas commented 11 months ago

With the release of mjx I suppose it works now, so I am closing the issue

btaba commented 11 months ago

Until we port mjx over, this isn't currently implemented in brax. We'll want to add contact forces for the positional backend as well

willthibault commented 7 months ago

In v0.10.0, MJX is now used in the backend. With MjxEnv being replaced by PipelineEnv, how can contacts and their corresponding forces be accessed? It seems like it can no longer be done similar to the method mentioned in this comment.

I also noticed the Contact class being used in pipeline initialization and stepping, but it seems a bit limiting. Are there any plans to include more data in this class?

btaba commented 7 months ago

Hi @willthibault , notice the Contact class inherits from mjx.Contact so you should have all the same data as before, please let us know otherwise

Edit: narrowed down the diff compared to the older version, we'll push out a fix. For now, you can set debug=True on the PipelineEnv.__init__

willthibault commented 7 months ago

Thanks for the followup @btaba! If we have the same data as before then we can use mj_contactForce, but only after changing the data from MJX to MuJoCo. Is there a better way to do this? I found that trying to do this inside a training step with Brax to only produce errors. I used something like this:

def get_contact_force(m, d):
  contact_twist = np.zeros(6, dtype=np.float64)
  right_foot_force = 0.0
  left_foot_force = 0.0
  contacts = [(i, d.contact[i]) for i in range(d.ncon)]
  for i, contact in contacts:
    if contact.geom1 == mujoco.mj_name2id(m, mujoco.mjtObj.mjOBJ_GEOM, "floor") and contact.geom2 == mujoco.mj_name2id(m, mujoco.mjtObj.mjOBJ_GEOM, "right_foot_collision"):
      mujoco.mj_contactForce(m, d, i, contact_twist)
      right_foot_force += np.linalg.norm(contact_twist)
    if contact.geom1 == mujoco.mj_name2id(m, mujoco.mjtObj.mjOBJ_GEOM, "floor") and contact.geom2 == mujoco.mj_name2id(m, mujoco.mjtObj.mjOBJ_GEOM, "left_foot_collision"):
      mujoco.mj_contactForce(m, d, i, contact_twist)
      left_foot_force += np.linalg.norm(contact_twist)

where d came from mjx.device_get_into(d, dx). This works well inside the MJX viewer test, but not inside Brax training. Is this more of a missing MJX feature?

AlexS28 commented 7 months ago

@willthibault by the way, did you have any luck / progress on getting this to work in brax/MJX? I am also generally finding that while things work outside brax, within brax I encounter errors with receiving certain data types, or trying to convert from mujoco to mjx etc within the training step.

willthibault commented 7 months ago

Hi @AlexS28,

Thanks for following up on this.

The function I posted above doesn't work as you mentioned in training steps because it isn't JAX friendly and results in JAX tracer errors. I think at the moment the MJX functionality for this function doesn't exist. Also, MJX has limited contact functionality (only 3 dimensional pyramidal contacts as mentioned in the feature parity list). It looks like other contact options should be coming in an update, but for now this is all we have. I have found a sort of work around that seems to work well enough for now, I too am working on bipedal walking for humanoids so maybe this is of help to you.

In a MJX data structure, contact forces are contained in efc_force. This link explains a little more about how it works and the relation to mj_contactForce. Unfortunately, the forces in this are in a more complex format and different frame because of the pyramidal contact. However, if you set your robot in a stable standing position in the simulator and read these forces, then you'll notice that it is very sparse if your contacts are limited, such as only feet and the floor. Summing the right indices of these forces (the non-zero entries) gave me a force roughly equal to the expected weight of my robot, and summing certain groupings matched the normal forces on the feet that I generated from mj_contactForce. Try running the function I shared above inside the loop of viewer.py while viewing efc_force and summing right and left foot contacts like right_foot_force += contact_twist[0] and left_foot_force += contact_twist[0] (index zero is the normal force of the contact):

get_contact_force(m, d) # generates left and right contact force from mj_contactForce
print(dx.efc_force)
print(np.sum(dx.efc_force)) # weight of robot
print(np.sum(dx.efc_force[12:24])) # right foot force
print(np.sum(dx.efc_force[28:40])) # left foot force

I got a reading that looked something like this:

Right foot force: 199.7961597442627
Left foot force: 209.16509103775024
[ 0.         0.         0.         0.         0.         0.
  0.         0.         0.         0.         0.         0.
 22.21031   21.744446  21.656994  22.297768  13.072264  12.513793
 12.504297  13.081738  15.409214  14.948065  14.884776  15.472495
  0.         0.         0.         0.        20.487473  29.70405
 22.465816  27.725697   7.990457  16.79486    9.899548  14.885785
 10.17236   19.43331   12.2858095 17.319925   0.         0.
  0.         0.         0.         0.         0.         0.
  0.         0.         0.         0.         0.         0.
  0.         0.         0.         0.         0.         0.       ]
408.96124
199.79616
209.16508

Accessing dx.efc_force[12:24] and dx.efc_force[28:40] in the Brax training has worked without issues so far. I imagine that these forces may not always be that accurate given the contacts made and the pyramidal representation, but it has worked well enough for my robot to start taking steps while considering contact forces. I am currently using MJX v3.1.1 and Brax v0.9.4 with the MjxEnv, but expect that this should work in MJX v3.1.2 and Brax v0.10.0 with the PipelineEnv too.

Hopefully further MJX/Brax updates will make accessing contacts easier. Even having elliptic contacts would make reading efc_force easier, but for pyramidal contacts some form of mj_contactForce would be useful.

AlexS28 commented 7 months ago

thanks @willthibault for the information. Much appreciated!! Your code worked for me as well.

hansihe commented 4 months ago

If this is useful for anyone else, I implemented a jittable version of contact force calculation.

No guarantees that this is correct obviously, but the output matches for all inputs I have tested it with. The 1 dim codepath is untested, but is very trival.

# Given an mjx model `s` and mjx state `d`, calculates forces for all contacts.
def get_contact_forces(s, d):
    assert(s.opt.cone == mujoco.mjtCone.mjCONE_PYRAMIDAL) # Assert cone is PYRAMIDAL

    # mju_decodePyramid
    # 1: force: result
    # 2: pyramid: d.efc_force + contact.efc_address
    # 3: mu: contact.friction
    # 4: dim: contact.dim

    contact = d.contact
    cnt = d.ncon

    # Generate 2d array of efc_force indexed by efc_address containing the maximum
    # number of potential elements (10).
    # This enables us to operate on each contact force pyramid rowwise.
    efc_argmap = jp.linspace(
      contact.efc_address,
      contact.efc_address + 9,
      10, dtype=jp.int32
    ).T
    # OOB access clamps in jax, this is safe
    pyramid = d.efc_force[efc_argmap.reshape((efc_argmap.size))].reshape(efc_argmap.shape)

    # Calculate normal forces
    # force[0] = 0
    # for (int i=0; i < 2*(dim-1); i++) {
    #   force[0] += pyramid[i];
    # }
    index_matrix = jp.repeat(jp.arange(10)[None, :], cnt, axis=0)
    force_normal_mask = index_matrix < (2 * (contact.dim - 1)).reshape((cnt, 1))
    force_normal = jp.sum(jp.where(force_normal_mask, pyramid, 0), axis=1)

    # Calculate tangent forces
    # for (int i=0; i < dim-1; i++) {
    #   force[i+1] = (pyramid[2*i] - pyramid[2*i+1]) * mu[i];
    # }
    pyramid_indexes = jp.arange(5) * 2
    force_tan_all = (pyramid[:, pyramid_indexes] - pyramid[:, pyramid_indexes + 1]) * contact.friction
    force_tan = jp.where(pyramid_indexes < contact.dim.reshape((cnt, 1)), force_tan_all, 0)

    # Full force array
    forces = jp.concatenate((force_normal.reshape((cnt, 1)), force_tan), axis=1)

    # Special case frictionless contacts
    # if (dim == 1) {
    #   force[0] = pyramid[0];
    #   return;
    # }
    frictionless_mask = contact.dim == 1
    frictionless_forces = jp.concatenate((pyramid[:,0:1], jp.zeros((pyramid.shape[0], 5))), axis=1)
    return jp.where(
        frictionless_mask.reshape((cnt, 1)),
        frictionless_forces,
        forces
    )
btaba commented 4 months ago

Nice @hansihe , also see https://github.com/google-deepmind/mujoco/commit/c6b1293e58fe9dbbcd144e4cbf9bed423439f473

hansihe commented 4 months ago

Ah, I could have used that instead then, probably should have looked around more beforehand :)

willthibault commented 3 months ago

For those interested in picking out forces associated with specific contacts (such as the feet collision contacts mentioned earlier), something like this works well:

def get_feet_forces(m, dx, forces):
  # Identifiers for the floor, right foot, and left foot
  floor_id = mujoco.mj_name2id(m, mujoco.mjtObj.mjOBJ_GEOM, "floor")
  right_foot_id = mujoco.mj_name2id(m, mujoco.mjtObj.mjOBJ_GEOM, "right_foot_collision")
  left_foot_id = mujoco.mj_name2id(m, mujoco.mjtObj.mjOBJ_GEOM, "left_foot_collision")

  # Find contacts that involve both the floor and the respective foot
  # This assumes dx.contact.geom contains two entries per contact, one for each of the two contacting geometries
  right_contacts = dx.contact.geom == jp.array([floor_id, right_foot_id])
  left_contacts = dx.contact.geom == jp.array([floor_id, left_foot_id])

  right_contact_ids = jp.where(jp.all(right_contacts, axis=1))[0]
  left_contact_ids = jp.where(jp.all(left_contacts, axis=1))[0]

  # Sum forces for the identified contacts
  total_right_forces = jp.sum(forces[right_contact_ids], axis=0)
  total_left_forces = jp.sum(forces[left_contact_ids], axis=0)

  return total_right_forces, total_left_forces

where the forces provided are the output of @hansihe's function. For use during training, the ids could instead be set during the initialization of the environment.