EleutherAI / gpt-neox

An implementation of model parallel autoregressive transformers on GPUs, based on the Megatron and DeepSpeed libraries
https://www.eleuther.ai/
Apache License 2.0
6.95k stars 1.02k forks source link

KeyError when converting DPO weights from GPTNeoX format to HuggingFace Llama in post-training documentations #1317

Closed jacobthebanana closed 1 week ago

jacobthebanana commented 2 weeks ago

The command in the post-training example is giving the following error when trying to convert GPTNeoX weights of a Llama3.2 model back to HuggingFace format:

https://github.com/EleutherAI/gpt-neox/blob/59a5236ddaf721890e3d6ef98fb8ca66c2266ce0/post-training/README.md?plain=1#L56

Auto-detecting precision to save model into...
['sequential.0.word_embeddings.weight', 'sequential.2.attention.dense.weight', 'sequential.2.attention.query_key_value.weight', 'sequential.2.mlp.linear1.weight', 'sequential.2.mlp.linear2.weight', 'sequential.2.input_layernorm.scale', 'sequential.2.post_attention_layernorm.scale', 'sequential.3.attention.dense.weight', 'sequential.3.attention.query_key_value.weight', 'sequential.3.mlp.linear1.weight', 'sequential.3.mlp.linear2.weight', 'sequential.3.input_layernorm.scale', 'sequential.3.post_attention_layernorm.scale', ... , 'sequential.19.norm.scale', 'sequential.20.final_linear.weight']
Detected MLP naming convention: new

  0%|          | 0/16 [00:00<?, ?it/s]
  0%|          | 0/16 [00:00<?, ?it/s]
Traceback (most recent call last):
  File "gpt-neox/tools/ckpts/convert_neox_to_hf.py", line 906, in <module>
    main()
  File "gpt-neox/tools/ckpts/convert_neox_to_hf.py", line 856, in main
    hf_model = convert(
  File "gpt-neox/tools/ckpts/convert_neox_to_hf.py", line 609, in convert
    get_state(
  File "gpt-neox/tools/ckpts/convert_neox_to_hf.py", line 198, in get_state
    return [state_dict["module"][key] for state_dict in state_dicts]
  File "gpt-neox/tools/ckpts/convert_neox_to_hf.py", line 198, in <listcomp>
    return [state_dict["module"][key] for state_dict in state_dicts]
KeyError: 'sequential.2.input_layernorm.weight'

https://github.com/EleutherAI/gpt-neox/blob/59a5236ddaf721890e3d6ef98fb8ca66c2266ce0/tools/ckpts/convert_neox_to_hf.py#L469-L478

It appears that just like the reward model, the DPO model in this example is also converted from the HF Llama architecture. Thus, when converting from GPT-NeoX format back to Llama, wouldn't it also be necessary to specify --architecture llama?