Note that every checkpoint for a realistic model such as GPT in PAXML might save ~100 arrays, so calling devices_indices_map at a scale of ~10k devices might start taking up the vast majority of blocking time spent on a checkpoint save.
This PR avoids materializing the dict and instead directly computes num_replicas based on the same logic underlying devices_indices_map.
(Follow-up to #1320)
Removes the dependency on
jax.sharding.Sharding#devices_indices_map
, which ties replica-parallel's runtime to the size of the global device mesh.We currently use
_get_replica_counts
to determine the degree to which an array is replicated (https://github.com/google/orbax/pull/1320/files#diff-7e3a46c4514a95afc4a9eced7b337cfd749a1cf4c66d68020b9663684590a970R132-R143). This unfortunately materializes a dict proportional in size to the number of devices the array is placed on.Here's a simple microbenchmark that demonstrates the issue on a mock topology of 1024 nodes each of which has 8 GPUs: sharding_metadata_bench.py
Note that every checkpoint for a realistic model such as GPT in PAXML might save ~100 arrays, so calling
devices_indices_map
at a scale of ~10k devices might start taking up the vast majority of blocking time spent on a checkpoint save.This PR avoids materializing the dict and instead directly computes
num_replicas
based on the same logic underlyingdevices_indices_map
.