jax._src.source_info_util.JaxStackTraceBeforeTransformation: NotImplementedError: No replication rule for custom_linear_solve. As a workaround, pass the check_rep=False argument to shard_map. To get this fixed, open an issue at https://github.com/google/jax/issues\
(In case that's a hard one, this arises from taking the slogdet of a batch-sharded 3-tensor, which should work...)
Sorry @mattjj, I've got another one for you!
(In case that's a hard one, this arises from taking the
slogdet
of a batch-sharded 3-tensor, which should work...)That's on jax/jaxlib 0.4.23