patrick-kidger / quax

Multiple dispatch over abstract array types in JAX.
Apache License 2.0
100 stars 2 forks source link

feat: pass precedence to plum dispatcher #24

Closed nstarman closed 1 month ago

nstarman commented 1 month ago

For when rules have some ambiguity.

nstarman commented 1 month ago

@patrick-kidger, the test failure appears to be related to jax v0.4.31 and not this PR. Also, this is directly passing a kwarg to plum, which tests this feature. Is a test necessary here?

patrick-kidger commented 1 month ago

LGTM! Merged. :) (I don't think a test is necessary, this is normal plum stuff.)

I've written a fix for the sharding issue JAX 0.4.31 here: https://github.com/patrick-kidger/quax/pull/25

It looks like there's still another outstanding one in test_sparse.py -- I've not tried tracking down what's going on with that; I'd be happy to take a PR on that.