unslothai / unsloth

Finetune Llama 3.1, Mistral, Phi & Gemma LLMs 2-5x faster with 80% less memory
https://unsloth.ai
Apache License 2.0
14.79k stars 979 forks source link

RuntimeError when saving checkpoints #278

Open JhonDan1999 opened 5 months ago

JhonDan1999 commented 5 months ago

I'm encountering a RuntimeError when attempting to save checkpoints during fine-tuning the "unsloth/gemma-2b-it-bnb-4bit" model. Below is a breakdown of my setup and the error encountered.

Model: unsloth/gemma-2b-it-bnb-4bit Relevant Training Configuration:

trainer = Trainer(
   # other trainer arguments
   save_steps=50  # Saving checkpoint every 50 steps 
)
trainer.train()

The training process halts with the following RuntimeError:

 RuntimeError: 
     Some tensors share memory, this will lead to duplicate memory on disk and potential differences when loading them again: [{'model.base_model.model.model.embed_tokens.weight', 'model.base_model.model.lm_head.weight'}].
     A potential way to correctly save your model is to use `save_model`.
     More information at https://huggingface.co/docs/safetensors/torch_shared_tensors 
danielhanchen commented 5 months ago

Hmm Gemma has tied embeddings - do you know if other tied embeddings and lm_head models have the same issue?

JhonDan1999 commented 5 months ago

Thank you for your prompt response. I just experimented with the Gemma and "unsloth/llama-2-7b-bnb-4bit" models, and I'm encountering this issue specifically with Gemma.

danielhanchen commented 5 months ago

Yep! It seems like the tied embeddings issue - apologies it was a super busy so didn't have time on this - apologies again!!

thesven commented 1 month ago

Was there ever a resolution to this issue? I'm now running into the same with the 9b variant of Gemma 2.

eganlau commented 1 month ago

I ran into the same error while trying to save the merged model to HF, but adding safe_serialization=False can get around this issue. i.e.

model.push_to_hub_merged("hf_account/model_name", tokenizer, save_method = "merged_4bit_forced", token="hf_token"), safe_serialization=False)