ami-iit / jaxsim

A differentiable physics engine and multibody dynamics library for control and robot learning.
https://jaxsim.readthedocs.io/
BSD 3-Clause "New" or "Revised" License
66 stars 9 forks source link

Implement RBDA with `jax.lax.scan` #12

Closed diegoferigo closed 1 year ago

diegoferigo commented 1 year ago

All our Rigid Body Dynamics Algorithms have been implemented either using plain for loops (that are unrolled during JIT compilation incurring in long build times) or jax.experimental.loops (that provided a nice syntactic sugar over the low-level jax.lax.{scan|fori_loop|while_loop}) that unfortunately have been removed in https://github.com/google/jax/pull/11607 and no longer part of JAX starting from v0.3.16.

This issue tracks the activity of updating all the usage of the removed jax.experimental.loops. At this point, it seems wise to exploit jax.lax.scan throughout the algorithms. Readability will be definitely affected, but at least in this way we can also ensure that the code is forward and backward differentiable (#4).

diegoferigo commented 1 year ago

Reminder: as soon as we purge all jax.experimental.loops invocations, we can remove the following pinning:

https://github.com/ami-iit/jaxsim/blob/4ad8f12e5ba9f348cff4c24cd51c740107777980/setup.cfg#L59

traversaro commented 1 year ago

cc @fl-ferr