facebookresearch / metaseq

Repo for external large-scale work
MIT License
6.46k stars 724 forks source link

Convert opt to megatron-lm #135

Open appleeji opened 2 years ago

appleeji commented 2 years ago

🚀 Feature Request

Convert opt checkpoint to megatron-lm or fastertransformer

Motivation

I am currently trying to use opt in a production environment. However, because the 175B model is too large, it cannot fit all on single A100 GPU. So I used huggingface's accelerate, but found that there was no latency benefit.

From what I have found, it is important to utilize intra-layer parallelism called tensor parallel in order to increase speed in a multi-gpu environment. like below

So I tried to serve opt using fastertransformer. However, converting the opt model to megatron-lm or fastertransformer did not work well. I hope this feature is supported so that opt can be used much more.

Additional context

Here's what I tried for opt-> megatron-lm

  1. Compare and match the keys at each checkpoint
Megatron-LM OPT from huggingface
word_embeddings/weight decoder.embed_tokens.weight
position_embeddings/weight decoder.embed_positions.weight
transformer/layers.{i}.input_layernorm.weight decoder.layers.{i}.self_attn_layer_norm.weight
transformer/layers.{i}.input_layernorm.bias decoder.layers.{i}.self_attn_layer_norm.bias
transformer/layers.{i}.attention.query_key_value.weight decoder.layers.{i}.self_attn.q_proj.weight
decoder.layers.{i}.self_attn.k_proj.weight
decoder.layers.{i}.self_attn.v_proj.weight
transformer/layers.{i}.attention.query_key_value.bias decoder.layers.{i}.self_attn.q_proj.bias
decoder.layers.{i}.self_attn.k_proj.bias
decoder.layers.{i}.self_attn.v_proj.bias
transformer/layers.{i}.attention.dense.weight decoder.layers.{i}.self_attn.out_proj.weight
transformer/layers.{i}.attention.dense.bias decoder.layers.{i}.self_attn.out_proj.bias
transformer/layers.{i}.post_attention_layernorm.weight decoder.layers.{i}.final_layer_norm.weight
transformer/layers.{i}.post_attention_layernorm.bias decoder.layers.{i}.final_layer_norm.bias
transformer/layers.{i}.mlp.dense_h_to_4h.weight decoder.layers.{i}.fc1.weight
transformer/layers.{i}.mlp.dense_h_to_4h.bias decoder.layers.{i}.fc1.bias
transformer/layers.{i}.mlp.dense_4h_to_h.weight decoder.layers.{i}.fc2.weight
transformer/layers.{i}.mlp.dense_4h_to_h.bias decoder.layers.{i}.fc2.bias
transformer/final_layernorm.weight X
transformer/final_layernorm.bias X
  1. Concat q,k,v in OPT like megatrons

    with open(f"model.layers.{i}.attention.query_key_value.bias.0.bin", "wb") as f:
        val = np.concatenate((model[f'decoder.layers.{i}.self_attn.q_proj.bias'].numpy(),
                                    model[f'decoder.layers.{i}.self_attn.k_proj.bias'].numpy(),
                                    model[f'decoder.layers.{i}.self_attn.v_proj.bias'].numpy()),
                                    axis=0)
  2. Make empty weight for final_layernorm that doesn't exist in OPT fill ones to final_layernorm.weight and zeros to final_layernorm.bias

Although the OPT model was executed with FasterTransformer with the above procedure, a sentence that does not make sense was generated as a result of completion.

malteos commented 2 years ago

I'm trying to do the same but also failed so far.

Regarding 3.: I'm not sure if having empty weights makes sense here. I'd remove the final layer norm from the Megatron implementation.

Update:

appleeji commented 2 years ago

@malteos thanks for sharing.

I also have tried various things based on the huggingface opt code, but I couldn't get the right result so far.

what i have tried

Perhaps there is a difference between huggingface opt and megatron in the shape of query, key, and value.

stephenroller commented 2 years ago

our original weights should be closer to the megatron weights than the HF ones.

As one example: https://github.com/facebookresearch/metaseq/blob/fc037c7abd3dbeae333480585ce84722869f767f/metaseq/distributed/stitch_fsdp_ckpt.py#L192-L200