patrick-kidger / quax

Multiple dispatch over abstract array types in JAX.
Apache License 2.0
100 stars 2 forks source link

feat: convenience construction of a threefry #14

Closed nstarman closed 2 months ago

nstarman commented 6 months ago

I think the test failures are unrelated.

patrick-kidger commented 6 months ago

Hmm, I think the test failures are probably due to a change in JAX. If you are able to fix that then I'd be happy to merge this.

nstarman commented 6 months ago

I can't reproduce the error locally. Something in zeros.

patrick-kidger commented 6 months ago

Which version of JAX are you using? I suspect it's probably due to an update on their end. (I haven't checked this myself yet.)

nstarman commented 6 months ago

Which version of JAX are you using? I suspect it's probably due to an update on their end. (I haven't checked this myself yet.)

macOS Python 3.12 jax[cpu] 0.4.25. I tried bumping to most recent jax in a clean environment.

patrick-kidger commented 6 months ago

Okay, so possibly an issue with the latest version of JAX! If this passes with an earlier version then that will be the culprit, and we'll need to figure out how to use JAX's changes.

nstarman commented 2 months ago

Now that jax has a key type that isn't just an Array and can be used for annotations, this PR isn't necessary.