ami-iit / jaxsim

A differentiable physics engine and multibody dynamics library for control and robot learning.
https://jaxsim.readthedocs.io/
BSD 3-Clause "New" or "Revised" License
69 stars 10 forks source link

Refactor velocity representations as integers #160

Open flferretti opened 4 months ago

flferretti commented 4 months ago

This pull request refactors the velocity representation in the JaxSim API. The changes include avoiding the use of enum in VelRepr and making velocity_representation a non-static argument in some methods of jaxsim.api.JaxSimModelData. These changes improve the compatibility with JAX while avoiding breaking changes in the API, making it possible to potentially using jax.vmap on different velocity representations


📚 Documentation preview 📚: https://jaxsim--160.org.readthedocs.build//160/

flferretti commented 3 months ago

The AD tests take a bit longer to complete. This can be due to the JAX traceback used in JaxSimModelReferences that can lead to this effect when trying to compute the gradients

diegoferigo commented 3 months ago

Awesome, thanks @flferretti for this PR. I'd expect the following consequences:

Not yet sure if my intuition is correct.

This being said, I'd like to tag a release with all the previous improvement. I don't expect surprises here, but being such a large change touching pretty much all our API surface, I prefer being cautious and include this PR in the following release (v0.4.0). I'll be reviewing this shortly.

diegoferigo commented 3 months ago

The AD tests take a bit longer to complete.

I guess that now AD needs to propagate gradients through all possible branches instead of just one. And yes, this might take longer.

flferretti commented 3 months ago

I agree with your intuitions, the IR should now include the three branches for each velocity representation and the recompilation should not be triggered when we use switch_velocity_representation or from_other_to_inertial or from_inertial_to_other.

I'd like to tag a release with all the previous improvement. I don't expect surprises here, but being such a large change touching pretty much all our API surface, I prefer being cautious and include this PR in the following release (v0.4.0). I'll be reviewing this shortly.

I totally agree, this can be potentially disruptive, so I'd also prefer to be cautious and eventually rebase this onto #172

flferretti commented 3 months ago

Most of the match-case seem to me 1:1 with the new lax.switch functions. Are there any sections in which you had to modify the original code? I checked the diff, but since it's quite large, I could have missed something.

No, they should be equivalent

In most cases, this PR does not alter significantly the readability of the code, I like that. We already had an extra indentation level due to the match-case statements, and the new functions use the same. There are cases, however, where the logic got much more complex and indented, like in JaxSimModelReferences. Let's keep this in mind in case we want to refactor it with a more simple approach.

Yes, the logic got more complex since I needed to somehow check the values of some parameters. We can think of a smarter solution to handle that in the future.

Do you think it could be helpful adding a new jaxsim.typing.VelRepr variable that points to int? Since this removed enum are part of the core public APIs, I'd prefer to make clear that those are not generic integer types.

Totally yes! I'll a commit for that

flferretti commented 2 months ago

I'll rebase this right before the next release @diegoferigo

flferretti commented 1 month ago

Checks are failing due to a timeout of the CI