unslothai / unsloth

Finetune Llama 3.1, Mistral, Phi & Gemma LLMs 2-5x faster with 80% less memory
https://unsloth.ai
Apache License 2.0
15.6k stars 1.05k forks source link

Infinite Generation #970

Open MuhammadBilalKhan267 opened 2 weeks ago

MuhammadBilalKhan267 commented 2 weeks ago

Qwen2_7b_Instruct_Potato_Dataset.zip Hello, I am fine-tuning Qwen 2 7B on an Urdu dataset. However, it seem to be generating infinite random tokens from the training set even when asked basic questions. I don't know what I am doing wrong as I don't have much experience fine-tuning LLM's. I am using a pretty small dataset (105) examples for training so that could possibly be the problem. The code is attached. The model doesn't even work well on the evaluation set.

danielhanchen commented 2 weeks ago

Oh yep a small dataset might be the culprit - Hmm are you using the multi language notebook we shared?

MuhammadBilalKhan267 commented 2 weeks ago

No, I only modified the notebook for llama3 instruct as I was previously working with it. I also tried using a larger dataset (600+ examples) but to no avail. The examples are very distinctive though so could that be a problem? Could the model size be an issue (I am using Qwen 2 7b instruct 4bit)? I would appreciate any tips.

Ammar-Alnagar commented 1 week ago

600 is still a small dataset compared to a 7B model , I suggest lowering the learning rate to better capture the features , less epochs to avoid overfitting , also make sure you are adding an . Very unlikely but could be the cause , you are using the wrong template to inference the model.

If you can share your settings it might be of help.

MuhammadBilalKhan267 commented 1 week ago

The infinite generation problem was solved. The problem was a small dataset. This time I used a 7000 question answer dataset from hugging face. While it did good on a few questions related to the training dataset, overall performance wasn't satisfactory. It also suffered from catastrophic forgetting. Qwen 2 has been pre-trained on Urdu data so it can create Urdu sentences properly. However, the domain knowledge may not be sufficient. Do I go for continuous pre-training on the whole dataset (45,000 examples) before instruction-tuning again with 7,000 examples (which I am currently doing) or is the problem rank, epochs or the learning rate. My code is attached. Any tips are appreciated!

Qwen2_7b_Instruct_QA_2.zip

Ammar-Alnagar commented 1 week ago

I am glad you fixed the infinite generation issue , after taking a look at your settings ,

`model = FastLanguageModel.get_peft_model( model, r = 32, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128 target_modules = ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj", "embed_tokens", "lm_head"], lora_alpha = 16, lora_dropout = 0, # Supports any, but = 0 is optimized bias = "none", # Supports any, but = "none" is optimized

[NEW] "unsloth" uses 30% less VRAM, fits 2x larger batch sizes!

use_gradient_checkpointing = "unsloth", # True or "unsloth" for very long context
random_state = 3407,
use_rslora = False,  # We support rank stabilacized LoRA
loftq_config = None, # And LoftQ

)`

`from trl import SFTTrainer from transformers import TrainingArguments from unsloth import is_bfloat16_supported

trainer = SFTTrainer( model = model, tokenizer = tokenizer, train_dataset = train_set, dataset_text_field = "text", max_seq_length = max_seq_length, dataset_num_proc = 4, packing = False, # Can make training 5x faster for short sequences. args = TrainingArguments( per_device_train_batch_size = 2, gradient_accumulation_steps = 4, warmup_steps = 500, num_train_epochs=5, #experiment with this number after seeing the loss log learning_rate = 4e-4, # should be a good start and then tinker around if its taking too long or the model is still not capturing the features. fp16 = not is_bfloat16_supported(), bf16 = is_bfloat16_supported(), logging_steps = 1, optim = "adamw_8bit", weight_decay = 0.01, lr_scheduler_type = "linear", seed = 3407, output_dir = "outputs", ), )`

let me know if you need further help :)