jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.56k stars 2.81k forks source link

custom_linear_solve missing shard_map rule #18977

Open PhilipVinc opened 11 months ago

PhilipVinc commented 11 months ago

Sorry @mattjj, I've got another one for you!

jax._src.source_info_util.JaxStackTraceBeforeTransformation: NotImplementedError: No replication rule for custom_linear_solve. As a workaround, pass the check_rep=False argument to shard_map. To get this fixed, open an issue at https://github.com/google/jax/issues\

(In case that's a hard one, this arises from taking the slogdet of a batch-sharded 3-tensor, which should work...)

That's on jax/jaxlib 0.4.23

PhilipVinc commented 5 months ago

@mattjj can I ping you again on this one?