Closed IvoryTower800 closed 6 months ago
you are converting a model that uses tie_word_embedding and EasyDeL by default ignoring the lm_head for memory-saving cause use this code to pass the state to the converter
state = EasyDelState.load_state(
output.checkpoint_path
)
state_new_params = {
"params" : state.params["params"] | {
"lm_head" : {
"kernel" : state.params["params"]["model"]["embed_tokens"]["embedding"].T
}
}
}
state = state.replace( params = state_new_params )
with jax.default_device(jax.devices("cpu")[0]):
model = easystate_to_huggingface_model(
state=state,
base_huggingface_module=GemmaForCausalLM,
config=model.config
)
model.half()
Thank you! It worked.
Describe the bug Hi, I saved a checkpoint and want to convert it to safetensors format. It was successful for Phi-2 model. But when I try Gemma model, the error raised. Below is my code. The checkpoint file itself should be fine, because it can load and continue training.
Thank you.
To Reproduce