openxla / xla

A machine learning compiler for GPUs, CPUs, and ML accelerators
Apache License 2.0
2.71k stars 436 forks source link

Inadequate memory consumption when using HSDP without gradient accumulation #18090

Open qGentry opened 1 month ago

qGentry commented 1 month ago

Hi, I'm training transformer model with Hybrid Sharded Data Parallelism. This setup is similar to FSDP/ZeRO-3 where params all-gather-ed for each layer's forward/backward pass and dropped afterwards. Although, instead of sharding both model params and optimizer state over all GPUs in the cluster, I shard model params only over subset of devices (usually within single node for the fast all-gathers over NVLink) and shard optimizer state over all gpus (similar to FSDP/ZeRO-1/2/3).

Basically, I have mesh (param_groups, model) and for each param tensor P of shape (X, Y) I shard param tensor with partition spec (model, None) and corresponding to this param P optimizer state P_o of the same shape (X, Y) with partition spec (model, param_groups).

When mesh (param_groups, model) size is:

  1. (1, N_GPUs) - this is basically FSDP/ZeRO-3.
  2. (N, N_GPUs/ N), N > 1 - HSDP.

I'm also have a gradient accumulation implemented where we split input batch into chunks, calculate forward/backward pass independently and then sum their gradients.

When using gradient accumulation with the factor of N (batch is splitted into N chucks and processes independently) and sequence lengths of S, peak memory usage must be equal setup with gradient accumulation with the factor of 2 N and 2 SEQ_LEN. This is because resulting input tensor is of shape [B / 2, 2 * S] has the same numel as tensor [B, S].

And this is completely true for the FSDP setup with mesh size (1, N_GPUs) for any gradient accumulation factor I've tested, peak memory usages are identical but when I'm trying to use HSDP, something weird happens.

When I'm using gradient accumulation factor of N > 1, peak memory usage is totally expected BUT as soon as I set it to 1, peak memory usage greatly increases.

Here, I have a toy model with the mesh (2, 4), total batch size of 64 and 3 setups:

  1. gradient accumulation factor = 1, seq_len = 512
  2. gradient accumulation factor = 2, seq_len = 1024
  3. gradient accumulation factor = 4, seq_len = 2048

Second and third setup consumes practically identical amount of memory (~50 GB on each GPU), while first sone consumes way more - 61GB.

Here's HLOs of the first and second setups: compiled_train_fn_grad_accum=2.txt compiled_train_fn_grad_accum=1.txt

JAX issue - https://github.com/jax-ml/jax/issues/24208

ptoulme-aws commented 1 month ago

Grep the number of all-gathers in both text files. Your one with while loop has 271 all-gathers while the one without has 175 all-gathers. Most likely your all-gathers are being CSEed away in the accum=1 or GSPMD partitioner is partitioning differently if while loop is present.

ptoulme-aws commented 1 month ago

Try adding this in Jax

# adapt the prediction function to gather weights just before their use,
# and to re-gather them on the backward pass (rather than saving them)
@partial(jax.remat, policy=lambda op, *_, **__: str(op) != 'all_gather')