StijnWoestenborghs / gradi-mojo

35 stars 3 forks source link

Add proper JAX implementation (using JAX operators) #5

Closed navdeepkk-polymagelabs closed 9 months ago

navdeepkk-polymagelabs commented 10 months ago

The current JAX implementation was written using multiple jax.lax.scan operators and seemed to be slow. Add a better implementation with one less jax.lax.scan operator and other jax.numpy operators.

This new implementations runs faster than the current implementation as benchmarked on a machine with AMD EPYC 9554 (64-cores).

image
navdeepkk-polymagelabs commented 10 months ago

@StijnWoestenborghs Are you open to integrating this?

joker-eph commented 9 months ago

Pinging @cevheck as well! Would be great to update https://control.limited/mojo-does-give-superpowers as well with the "correct" JAX implementation.

StijnWoestenborghs commented 9 months ago

@navdeepkk-polymagelabs @joker-eph Thanks for pointing this out! I will merge the fix in the coming days and redo the test & plots in the blogpost accordingly. Much appreciate this

navdeepkk-polymagelabs commented 9 months ago

Sounds good! Thanks @StijnWoestenborghs @joker-eph.