google-deepmind / mujoco

Multi-Joint dynamics with Contact. A general purpose physics simulator.
https://mujoco.org
Apache License 2.0
7.77k stars 771 forks source link

[MJX] Seeking Assistance with Implementing Contact Force in MJX #1555

Closed breakds closed 4 months ago

breakds commented 5 months ago

Hi MuJoCo Team,

I am a research engineer focusing on robotics, and I've been leveraging MJX, the GPU-accelerated version of MuJoCo, for simulating quadrupedal robots. First off, I want to express my gratitude for open-sourcing MJX. It's been a fantastic tool, and I'm truly enjoying working with it!

Historically, we've relied on mj_contactForce to obtain contact forces in our simulations. However, I noticed that this functionality appears to be absent in MJX. In an effort to find a workaround, I dove into the engine_support.c code, specifically around the implementation details mentioned here. The logic behind mj_contactForce is quite straightforward. However, it seems that a crucial piece of data, efc_addresses, is missing from mjx.Data.contact, complicating direct adaptation.

Given this, I am contemplating the development of a JAX version of this functionality. Before proceeding, I wanted to reach out to see if you might have any suggestions for a workaround or any advice on implementing this in MJX. Any guidance or pointers you could provide would be immensely helpful and appreciated.

Thank you so much for your support and for making such a powerful tool available to the research community. I look forward to any suggestions you might have.

erikfrey commented 5 months ago

Hi @breakds,

Fortunately, efc_address is pretty simple to calculate. See this line here:

https://github.com/google-deepmind/mujoco/blob/a8db22f0d077aee86f771808fd24ea60a148c93f/mjx/mujoco/mjx/_src/io.py#L241

With a bit of futzing you should be able to get back from a contact id to the correct efc_address in a JAX-friendly way. Please let us know if you have any trouble.

That should work for you own needs. It may also be the right call to just add efc_address into MJX. I'll look into this, or if you'd like to have a go at a PR, I'd be happy to take a look.

breakds commented 5 months ago

Thanks for the pointer @erikfrey ! I can give a try to contribute to MJX about this.

breakds commented 5 months ago

Hi @breakds,

Fortunately, efc_address is pretty simple to calculate. See this line here:

https://github.com/google-deepmind/mujoco/blob/a8db22f0d077aee86f771808fd24ea60a148c93f/mjx/mujoco/mjx/_src/io.py#L241

With a bit of futzing you should be able to get back from a contact id to the correct efc_address in a JAX-friendly way. Please let us know if you have any trouble.

That should work for you own needs. It may also be the right call to just add efc_address into MJX. I'll look into this, or if you'd like to have a go at a PR, I'd be happy to take a look.

I have a very premature draft at https://github.com/google-deepmind/mujoco/pull/1561/files - trying to mimic the logic at

https://github.com/google-deepmind/mujoco/blob/a8db22f0d077aee86f771808fd24ea60a148c93f/mjx/mujoco/mjx/_src/io.py#L299

I am not sure what nc here stands for and how I can compute it. @erikfrey can you give some hint?

Thanks a lot!

breakds commented 5 months ago

I have updated the draft https://github.com/google-deepmind/mujoco/pull/1561/files

Can you help me check whether this is on the right track?

erikfrey commented 4 months ago

hi @breakds - I realized afterwards that actually, efc_address needed to change along with a few other fields and the whole thing was going to be a significant refactor. That has now been pushed in a4df912018d32ea56eeaf6baa946785d59b859bf

I'm going to close your PR but thank you for the prod to get this into MJX, and for opening your PR - it ended up helping even though we didn't merge it directly.

So as of today, you can access efc_address in mjx.Data.contact in MJX at HEAD and it will be in included in the next release. Thanks again!

breakds commented 4 months ago

Thanks!

JeyRunner commented 3 months ago

Just in case someone wants to get the forces of contact with the new mjx version, I am sharing a little helper I wrote (does the same as mj_contactForce). Not sure if all is correct though...


def __contact_force_decode_pyramid(con_dim: int, efc_force_pyramid: Float[Array, "con_dim"],
                                                                     friction_mu: Float[Array, "con_dim"]) -> Float[Array, "con_dim"]:
    """
    Convert pyramid representation to contact force.
    See https://github.com/google-deepmind/mujoco/blob/main/src/engine/engine_util_misc.c#L775
    :return:
    """
    force = jnp.zeros(con_dim)
    if con_dim == 1:
        return force.at[0].set(efc_force_pyramid[0])

    # force_normal = sum(pyramid0_i + pyramid1_i)
    for i in range(2 * (con_dim - 1)):
        force = force.at[0].add(efc_force_pyramid[i])

    # force_tangent_i = (pyramid0_i - pyramid1_i) * mu_i
    for i in range(con_dim - 1):
        force = force.at[i + 1].set((efc_force_pyramid[2 * i] - efc_force_pyramid[2 * i + 1]) * friction_mu[i])
    return force

def get_contact_force(sys: mjx.Model, state_sim: brax.mjx.base.State,
                    contact_id: int,
                    transform_in_world_frame=False) -> Float[Array, "con_dim"]:
    """
    Get 6D force:torque for one contact, in contact frame.
    If con_dim of contact is just 3, this will return just the 3D forces.
    See: https://github.com/google-deepmind/mujoco/blob/main/src/engine/engine_support.c#L1707

    :param sys:
    :param state_sim:
    :param contact_id: the index of the contact e.g. in state_sim.contact.geom[contact_id, :]
    :param transform_in_world_frame: if true, the force will be transformed into the world frame (just works with con_dim = 3).
    :return:
    """
    cond_dim = state_sim.contact.dim[contact_id]
    efc_addr = state_sim.contact.efc_address[contact_id]
    if sys.opt.cone == mjx.ConeType.PYRAMIDAL:
        forces = __contact_force_decode_pyramid(cond_dim, efc_force_pyramid=state_sim.efc_force[efc_addr:], friction_mu=state_sim.contact.friction[contact_id])
    else:
        forces = state_sim.efc_force[contact_id:contact_id + cond_dim]

    if transform_in_world_frame:
        frame_rot_mat = state_sim.contact.frame[contact_id]
        assert cond_dim == 3
        # rotate forces into global frame
        # see: https://github.com/google-deepmind/mujoco/blob/main/src/engine/engine_vis_visualize.c#L230C19-L230C22
        forces = frame_rot_mat.T @ forces
    return forces
yuvaltassa commented 3 months ago

Nice! Care to submit this as a PR?

JeyRunner commented 3 months ago

Thanks to @btaba it seems to be already implemented :)

willthibault commented 2 months ago

Hi,

It's great to see a contact force function added to MJX!

I'm wondering if there may be a more JAX based way of writing this function so that it could possibly handle getting multiple contact forces at once like when a JAX array of contact_id elements is provided. My thought is to solve the problem of getting the multiple contact forces on a object like mentioned here where a box foot is in contact with the floor for bipedal locomotion.

It could be something like shared here.

Thanks!

btaba commented 2 months ago

Hi @willthibault , in the example referenced, they're getting all contact forces. It sounds like you want a function to get all forces? Would something like this work for now:

jp.array([contact_force(m, d, i) for i in range(d.ncon)])

An alternative is to group contacts by type (elliptic, pyramidal, etc.), and call the underlying routine to generate the force in a vmap for each group, but this can be a future optimization.

willthibault commented 2 months ago

Hi @btaba,

Yes, that works and is what I have tried doing for now. I was thinking of doing something more along the lines of a vmap, but understand this is probably a future development of this function. I thought I would mention this nonetheless as it could be a performance improvement down the line.