bigscience-workshop / Megatron-DeepSpeed

Ongoing research training transformer language models at scale, including: BERT & GPT-2
Other
1.3k stars 211 forks source link

Universal checkpoints and MP states #380

Closed aitorormazabal closed 1 year ago

aitorormazabal commented 1 year ago

Hello,

We are trying to fine-tune BLOOM on a different 3D configuration than the original. The best path for this currently seems to entail reshaping into a universal checkpoint and resuming training from that to create a reshaped checkpoint ("Checkpoint reshaping" section in https://github.com/bigscience-workshop/bigscience/tree/master/train/tr11-176B-ml ). I have two questions:

Also, out of curiosity (@stas00 ), is there a reason why heavy use of Pipeline Parallelism was made, as opposed to i.e. metaseq and OPT where afaik no PP was used and instead higher ZeRO degree (or equivalent in their FSDP) was used together with Megatron? Is this known to be more efficient or just a design choice?

Would appreciate if someone on the team could clear this up!

stas00 commented 1 year ago
* Does this imply that the trainer has logic to detect if the provided checkpoint is universal instead of sharded, and to load into the empty partitioned parameters (including ZeRO shards) from that? Is this done in a way to ensure not too much CPU memory is used at a time on each node?

when you restart from a universal checkpoint Meg-DS has --universal_checkpoint flag, so it goes on a different path of loading

https://github.com/bigscience-workshop/Megatron-DeepSpeed/blob/e52bdabbde3c6895aceb76c1bced295c2646121f/megatron/training.py#L421-L422

* Instead of consolidating a partitioned checkpoint myself I plan to use the universal one in (https://huggingface.co/bigscience/bloom-optimizer-states/tree/global_step95000_universal/global_step95000_universal). This includes files of format "mp_rank_XX_model_states.pt" on top of the weights, which seem to correspond to the MP degree. I ran a quick script to check these dicts recursively and they match exactly except for args.local_rank, args.rank and a single element in rng_tracker_states.model-parallel-rng. I suppose these aren't actually used when loading a universal checkpoint and are an artifact of the consolidation script, am I correct?

It's been a long time so my memory is hazy, most likely those are files saved by Deepspeed and copied over during checkpoint conversion. Or perhaps they are updated during the conversion. They have all sorts of metadata in them needed to load that checkpoint.

If you think they are redundant, delete those and see if things still work.

Also, out of curiosity (@stas00 ), is there a reason why heavy use of Pipeline Parallelism was made, as opposed to i.e. metaseq and OPT where afaik no PP was used and instead higher ZeRO degree (or equivalent in their FSDP) was used together with Megatron? Is this known to be more efficient or just a design choice?

ZeRO-3 requires large bandwidth and where BLOOM was trained it was 50Gbps - 8x smaller than minimal bandwidth needed for an OK performance. so we had to seek out a solution that was more network-lean, so a combination of TP and PP was used (in addition to TP)

aitorormazabal commented 1 year ago

Thank you! Very valuable to know about the ZeRO bandwidth requirements.