yuhuixu1993 / qa-lora

Official PyTorch implementation of QA-LoRA
MIT License
117 stars 11 forks source link

Encounter data type problem in train #24

Closed wenjingk-xilinx closed 7 months ago

wenjingk-xilinx commented 10 months ago

Hi, I tried to run the code with this script:

python -m pdb qalora.py  \
    --model_path ./llama-7b/ \
    --bits 4 \
    --output_dir ./output \
    --dataset alpaca \
    --do_train True \
    --do_eval True \
    --do_mmlu_eval True \
    --source_max_len 384 \
    --target_max_len 128 \
    --per_device_train_batch_size 4 \
    --per_device_eval_batch_size 4 \
    --gradient_accumulation_steps 4 \
    --logging_steps 10 \
    --max_steps 10000 \
    --save_strategy steps \
    --data_seed 42 \
    --save_steps 1000 \
    --save_total_limit 40 \
    --evaluation_strategy steps \
    --eval_dataset_size 1024 \
    --max_eval_samples 1000 \
    --eval_steps 1000 \
    --optim paged_adamw_32bit 
adding LoRA modules...
/proj/ossdataset1/wenjingk/anaconda3/envs/qalora/lib/python3.8/site-packages/auto_gptq/utils/peft_utils.py:369: UserWarning: You can just ignore this warning if the peft type you use isn't in ['LORA', 'ADALORA'].
LlamaGPTQForCausalLM supports injecting fused attention but not enables this time. If you are training adapters, you must also disable fused attention injection when loading quantized base model at inference time, otherwise adapters
 may not be added to base model properly. If you are loading adapters to do inference, you can reference to adapter's config file to check whether the adapters are trained using base model that not enable fused attention injection.
  warnings.warn(
trainable params: 11163648.0 || all params: 1162579968 || trainable: 0.960247751318557
loaded model
Using pad_token, but it is not set yet.
Splitting train dataset in train and validation according to `eval_dataset_size`
Generating eval split: 1531 examples [00:00, 30180.38 examples/s]
Generating test split: 14042 examples [00:00, 51912.29 examples/s]
torch.float16 263512064 0.22665985519756196
torch.int32 809500672 0.6962918597072243
torch.float32 89575424 0.07704828509521378
  0%|                                                                                                                                                                                                         | 0/10000 [00:00<?, ?it/s]
Traceback (most recent call last):
  File "/proj/ossdataset1/wenjingk/anaconda3/envs/qalora/lib/python3.8/pdb.py", line 1705, in main
    pdb._runscript(mainpyfile)
  File "/proj/ossdataset1/wenjingk/anaconda3/envs/qalora/lib/python3.8/pdb.py", line 1573, in _runscript
    self.run(statement)
  File "/proj/ossdataset1/wenjingk/anaconda3/envs/qalora/lib/python3.8/bdb.py", line 580, in run
    exec(cmd, globals, locals)
  File "<string>", line 1, in <module>
  File "/proj/ossdataset1/wenjingk/peft/qa-lora/qalora.py", line 4, in <module>
    from collections import defaultdict
  File "/proj/ossdataset1/wenjingk/peft/qa-lora/qalora.py", line 788, in train
    train_result = trainer.train(resume_from_checkpoint=resume_from_checkpoint)
  File "/proj/ossdataset1/wenjingk/anaconda3/envs/qalora/lib/python3.8/site-packages/transformers/trainer.py", line 1539, in train
    return inner_training_loop(
  File "/proj/ossdataset1/wenjingk/anaconda3/envs/qalora/lib/python3.8/site-packages/transformers/trainer.py", line 1809, in _inner_training_loop
    tr_loss_step = self.training_step(model, inputs)
  File "/proj/ossdataset1/wenjingk/anaconda3/envs/qalora/lib/python3.8/site-packages/transformers/trainer.py", line 2654, in training_step
    loss = self.compute_loss(model, inputs)
  File "/proj/ossdataset1/wenjingk/anaconda3/envs/qalora/lib/python3.8/site-packages/transformers/trainer.py", line 2679, in compute_loss
    outputs = model(**inputs)
  File "/proj/ossdataset1/wenjingk/anaconda3/envs/qalora/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/proj/ossdataset1/wenjingk/anaconda3/envs/qalora/lib/python3.8/site-packages/peft/peft_model.py", line 922, in forward
    return self.base_model(
  File "/proj/ossdataset1/wenjingk/anaconda3/envs/qalora/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/proj/ossdataset1/wenjingk/anaconda3/envs/qalora/lib/python3.8/site-packages/transformers/models/llama/modeling_llama.py", line 824, in forward
    logits = self.lm_head(hidden_states)
  File "/proj/ossdataset1/wenjingk/anaconda3/envs/qalora/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/proj/ossdataset1/wenjingk/anaconda3/envs/qalora/lib/python3.8/site-packages/torch/nn/modules/linear.py", line 114, in forward
    return F.linear(input, self.weight, self.bias)
RuntimeError: expected scalar type Float but found Half

F.linear(input, self.weight, self.bias) raises an error, input.dtype=float32, self.weight.dtype=float16.

Then I add '--fp16' flag in the running script, the bug becomes:

trainable params: 11163648.0 || all params: 1162579968 || trainable: 0.960247751318557
loaded model
Using pad_token, but it is not set yet.
Splitting train dataset in train and validation according to `eval_dataset_size`
torch.float16 263512064 0.22665985519756196
torch.int32 809500672 0.6962918597072243
torch.float32 89575424 0.07704828509521378
  0%|                                                                                                                                                                                                         | 0/10000 [00:00<?, ?it/s]Traceback (most recent call last):
  File "/proj/ossdataset1/wenjingk/anaconda3/envs/qalora/lib/python3.8/pdb.py", line 1705, in main
    pdb._runscript(mainpyfile)
  File "/proj/ossdataset1/wenjingk/anaconda3/envs/qalora/lib/python3.8/pdb.py", line 1573, in _runscript
    self.run(statement)
  File "/proj/ossdataset1/wenjingk/anaconda3/envs/qalora/lib/python3.8/bdb.py", line 580, in run
    exec(cmd, globals, locals)
  File "<string>", line 1, in <module>
  File "/proj/ossdataset1/wenjingk/peft/qa-lora/qalora.py", line 4, in <module>
    from collections import defaultdict
  File "/proj/ossdataset1/wenjingk/peft/qa-lora/qalora.py", line 788, in train
    train_result = trainer.train(resume_from_checkpoint=resume_from_checkpoint)
  File "/proj/ossdataset1/wenjingk/anaconda3/envs/qalora/lib/python3.8/site-packages/transformers/trainer.py", line 1539, in train
    return inner_training_loop(
  File "/proj/ossdataset1/wenjingk/anaconda3/envs/qalora/lib/python3.8/site-packages/transformers/trainer.py", line 1867, in _inner_training_loop
    self.accelerator.clip_grad_norm_(
  File "/proj/ossdataset1/wenjingk/anaconda3/envs/qalora/lib/python3.8/site-packages/accelerate/accelerator.py", line 1925, in clip_grad_norm_
    self.unscale_gradients()
  File "/proj/ossdataset1/wenjingk/anaconda3/envs/qalora/lib/python3.8/site-packages/accelerate/accelerator.py", line 1888, in unscale_gradients
    self.scaler.unscale_(opt)
  File "/proj/ossdataset1/wenjingk/anaconda3/envs/qalora/lib/python3.8/site-packages/torch/cuda/amp/grad_scaler.py", line 284, in unscale_
    optimizer_state["found_inf_per_device"] = self._unscale_grads_(optimizer, inv_scale, found_inf, False)
  File "/proj/ossdataset1/wenjingk/anaconda3/envs/qalora/lib/python3.8/site-packages/torch/cuda/amp/grad_scaler.py", line 212, in _unscale_grads_
    raise ValueError("Attempting to unscale FP16 gradients.")
ValueError: Attempting to unscale FP16 gradients.
xxw11 commented 8 months ago

Hi, the issue with data type conflicts has been resolved. You can simply use the default parameters in qalora.py for fine-tuning.

wenjingk-xilinx commented 8 months ago

Hi @xxw11 , thank you for your update. I see you've changed the data format into fp32. The training becomes much slower with fp32 than fp16. Have you compared the accuracy with these two different data types?

xxw11 commented 8 months ago

Yes, you're right, the current code will be slower. In my experiments, using FP32 results in higher accuracy, especially on Llama1.