Closed nstarman closed 1 month ago
@patrick-kidger, the test failure appears to be related to jax v0.4.31 and not this PR.
Also, this is directly passing a kwarg to plum
, which tests this feature. Is a test necessary here?
LGTM! Merged. :) (I don't think a test is necessary, this is normal plum stuff.)
I've written a fix for the sharding issue JAX 0.4.31 here: https://github.com/patrick-kidger/quax/pull/25
It looks like there's still another outstanding one in test_sparse.py
-- I've not tried tracking down what's going on with that; I'd be happy to take a PR on that.
For when rules have some ambiguity.