Closed navdeepkk-polymagelabs closed 9 months ago
@StijnWoestenborghs Are you open to integrating this?
Pinging @cevheck as well! Would be great to update https://control.limited/mojo-does-give-superpowers as well with the "correct" JAX implementation.
@navdeepkk-polymagelabs @joker-eph Thanks for pointing this out! I will merge the fix in the coming days and redo the test & plots in the blogpost accordingly. Much appreciate this
Sounds good! Thanks @StijnWoestenborghs @joker-eph.
The current JAX implementation was written using multiple
jax.lax.scan
operators and seemed to be slow. Add a better implementation with one lessjax.lax.scan
operator and otherjax.numpy
operators.This new implementations runs faster than the current implementation as benchmarked on a machine with AMD EPYC 9554 (64-cores).