Closed adam-hartshorne closed 2 years ago
Interesting. I assume this is OS-dependent behaviour as I don't get the same error on Linux. (Running under WSL2.)
How are you running JAX on Windows? Have you built jaxlib from source yourself?
I have tried both jaxlib built by myself and also the from here https://github.com/cloudhan/jax-windows-builder
I should also note that I currently using Jax in x64 mode with the built-in ODE solver with no issues.
I have some vague recollection of encountering a similar problem with Jax itself way back over a similar issue and it might well have also been OS specific bug.
Right, I've tracked down the root cause: https://github.com/google/jax/issues/9574
This is something that can be worked around on our end -- I'll include a fix in the upcoming v0.0.3 release.
Thanks for your fast reponse.
The version 0.0.3 release, which should fix this, is now available on PyPI.
When running any of the examples that put Jax into 64-bit mode e.g. the lotka_volterra benchmark, produces the following error, when using Python 3.9, jax 0.2.27, jaxlib 0.1.75, equinox 0.1.5, on a Windows 10 box.