google / orbax

Orbax provides common checkpointing and persistence utilities for JAX users
https://orbax.readthedocs.io/
Apache License 2.0
303 stars 36 forks source link

[replica-parallel] Avoid jax.sharding.Sharding#devices_indices_map #1358

Open gspschmid opened 2 days ago

gspschmid commented 2 days ago

(Follow-up to #1320)

Removes the dependency on jax.sharding.Sharding#devices_indices_map, which ties replica-parallel's runtime to the size of the global device mesh.

We currently use _get_replica_counts to determine the degree to which an array is replicated (https://github.com/google/orbax/pull/1320/files#diff-7e3a46c4514a95afc4a9eced7b337cfd749a1cf4c66d68020b9663684590a970R132-R143). This unfortunately materializes a dict proportional in size to the number of devices the array is placed on.

Here's a simple microbenchmark that demonstrates the issue on a mock topology of 1024 nodes each of which has 8 GPUs: sharding_metadata_bench.py

$ python3 sharding_metadata_bench.py 1024
num_nodes=1024 num_devices=8192
[num_devices=8192 num_partitions=8192 num_replicas=1] 0.55s / 100iters ~ 0.005460s / iter
[num_devices=8192 num_partitions=4096 num_replicas=2] 0.48s / 100iters ~ 0.004814s / iter
[num_devices=8192 num_partitions=2 num_replicas=4096] 0.39s / 100iters ~ 0.003917s / iter
[num_devices=8192 num_partitions=1 num_replicas=8192] 0.39s / 100iters ~ 0.003870s / iter

Note that every checkpoint for a realistic model such as GPT in PAXML might save ~100 arrays, so calling devices_indices_map at a scale of ~10k devices might start taking up the vast majority of blocking time spent on a checkpoint save.

This PR avoids materializing the dict and instead directly computes num_replicas based on the same logic underlying devices_indices_map.

gspschmid commented 2 days ago

@cpgaffney1

gspschmid commented 1 day ago

@cpgaffney1 Re-based and re-enabled use_replica_parallel by default.