NVIDIA / Megatron-LM

Ongoing research training transformer models at scale
https://docs.nvidia.com/megatron-core/developer-guide/latest/user-guide/index.html#quick-start
Other
10.13k stars 2.28k forks source link

[BUG] Checkpoint saving is slow for zarr backend + distributed optimizer #834

Closed chotzen closed 2 months ago

chotzen commented 4 months ago

Describe the bug The distributed optimizer state is being saved in an inefficient way when zarr is used as a backend. This causes slowdowns like the following writes to a local SSD (holding everything constant for a 12 layer x 128 head_dim x 12 head llama-style transformer):

image

After profiling the checkpoint saving workload, it looks like this is what happens for each parameter being saved in the optimizer state:

This full process takes 450 ms, and is repeated many times-- once per parameter in the distributed optimizer.

To Reproduce

Spawn a GPTModel (with e.g. 12 layers, 128 head dim, 12 heads) on 2 x 2 x 2 pipeline x tensor x data partitions, then try to save the distributed optimizer as a checkpoint with the zarr fully_sharded_bucket_space backend.

Expected behavior This is nearly as fast as the saving the non-distributed optimizer, and the difference does not grow at a faster rate than the number of model parameters.

Environment (please complete the following information):

Proposed fix N/A

Additional context N/A

deepakn94 commented 4 months ago

Have you tried the torch_dist (https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/training/arguments.py#L1252) distributed checkpoint format?

chotzen commented 4 months ago

yes I've tried it, we ran into some other issue regarding saving the optimizer step with dp_partitions >= 2. will file another bug for that when I have a chance to reproduce it.

chotzen commented 4 months ago

Hi @deepakn94, which kinds of checkpoint resharding are meant to be supported for the torch_dist backend? I'm unable to load a (D, P, T) = (2, 2, 2) checkpoint into a (2, 1, 2) partitioning scheme with any combination of (torch_dist, 1) and sharding_type="dp_zero_gather_scatter" or not.

mikolajblaz commented 4 months ago

Hi @chotzen, please use the recommended torch_dist backend, especially for the DistributedOptimzer - zarr backend saving is very slow for DitOpt like sharding type

I'm unable to load a (D, P, T) = (2, 2, 2) checkpoint into a (2, 1, 2) partitioning scheme

Changing TPxPP is not supported with DistOpt yet (only DP for now), will be supported in the nearest future (target is MCore v0.8)

github-actions[bot] commented 2 months ago

Marking as stale. No activity in 60 days.