facebookresearch / fmmax

Fourier modal method with Jax
MIT License
95 stars 9 forks source link

Use jax.pure_callback #83

Closed mfschubert closed 7 months ago

mfschubert commented 7 months ago

Update the eigensolve to use jax.pure_callback instead of jax.experimental.host_callback.call. This allows vmap over fmmax calculations. Adds a test that exercises vmap.

smartalecH commented 7 months ago

Wow was just looking at this today, thanks! This should better enable pmap and xmap too, right?

mfschubert commented 7 months ago

Haha coincidence. Yeah I expect these should work seamlessly now.