erfanzar / EasyDeL

Accelerate, Optimize performance with streamlined training and serving options with JAX.
https://easydel.readthedocs.io/en/latest/
Apache License 2.0
191 stars 23 forks source link

Error converting easydel checkpoint to huggingface model. #132

Closed IvoryTower800 closed 6 months ago

IvoryTower800 commented 6 months ago

Describe the bug Hi, I saved a checkpoint and want to convert it to safetensors format. It was successful for Phi-2 model. But when I try Gemma model, the error raised. Below is my code. The checkpoint file itself should be fine, because it can load and continue training.

Thank you.

To Reproduce

with jax.default_device(jax.devices("cpu")[0]):
    model = easystate_to_huggingface_model(
        state=EasyDelState.load_state(
            "/root/gemma-2b-S16971.easy"
        ),
        base_huggingface_module=GemmaForCausalLM,
        config=model.config
    )

model.half()
model.save_pretrained('gemma-2b-S16971',safe_serialization=True)
Converting EasyDelState to torch state_dict: 100%|██████████| 164/164 [00:08<00:00, 18.73it/s]
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[6], line 2
      1 with jax.default_device(jax.devices("cpu")[0]):
----> 2     model = easystate_to_huggingface_model(
      3         state=EasyDelState.load_state(
      4             "/root/writer-gemma-2b-S16971.easy"
      5         ),
      6         base_huggingface_module=GemmaForCausalLM,
      7         config=model.config
      8     )

File /usr/local/lib/python3.10/site-packages/EasyDel/transform/easydel_transform.py:333, in easystate_to_huggingface_model(state, config, base_huggingface_module, base_huggingface_module_kwarguments, dtype, transpose_needed, transpose_not_needed, select_params_field, rnn_based_or_rwkv, auto_correct)
    321 state_dict = easystate_to_torch(
    322     state=state,
    323     dtype=dtype,
   (...)
    327     rnn_based_or_rwkv=rnn_based_or_rwkv
    328 )
    329 model = base_huggingface_module(
    330     config=config,
    331     **base_huggingface_module_kwarguments
    332 )
--> 333 model.load_state_dict(state_dict)
    334 return model

File /usr/local/lib/python3.10/site-packages/torch/nn/modules/module.py:2152, in Module.load_state_dict(self, state_dict, strict, assign)
   2147         error_msgs.insert(
   2148             0, 'Missing key(s) in state_dict: {}. '.format(
   2149                 ', '.join(f'"{k}"' for k in missing_keys)))
   2151 if len(error_msgs) > 0:
-> 2152     raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
   2153                        self.__class__.__name__, "\n\t".join(error_msgs)))
   2154 return _IncompatibleKeys(missing_keys, unexpected_keys)

RuntimeError: Error(s) in loading state_dict for GemmaForCausalLM:
    Missing key(s) in state_dict: "lm_head.weight". 
erfanzar commented 6 months ago

you are converting a model that uses tie_word_embedding and EasyDeL by default ignoring the lm_head for memory-saving cause use this code to pass the state to the converter

state = EasyDelState.load_state(
    output.checkpoint_path
)

state_new_params = {
    "params" : state.params["params"] | {
        "lm_head" : {
            "kernel" : state.params["params"]["model"]["embed_tokens"]["embedding"].T
        }
    }
}

state = state.replace( params = state_new_params )

with jax.default_device(jax.devices("cpu")[0]):
    model = easystate_to_huggingface_model(
        state=state,
        base_huggingface_module=GemmaForCausalLM,
        config=model.config
    )

model.half()
IvoryTower800 commented 6 months ago

Thank you! It worked.