CERC-AAI / multimodal

An implementation of model parallel autoregressive transformers on GPUs, based on the DeepSpeed library.
Apache License 2.0
8 stars 3 forks source link

Pythia Checkpoint Loading #4

Closed kshitijkg closed 1 year ago

kshitijkg commented 1 year ago

We need to load Pythia Checkpoints for MAGMA training. Main Issue: Mismatch in weights in checkpoint and in MAGMA model Sources of mismatch

  1. Naming change due to Attention module being re-set to the AdapterWrapper (https://github.com/floatingsnake/gpt-neox/blob/magma/megatron/model/adapter.py#L141), resulting in weights changing from, example: 2.attention. query_key_value.weight to 2.attention.attn_block.query_key_value.weight

Proposed solutions: Without changing names on Pythia Checkpoint:

  1. Add adapters after loading checkpoint, the restructuring happens after weights have already bene loaded. Disadvantage: Adapter weights will have to be loaded separately, Disadvantage: code will duplicated and not clean
  2. Get class from the module https://github.com/floatingsnake/gpt-neox/blob/magma/megatron/model/adapter.py#L129, then inherit it, override init and forward functions to include adapters. The structure remains the same, but this does not work since we dont easily have the initialization arguments to recreate the attention module. We only have the initialized object

Changing the names of the Pythia Checkpoint:

  1. Renames the weights from attention to attention.attn_block and mlp to mlp.attn_block, and stores the checkpoint again, and use the new checkpoint.
  2. Override with custom load fn that does this on the fly: https://github.com/floatingsnake/gpt-neox/pull/3: This solution will not work in the future when we are using pipeline parallelism: custom_load_fn not supported w. pipeline parallelism

Mismatch Source 2:

  1. Additional weights in MAGMA - Due to image prefix and adapters: Proposed Solution: Can be resolved by setting strict = False when loading checkpoint. Not the best solution, can be risky. But plan is to quickly verify if all the weights that dont match are just due to image prefix and adapters and be able to train stuff, after First mismatch has been fixed, set strict=False. Can find a better solution once everyone is able to use the code to port their changes and do test runs.
kshitijkg commented 1 year ago

Current Solution: Number 3. Renames the weights from attention to attention.attn_block and mlp to mlp.attn_block, and stores the checkpoint again, and use the new checkpoint. PR: https://github.com/floatingsnake/gpt-neox/pull/10

We just need to run the convert checkpoint script and use that to load.

Additionally, we set strict = False so that image prefix and adapters are ignored. I have checked manually if there are any other weights that exist that dont have the right name, but everything looks correct.

Lastly, this requires another change in the DeeperSpeed code, use the following branch: https://github.com/EleutherAI/DeeperSpeed/tree/robin_summit