bigcode-project / Megatron-LM

Ongoing research training transformer models at scale
Other
371 stars 48 forks source link

support mqa in checkpoint-merging tools #40

Closed RaymondLi0 closed 1 year ago

RaymondLi0 commented 1 year ago
RaymondLi0 commented 1 year ago

I had to add mp.set_start_method('spawn') here: https://github.com/bigcode-project/Megatron-LM/blob/mqa-checkpoint-utils/tools/checkpoint_util.py#L106 But not sure whether it's just an issue with my environment

RaymondLi0 commented 1 year ago

Added a --use-distributed-optimizer to the checkpoint_util script, so that it can load the correct checkpoint naming scheme.

The checkpoint-loader uses the data-parallel rank to get the name of the optimizer file to load. The problem is that it is not initialized in this script (and it is not required). Since only the model file is used, we circumvent the issue by setting loading the 0-th optimizer shard, which won't be used anyway. https://github.com/bigcode-project/Megatron-LM/pull/40/commits/a8e64f6f79cd5ca31db5f11336af78baeaa5282c#diff-122925dfa160fba3c00803abba3577ef0d5aa5ab48989032a63d41c91f2a8002R122

EDIT: changed to not load any optimizer state at all, instead of arbitrarily loading the 0-th shard