Open qGentry opened 1 month ago
Grep the number of all-gathers in both text files. Your one with while loop has 271 all-gathers while the one without has 175 all-gathers. Most likely your all-gathers are being CSEed away in the accum=1 or GSPMD partitioner is partitioning differently if while loop is present.
Try adding this in Jax
# adapt the prediction function to gather weights just before their use,
# and to re-gather them on the backward pass (rather than saving them)
@partial(jax.remat, policy=lambda op, *_, **__: str(op) != 'all_gather')
Hi, I'm training transformer model with Hybrid Sharded Data Parallelism. This setup is similar to FSDP/ZeRO-3 where params all-gather-ed for each layer's forward/backward pass and dropped afterwards. Although, instead of sharding both model params and optimizer state over all GPUs in the cluster, I shard model params only over subset of devices (usually within single node for the fast all-gathers over NVLink) and shard optimizer state over all gpus (similar to FSDP/ZeRO-1/2/3).
Basically, I have mesh (param_groups, model) and for each param tensor P of shape (X, Y) I shard param tensor with partition spec (model, None) and corresponding to this param P optimizer state P_o of the same shape (X, Y) with partition spec (model, param_groups).
When mesh (param_groups, model) size is:
I'm also have a gradient accumulation implemented where we split input batch into chunks, calculate forward/backward pass independently and then sum their gradients.
When using gradient accumulation with the factor of N (batch is splitted into N chucks and processes independently) and sequence lengths of S, peak memory usage must be equal setup with gradient accumulation with the factor of 2 N and 2 SEQ_LEN. This is because resulting input tensor is of shape [B / 2, 2 * S] has the same numel as tensor [B, S].
And this is completely true for the FSDP setup with mesh size (1, N_GPUs) for any gradient accumulation factor I've tested, peak memory usages are identical but when I'm trying to use HSDP, something weird happens.
When I'm using gradient accumulation factor of N > 1, peak memory usage is totally expected BUT as soon as I set it to 1, peak memory usage greatly increases.
Here, I have a toy model with the mesh (2, 4), total batch size of 64 and 3 setups:
Second and third setup consumes practically identical amount of memory (~50 GB on each GPU), while first sone consumes way more - 61GB.
Here's HLOs of the first and second setups: compiled_train_fn_grad_accum=2.txt compiled_train_fn_grad_accum=1.txt
JAX issue - https://github.com/jax-ml/jax/issues/24208