Closed Shono1 closed 1 month ago
having the same issue
I downgraded to python==3.8, jax==4.10 and jaxlib==4.10 and error seems to be gone for now.
Ah surprisingly this actually isn't a JAX version issue, but a NumPy one instead -- see https://github.com/numpy/numpy/issues/26421, https://github.com/numpy/numpy/issues/15349, or the NumPy 2.0.0 Release Notes for context. I'm guessing that when it was working you were on some NumPy 1.XX version, but upgraded to NumPy 2.XX at some point which broke the package as you observed.
Apologies for the significant delay in fixing this issue/I'm not sure if you're still working with this package (hopefully your REU went well!), but after #16 this should now work in both NumPy 1 and 2.
I've recently installed this library, and have run into an issue when trying to run the starter code
with GPU-accelerated Jax. For all solver resolutions except low, I get an operand dimension mismatch error thrown fromdiff_coefficients
. Here's the traceback from trying to runhj.step
on the unmodifiedquickstart.ipynb
:When I knock the resolution down to low, however, there is no error.
I also had success running high and very high resolution calculations on the CPU only version of JAX.Update: I started a new CPU only environment and it actually threw the same error as I was getting with the GPU one. I'm presently unsure how I ever got this to work.
My current setup:
I've also implemented a fix suggested by @mattkiim #14
Let me know if you need any more info or need me to run any tests.