sandialabs / WecOptTool

WEC Design Optimization Toolbox
https://sandialabs.github.io/WecOptTool/
GNU General Public License v3.0
12 stars 20 forks source link

autograd -> JAX #118

Open ryancoe opened 1 year ago

ryancoe commented 1 year ago

Ever since originally adopting autograd, we've been concerned that most of the development energy from autograd has moved to JAX. In addition to continued development, JAX also has more complete functionality (e.g., for fft).

We did not use JAX initially because support for MS Windows is not great - users can either compile on their own, use a third-party binary, or use Windows Subsystem for Linux (WSL) (https://github.com/google/jax#installation).

Given that more direct JAX support for MS Windows does not imminent, we there are two major hurdles preventing us from transitioning:

michaelcdevin commented 8 months ago

@ryancoe @cmichelenstrofer it appears a Windows-compatible pip install is now available for JAX as of v0.4.13. That's one of the major implementation hurdles out of the way!

cmichelenstrofer commented 8 months ago

We should try it at some point. But it does require changing the source a bit, so it won't be a small task.

ryancoe commented 8 months ago

@michaelcdevin - can you quickly see how well the Jax Windows install works as a first step?

michaelcdevin commented 8 months ago

pip install jax works without a hitch on Windows. I tested some of the basic jax.numpy and jax.grad operations and all seems to work as expected.

michaelcdevin commented 1 week ago

NumPy v2.0 was released four days ago with various breaking changes. Since autograd is no longer maintained, this makes it so autograd is incompatible with current and future releases of NumPy.

It looks like jax was proactive about maintaining compatibility with NumPy v2.0. I added in a NumPy version limitation in 902c17b as a stopgap so WecOptTool doesn't break, but switching from autograd to jax is a higher priority now so we don't fall behind in NumPy versions.

@cmichelenstrofer @ryancoe