Open PhilipVinc opened 1 week ago
see https://github.com/google/jax/discussions/20883 @PhilipVinc
?
The issue you linked would address this problem, yes. But also declaring a custom partitioning would.
I see the custom partitioning as a more elegant (and better supported) way to achieve the same thing that the current sharding decorator does.
Right now the sharding of operators is enforced by the decorator
@nk.jax.sharding.sharding_decorator
which internally usesshard_map
to ensure that different shards are kept on different devices.However the decorator is brittle , and breaks when the input is not shared, for example when converting to dense/sparse format our operators.
For example, all our tests that can be run with
or with
will fail when calling to_sparse.
The solution would be to either remove the calls to sharding decorator in here (assuming it still shards correctly) or use jax custom partitioning https://jax.readthedocs.io/en/latest/jax.experimental.custom_partitioning.html
In theory, we should use this everywhere...