Open flferretti opened 4 months ago
The AD tests take a bit longer to complete. This can be due to the JAX traceback used in JaxSimModelReferences
that can lead to this effect when trying to compute the gradients
Awesome, thanks @flferretti for this PR. I'd expect the following consequences:
TL;DR
slower JIT compilations for a single call, but no recompilations in case multiple representations are needed.TL;DR
now, smaller binary size that can also mean less memory used especially on GPU.Not yet sure if my intuition is correct.
This being said, I'd like to tag a release with all the previous improvement. I don't expect surprises here, but being such a large change touching pretty much all our API surface, I prefer being cautious and include this PR in the following release (v0.4.0). I'll be reviewing this shortly.
The AD tests take a bit longer to complete.
I guess that now AD needs to propagate gradients through all possible branches instead of just one. And yes, this might take longer.
I agree with your intuitions, the IR should now include the three branches for each velocity representation and the recompilation should not be triggered when we use switch_velocity_representation
or from_other_to_inertial
or from_inertial_to_other
.
I'd like to tag a release with all the previous improvement. I don't expect surprises here, but being such a large change touching pretty much all our API surface, I prefer being cautious and include this PR in the following release (v0.4.0). I'll be reviewing this shortly.
I totally agree, this can be potentially disruptive, so I'd also prefer to be cautious and eventually rebase this onto #172
Most of the match-case seem to me 1:1 with the new lax.switch functions. Are there any sections in which you had to modify the original code? I checked the diff, but since it's quite large, I could have missed something.
No, they should be equivalent
In most cases, this PR does not alter significantly the readability of the code, I like that. We already had an extra indentation level due to the match-case statements, and the new functions use the same. There are cases, however, where the logic got much more complex and indented, like in
JaxSimModelReferences
. Let's keep this in mind in case we want to refactor it with a more simple approach.
Yes, the logic got more complex since I needed to somehow check the values of some parameters. We can think of a smarter solution to handle that in the future.
Do you think it could be helpful adding a new jaxsim.typing.VelRepr variable that points to int? Since this removed enum are part of the core public APIs, I'd prefer to make clear that those are not generic integer types.
Totally yes! I'll a commit for that
I'll rebase this right before the next release @diegoferigo
Checks are failing due to a timeout of the CI
This pull request refactors the velocity representation in the JaxSim API. The changes include avoiding the use of
enum
inVelRepr
and makingvelocity_representation
a non-static argument in some methods ofjaxsim.api.JaxSimModelData
. These changes improve the compatibility with JAX while avoiding breaking changes in the API, making it possible to potentially usingjax.vmap
on different velocity representations📚 Documentation preview 📚: https://jaxsim--160.org.readthedocs.build//160/