ami-iit / jaxsim

A differentiable physics engine and multibody dynamics library for control and robot learning.
https://jaxsim.readthedocs.io/
BSD 3-Clause "New" or "Revised" License
57 stars 9 forks source link

Refactor contact forces sum in `api.ode.system_velocity_dynamics` #180

Closed flferretti closed 2 weeks ago

flferretti commented 3 weeks ago

This PR refactors contact forces sum in system_velocity_dynamics potentially improving readability and performance


📚 Documentation preview 📚: https://jaxsim--180.org.readthedocs.build//180/

flferretti commented 2 weeks ago

What was the original desiderata of this PR? Runtime performance, memory footprint, readability, etc?

It seems to me that now readability is much worse than before, and on a quick test I've done, the new logic seems wrong. Did you double check it?

I answered you with a comment in the code.

For what regards instead the logic, in which example did you get the error? I did the test simulating spheres in vmap and I had no problem

diegoferigo commented 2 weeks ago

For what regards instead the logic, in which example did you get the error? I did the test simulating spheres in vmap and I had no problem

Check the following snippet, I get different results with the new logic. Did I do any mistake?

Code ```python import jax import jax.numpy as jnp import jaxsim.api as js import resolve_robotics_uri_py import rod # Find the urdf file. urdf_path = resolve_robotics_uri_py.resolve_robotics_uri( uri="model://ergoCubSN001/model.urdf" ) # Build the ROD model. rod_sdf = rod.Sdf.load(sdf=urdf_path) # Build the model. model = js.model.JaxSimModel.build_from_model_description( model_description=rod_sdf.model, ) # Get the parent body of the collidable points. parent_link_index_of_collidable_points = jnp.array( model.kin_dyn_parameters.contact_parameters.body ) # Create the 6D forces of collidable points. nc = len(parent_link_index_of_collidable_points) W_f_Ci = jnp.ones(shape=(nc, 6)) # ========= # Old logic # ========= f_old = jax.vmap( lambda nc: ( jnp.vstack(jnp.equal(parent_link_index_of_collidable_points, nc).astype(int)) * W_f_Ci ).sum(axis=0) )(jnp.arange(model.number_of_links())) # ========= # New logic # ========= f_new = jnp.where( ( parent_link_index_of_collidable_points[:, None] == jnp.arange(model.number_of_links()) ).any(axis=-1, keepdims=True), W_f_Ci, jnp.zeros_like(W_f_Ci), ).sum(axis=0) ```
flferretti commented 2 weeks ago

I found another way to further simplify the code, if you prefer I could merge the two lines. Ready for review @diegoferigo, thanks!