google / orbax

Orbax provides common checkpointing and persistence utilities for JAX users
https://orbax.readthedocs.io/
Apache License 2.0
307 stars 36 forks source link

[replica-parallel] Add replica-parallel saving #1320

Closed gspschmid closed 1 week ago

gspschmid commented 2 weeks ago

(Follow-up to https://github.com/google/orbax/pull/1319)

Adds "replica-parallel" saving in which each replica of a shard saves an equally-sized slice. In effect, this drives down the time spent on saving checkpoints as data-parallelism increases.

Motivation: Depending on their sharding and replication, JAX arrays may consist of multiple shards. In case of replication each shard carries a distinct replica_id, distinguishing the copies of the same logical shard from one another. Orbax's current behavior is to save the same replica_id-copy for all shards of all arrays ("single-replica" saving). In the presence of replication this is suboptimal, since the work could be parallelized across all replicas.

This PR adapts ArrayHandler to operate in "replica-parallel" mode (use_replica_parallel=True). In this mode we determine the first axis for which an array's shards are evenly-divisible by the number of replicas R. If such an axis exists, we will then assign each replica to save one R-th of the overall shard. Otherwise we fall back to "single-replica" saving.


While this patch is mostly relevant to large-scale training across many hosts, the following self-contained example illustrates the difference between existing "single-replica" saving and "replica-parallel" saving: https://gist.github.com/gspschmid/41a78da35c0b14fafaff7bed3d52c5bc

For simplicity, I ran this example in a single process controlling 8 GPUs to save a fully-replicated array of ~4.3GBs. Attached are the output and profile for the last iteration.

Single-replica (Orbax v0.9.0):

=== iteration 5/5 ===
[gc] 0.91s
[save (offload)] 0.09s
[save (persist)] 5.42s
image

Replica-parallel (this PR):

=== iteration 5/5 ===
[gc] 1.01s
[save (offload)] 0.05s
[save (persist)] 2.27s
image

A few observations:

gspschmid commented 2 weeks ago

@cpgaffney1

cpgaffney1 commented 2 weeks ago

Apologies I did not get a chance to fully review this today, will take a closer look tomorrow.

By garbage collection you are referring to gc.collect(), right? (not the Orbax CheckpointManager garbage collection of old steps, that should not be happening here)

It would be very helpful to see a small example like this run on multiple processes, since that's where the true value of the feature would lie.

gspschmid commented 2 weeks ago

By garbage collection you are referring to gc.collect(), right? (not the Orbax CheckpointManager garbage collection of old steps, that should not be happening here)

Correct

It would be very helpful to see a small example like this run on multiple processes, since that's where the true value of the feature would lie.

Agreed, will try to get to this later this week.

gspschmid commented 2 weeks ago

Rebased on the newer version of https://github.com/google/orbax/pull/1319 and addressed comments -- will rebase once more once https://github.com/google/orbax/pull/1319 is merged and ping you.

gspschmid commented 2 weeks ago

Ran this in a single-host, but multi-process setting today: gist

Timings end up very similar to the single-process example above, i.e. ~100ms/save with baseline Orbax (single-replica) and ~50ms for replica-parallel.

Note that the adapted microbenchmark only creates a single CheckpointManager and doesn't operate on a new jax.Array for every iteration. That seems to have reduced the GC overhead significantly (only ~30ms now).

Besides that, I'm noticing that we're spending a fair amount of time in the various sync_global_devicescalls -- with replica-parallel these end up taking ~50% of the exposed save time (~30ms):

image

(I still intend to run a real benchmark on a cluster, but this might have to wait til next week.)

cpgaffney1 commented 2 weeks ago

Please rebase onto head and I will take a final look at this CL before merging internally.

On the sync_global_devices question, sometimes it looks like the barrier is taking a long time even though it is really just a non-primary process waiting while the leader does real work. I'm not familiar enough with your profiler to know whether that is the case.

We can take a look internally at this, since it's not the first time the possibility has been raised. Intuitively though there is some floor on the amount of time barrier syncs can take, which should scale with the number of devices (one major advantage of using single-controller model). There's a number of things we could move to background threads to minimize the overall number of barriers though, like directory creations - that is something we're starting to work on.

gspschmid commented 2 weeks ago

On the sync_global_devices question, sometimes it looks like the barrier is taking a long time even though it is really just a non-primary process waiting while the leader does real work. I'm not familiar enough with your profiler to know whether that is the case.

The top (i.e. highlighted) process in the screenshot above should be the primary (Xla:#global=0).

Intuitively though there is some floor on the amount of time barrier syncs can take, which should scale with the number of devices

There's a number of things we could move to background threads to minimize the overall number of barriers though, like directory creations - that is something we're starting to work on.

Agreed, and indeed it seems like the lowest hanging fruit might be to elide some of these barriers -- based on the above profile there are five per save. So it sounds great that you're already looking into that!

Please rebase onto head and I will take a final look at this CL before merging internally.

Will do (likely on Monday)! Thanks again for helping push this through :-)

gspschmid commented 1 week ago

Rebased on main, PTAL, @cpgaffney1 !

cpgaffney1 commented 1 week ago

LGTM - if I don't finish merging today, will finish tomorrow.

gspschmid commented 1 week ago

I finally got around to running some multi-node benchmarks confirming that replica-parallel helps reduce blocking save time as data-parallelism increases.

The smaller configuration I tried was GPT-5B on PAXML with TP=8 (the GPUs connected to a single node) and DP=#nodes. Take-aways:

I also ran a 175B model with DP={1,2,4}, FSDP=16, TP=8 observing the same general trends. The benefit diminishes more quickly than for the 5B model, but that's in line with earlier experiments on 175B that scaled FSDP.

This is a somewhat ad-hoc analysis, but makes me confident that replica-parallel in its current state is useful. It certainly does seem plausible that we'll eventually to support non-evenly-divisible dimensions (or splitting across multiple dimensions). Another question is whether we should make replica-parallel slicing aware of layouts (prioritizing major axes).