Closed nstarman closed 2 months ago
Hmm, I think the test failures are probably due to a change in JAX. If you are able to fix that then I'd be happy to merge this.
I can't reproduce the error locally. Something in zeros
.
Which version of JAX are you using? I suspect it's probably due to an update on their end. (I haven't checked this myself yet.)
Which version of JAX are you using? I suspect it's probably due to an update on their end. (I haven't checked this myself yet.)
macOS
Python 3.12
jax[cpu] 0.4.25. I tried bumping to most recent jax
in a clean environment.
Okay, so possibly an issue with the latest version of JAX! If this passes with an earlier version then that will be the culprit, and we'll need to figure out how to use JAX's changes.
Now that jax has a key type that isn't just an Array and can be used for annotations, this PR isn't necessary.
I think the test failures are unrelated.