ami-iit / jaxsim

A differentiable physics engine and multibody dynamics library for control and robot learning.
https://jaxsim.readthedocs.io/
BSD 3-Clause "New" or "Revised" License
76 stars 10 forks source link

Inconsistent data types for `body` attribute in `FrameParameters` and `ContactParameters` #271

Closed flferretti closed 6 days ago

flferretti commented 1 week ago

The body attribute in FrameParameters and ContactParameters currently uses different data types:

This inconsistency complicates operations using jax.vmap and introduces potential issues with silent IndexErrors in JIT-compiled functions. To unify the logic of extracting the frame parent link index and ensure consistency, we should standardize the data type of the body attribute.

Proposed Solutions

1. Make them both jnp.arrays:

This will facilitate operations using jax.vmap. Thus, we should carefully handle the __hash__ and __eq__ methods using HashedNumpyArray

2. Make them both tuples:

This will ensure consistency and avoid silent IndexErrors in JIT-compiled functions.

Example Code Changes

If we decide to use tuple, we could unify the index extraction logic as it has been done in https://github.com/ami-iit/jaxsim/pull/272/commits/5905a040f3a4075b0d92d86b8d3afd7466861748:

--- a/src/jaxsim/api/frame.py
+++ b/src/jaxsim/api/frame.py
@@ -32,17 +32,14 @@ def idx_of_parent_link(
     """

     n_l = model.number_of_links()
-    n_f = len(model.frame_names())

     exceptions.raise_value_error_if(
-        condition=jnp.array([frame_index < n_l, frame_index >= n_l + n_f]).any(),
+        condition=frame_index < n_l,
         msg="Invalid frame index '{idx}'",
         idx=frame_index,
     )

-    return model.kin_dyn_parameters.frame_parameters.body[
-        frame_index - model.number_of_links()
-    ]
+    return model.kin_dyn_parameters.frame_parameters.body[frame_index - n_l]
diegoferigo commented 1 week ago

This inconsistency complicates operations using jax.vmap and introduces potential issues with silent IndexErrors in JIT-compiled functions. To unify the logic of extracting the frame parent link index and ensure consistency, we should standardize the data type of the body attribute.

Do you have any specific example in mind for this?

I agree that with can unify how we handle this field. If it does not affect any existing logic, I'd prefer the tuple option. I don't think any use case that need to change dynamically the parent link of a frame, as much as it does not make sense for the contact points.

flferretti commented 1 week ago

Do you have any specific example in mind for this?

Yes, you can refer to the example reported in the issue description. Since body would be a tuple, we could exclude the condition frame_index >= n_l + n_f from the check as it would throw a IndexError anyway, while with jnp.array this error will not be raised.

I agree that with can unify how we handle this field. If it does not affect any existing logic, I'd prefer the tuple option. I don't think any use case that need to change dynamically the parent link of a frame, as much as it does not make sense for the contact points.

I totally agree with making it a tuple! It will also make the compilation a bit leaner