Closed gspschmid closed 1 week ago
@cpgaffney1
Apologies I did not get a chance to fully review this today, will take a closer look tomorrow.
By garbage collection you are referring to gc.collect()
, right? (not the Orbax CheckpointManager garbage collection of old steps, that should not be happening here)
It would be very helpful to see a small example like this run on multiple processes, since that's where the true value of the feature would lie.
By garbage collection you are referring to gc.collect(), right? (not the Orbax CheckpointManager garbage collection of old steps, that should not be happening here)
Correct
It would be very helpful to see a small example like this run on multiple processes, since that's where the true value of the feature would lie.
Agreed, will try to get to this later this week.
Rebased on the newer version of https://github.com/google/orbax/pull/1319 and addressed comments -- will rebase once more once https://github.com/google/orbax/pull/1319 is merged and ping you.
Ran this in a single-host, but multi-process setting today: gist
Timings end up very similar to the single-process example above, i.e. ~100ms/save with baseline Orbax (single-replica) and ~50ms for replica-parallel.
Note that the adapted microbenchmark only creates a single CheckpointManager and doesn't operate on a new jax.Array for every iteration. That seems to have reduced the GC overhead significantly (only ~30ms now).
Besides that, I'm noticing that we're spending a fair amount of time in the various sync_global_devices
calls -- with replica-parallel these end up taking ~50% of the exposed save time (~30ms):
(I still intend to run a real benchmark on a cluster, but this might have to wait til next week.)
Please rebase onto head and I will take a final look at this CL before merging internally.
On the sync_global_devices
question, sometimes it looks like the barrier is taking a long time even though it is really just a non-primary process waiting while the leader does real work. I'm not familiar enough with your profiler to know whether that is the case.
We can take a look internally at this, since it's not the first time the possibility has been raised. Intuitively though there is some floor on the amount of time barrier syncs can take, which should scale with the number of devices (one major advantage of using single-controller model). There's a number of things we could move to background threads to minimize the overall number of barriers though, like directory creations - that is something we're starting to work on.
On the sync_global_devices question, sometimes it looks like the barrier is taking a long time even though it is really just a non-primary process waiting while the leader does real work. I'm not familiar enough with your profiler to know whether that is the case.
The top (i.e. highlighted) process in the screenshot above should be the primary (Xla:#global=0
).
Intuitively though there is some floor on the amount of time barrier syncs can take, which should scale with the number of devices
There's a number of things we could move to background threads to minimize the overall number of barriers though, like directory creations - that is something we're starting to work on.
Agreed, and indeed it seems like the lowest hanging fruit might be to elide some of these barriers -- based on the above profile there are five per save. So it sounds great that you're already looking into that!
Please rebase onto head and I will take a final look at this CL before merging internally.
Will do (likely on Monday)! Thanks again for helping push this through :-)
Rebased on main, PTAL, @cpgaffney1 !
LGTM - if I don't finish merging today, will finish tomorrow.
I finally got around to running some multi-node benchmarks confirming that replica-parallel helps reduce blocking save time as data-parallelism increases.
The smaller configuration I tried was GPT-5B on PAXML with TP=8 (the GPUs connected to a single node) and DP=#nodes. Take-aways:
I also ran a 175B model with DP={1,2,4}, FSDP=16, TP=8 observing the same general trends. The benefit diminishes more quickly than for the 5B model, but that's in line with earlier experiments on 175B that scaled FSDP.
This is a somewhat ad-hoc analysis, but makes me confident that replica-parallel in its current state is useful. It certainly does seem plausible that we'll eventually to support non-evenly-divisible dimensions (or splitting across multiple dimensions). Another question is whether we should make replica-parallel slicing aware of layouts (prioritizing major axes).
(Follow-up to https://github.com/google/orbax/pull/1319)
Adds "replica-parallel" saving in which each replica of a shard saves an equally-sized slice. In effect, this drives down the time spent on saving checkpoints as data-parallelism increases.
Motivation: Depending on their sharding and replication, JAX arrays may consist of multiple shards. In case of replication each shard carries a distinct
replica_id
, distinguishing the copies of the same logical shard from one another. Orbax's current behavior is to save the samereplica_id
-copy for all shards of all arrays ("single-replica" saving). In the presence of replication this is suboptimal, since the work could be parallelized across all replicas.This PR adapts
ArrayHandler
to operate in "replica-parallel" mode (use_replica_parallel=True
). In this mode we determine the first axis for which an array's shards are evenly-divisible by the number of replicasR
. If such an axis exists, we will then assign each replica to save one R-th of the overall shard. Otherwise we fall back to "single-replica" saving.While this patch is mostly relevant to large-scale training across many hosts, the following self-contained example illustrates the difference between existing "single-replica" saving and "replica-parallel" saving: https://gist.github.com/gspschmid/41a78da35c0b14fafaff7bed3d52c5bc
For simplicity, I ran this example in a single process controlling 8 GPUs to save a fully-replicated array of ~4.3GBs. Attached are the output and profile for the last iteration.
Single-replica (Orbax v0.9.0):
Replica-parallel (this PR):
A few observations:
ts[index].write(...)
into 8 smaller ones. The microbenchmark uses tensorstore'smemory://
backend, so it remains to be seen whether this improvement holds up in realistic use cases.