Closed diegoferigo closed 1 year ago
Reminder: as soon as we purge all jax.experimental.loops
invocations, we can remove the following pinning:
https://github.com/ami-iit/jaxsim/blob/4ad8f12e5ba9f348cff4c24cd51c740107777980/setup.cfg#L59
cc @fl-ferr
All our Rigid Body Dynamics Algorithms have been implemented either using plain
for
loops (that are unrolled during JIT compilation incurring in long build times) orjax.experimental.loops
(that provided a nice syntactic sugar over the low-leveljax.lax.{scan|fori_loop|while_loop}
) that unfortunately have been removed in https://github.com/google/jax/pull/11607 and no longer part of JAX starting from v0.3.16.This issue tracks the activity of updating all the usage of the removed
jax.experimental.loops
. At this point, it seems wise to exploitjax.lax.scan
throughout the algorithms. Readability will be definitely affected, but at least in this way we can also ensure that the code is forward and backward differentiable (#4).