netket / netket

Machine learning algorithms for many-body quantum systems
https://www.netket.org
Apache License 2.0
534 stars 186 forks source link

Properly support sharding of operators #1921

Open PhilipVinc opened 1 week ago

PhilipVinc commented 1 week ago

Right now the sharding of operators is enforced by the decorator @nk.jax.sharding.sharding_decorator which internally uses shard_map to ensure that different shards are kept on different devices.

However the decorator is brittle , and breaks when the input is not shared, for example when converting to dense/sparse format our operators.

For example, all our tests that can be run with

hatch --verbose test -i distributed=sharding test/operator/test_operator.py -- 

or with

mpirun -np 2 pytest -n0 --jax-distributed-gloo test/operator/test_operator.py -x

will fail when calling to_sparse.

The solution would be to either remove the calls to sharding decorator in here (assuming it still shards correctly) or use jax custom partitioning https://jax.readthedocs.io/en/latest/jax.experimental.custom_partitioning.html

In theory, we should use this everywhere...

inailuig commented 1 week ago

see https://github.com/google/jax/discussions/20883 @PhilipVinc

PhilipVinc commented 1 week ago

?

PhilipVinc commented 1 week ago

The issue you linked would address this problem, yes. But also declaring a custom partitioning would.

I see the custom partitioning as a more elegant (and better supported) way to achieve the same thing that the current sharding decorator does.