jiaweizzhao / GaLore

GaLore: Memory-Efficient LLM Training by Gradient Low-Rank Projection
Apache License 2.0
1.43k stars 148 forks source link

(Question) About glue tasks #52

Open ZhichaoWang091732 opened 5 months ago

ZhichaoWang091732 commented 5 months ago

Hello, thanks for your inspiring and excellent work!

I want to try full fine-tuning to compare with Galora, and I have blocked the use of Galora. However, I'm having some problems that when I try to run the glue task (i.e. mrpc) to full fine-tune roberta, I find that the eval acc doesn't change at all as the training progresses. I have ruled out a possible overfitting problem and I would like to ask the author or anyone else if there is a relevant solution.

image

jiaweizzhao commented 4 months ago

Hi, thanks for your question. Were you using the hyperparameters and settings provided by our paper (appendix)?

mzf666 commented 2 months ago

I have the same issue. I have checked the gradient norm and the learning rate are not zero. In the original code, once the metric is initialized, it was not refreshed and kept receiving new predictions across the whole training process. Hence, I manually reload the 'metric' using the 'metric = evaluate.load('glue', args.task_name)'.

However, after fixing this potential bug, it seems that while the eval loss of the finetuned model does changes, the accuracy and f1 score metrics remains the same.

python run_glue_geomlrk.py \
    --model_name_or_path roberta-base \
    --task_name mrpc \
    --max_length 512 \
    --seed=1234 \
    --lora_r 8 \
    --lora_all_modules \
    --per_device_train_batch_size 16 \
    --num_train_epochs 30 \
    --learning_rate 1e-4 \
    --lr_scheduler_type linear \
    --weight_decay 0.1 

image

python run_glue_geomlrk.py \
    --model_name_or_path roberta-base \
    --task_name mrpc \
    --max_length 512 \
    --seed=1234 \
    --lora_r 8 \
    --lora_all_modules \
    --per_device_train_batch_size 16 \
    --num_train_epochs 30 \
    --learning_rate 1e-2 \
    --lr_scheduler_type linear \
    --weight_decay 0.1 

image

mzf666 commented 2 months ago

Hello, thanks for your inspiring and excellent work!

I want to try full fine-tuning to compare with Galora, and I have blocked the use of Galora. However, I'm having some problems that when I try to run the glue task (i.e. mrpc) to full fine-tune roberta, I find that the eval acc doesn't change at all as the training progresses. I have ruled out a possible overfitting problem and I would like to ask the author or anyone else if there is a relevant solution.

image

I spent several hours in adjusting the hyperparameters. I found the adamw optimizer does work with suitable learning rate. You can try this launching command:

CUDA_VISIBLE_DEVICES=$cuda_idx python run_glue_geomlrk.py \
    --model_name_or_path roberta-base \
    --task_name mrpc \
    --max_length 512 \
    --seed=1234 \
    --lora_r 4 \
    --per_device_train_batch_size 16 \
    --num_train_epochs 30 \
    --learning_rate 1e-5 \
    --lr_scheduler_type linear \
    --weight_decay 0.1 \

This leads to the result

image

It seems that improper learning rate may drive the model to mode collapse, i.e. assigning the same logits on any input sequence. Thus, the accuracy and F1 score remains unchanged as they are doing a fixed guess.

MaeChd commented 2 days ago

Hello, thanks for your inspiring and excellent work! I want to try full fine-tuning to compare with Galora, and I have blocked the use of Galora. However, I'm having some problems that when I try to run the glue task (i.e. mrpc) to full fine-tune roberta, I find that the eval acc doesn't change at all as the training progresses. I have ruled out a possible overfitting problem and I would like to ask the author or anyone else if there is a relevant solution. image

I spent several hours in adjusting the hyperparameters. I found the adamw optimizer does work with suitable learning rate. You can try this launching command:

CUDA_VISIBLE_DEVICES=$cuda_idx python run_glue_geomlrk.py \
    --model_name_or_path roberta-base \
    --task_name mrpc \
    --max_length 512 \
    --seed=1234 \
    --lora_r 4 \
    --per_device_train_batch_size 16 \
    --num_train_epochs 30 \
    --learning_rate 1e-5 \
    --lr_scheduler_type linear \
    --weight_decay 0.1 \

This leads to the result

image

It seems that improper learning rate may drive the model to mode collapse, i.e. assigning the same logits on any input sequence. Thus, the accuracy and F1 score remains unchanged as they are doing a fixed guess.

Hello, thank you for your answer to this question; However, when I tried to reproduce LoRA and Roberta fine-tuning on the GLUE task, there was a phenomenon that the training loss was almost unchanged and the evaluation index was constant. I wonder if you have any attempts?