hiyouga / LLaMA-Factory

Unified Efficient Fine-Tuning of 100+ LLMs (ACL 2024)
https://arxiv.org/abs/2403.13372
Apache License 2.0
34.45k stars 4.25k forks source link

Qlora 报错 #5482

Closed dayuyang1999 closed 1 month ago

dayuyang1999 commented 1 month ago

Reminder

System Info

- `llamafactory` version: 0.9.1.dev0
- Platform: Linux-5.19.0-0_fbk12_zion_11583_g0bef9520ca2b-x86_64-with-glibc2.34
- Python version: 3.12.5
- PyTorch version: 2.4.1+cu121 (GPU)
- Transformers version: 4.44.2
- Datasets version: 2.21.0
- Accelerate version: 0.34.2
- PEFT version: 0.12.0
- TRL version: 0.9.6
- GPU type: NVIDIA H100

Reproduction

yaml 的配置

report_to: wandb
run_name: mywork-70b-qlora

### model
model_name_or_path: /saved_models/models--meta-llama--Meta-Llama-3.1-70B-Instruct/snapshots/llama3-70B-instruct
quantization_bit: 8
quantization_method: bitsandbytes

### method
stage: sft
do_train: true
finetuning_type: lora
lora_target: all

### dataset
dataset: mywork_50
template: llama3
cutoff_len: 4096
overwrite_cache: true
preprocessing_num_workers: 16

### output
output_dir: saves/mywork/50/
logging_steps: 10
save_steps: 100
plot_loss: true
overwrite_output_dir: true

### train
per_device_train_batch_size: 1
gradient_accumulation_steps: 8
learning_rate: 1.0e-4
num_train_epochs: 1.0
lr_scheduler_type: cosine
warmup_ratio: 0.1
bf16: true
ddp_timeout: 180000000

### eval
eval_dataset: mywork_eval
per_device_eval_batch_size: 1
eval_strategy: steps
eval_steps: 500

运行:

llamafactory-cli train examples/train_lora/train_70b_qlora.yaml

报错:

/home/myname/.conda/envs/llm_env/lib/python3.12/site-packages/bitsandbytes/autograd/_functions.py:316: UserWarning: MatMul8bitLt: inputs will be cast from torch.bfloat16 to float16 during quantization
  warnings.warn(f"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization")
cuBLAS API failed with status 15
A: torch.Size([4096, 8192]), B: torch.Size([8192, 8192]), C: (4096, 8192); (lda, ldb, ldc): (c_int(131072), c_int(262144), c_int(131072)); (m, n, k): (c_int(4096), c_int(8192), c_int(8192))
Traceback (most recent call last):
  File "/home/dayuyang/.conda/envs/llm_env/bin/llamafactory-cli", line 8, in <module>
    sys.exit(main())
             ^^^^^^
  File "/home/dayuyang/CRS/LLaMA-Factory/src/llamafactory/cli.py", line 111, in main
    run_exp()
  File "/home/dayuyang/CRS/LLaMA-Factory/src/llamafactory/train/tuner.py", line 50, in run_exp
    run_sft(model_args, data_args, training_args, finetuning_args, generating_args, callbacks)
  File "/home/dayuyang/CRS/LLaMA-Factory/src/llamafactory/train/sft/workflow.py", line 96, in run_sft
    train_result = trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/dayuyang/.conda/envs/llm_env/lib/python3.12/site-packages/transformers/trainer.py", line 1938, in train
    return inner_training_loop(
           ^^^^^^^^^^^^^^^^^^^^
  File "/home/dayuyang/.conda/envs/llm_env/lib/python3.12/site-packages/transformers/trainer.py", line 2279, in _inner_training_loop
    tr_loss_step = self.training_step(model, inputs)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/dayuyang/.conda/envs/llm_env/lib/python3.12/site-packages/transformers/trainer.py", line 3318, in training_step
    loss = self.compute_loss(model, inputs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/dayuyang/.conda/envs/llm_env/lib/python3.12/site-packages/transformers/trainer.py", line 3363, in compute_loss
    outputs = model(**inputs)
              ^^^^^^^^^^^^^^^
  File "/home/dayuyang/.conda/envs/llm_env/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/dayuyang/.conda/envs/llm_env/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/dayuyang/.conda/envs/llm_env/lib/python3.12/site-packages/accelerate/utils/operations.py", line 820, in forward
    return model_forward(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/dayuyang/.conda/envs/llm_env/lib/python3.12/site-packages/accelerate/utils/operations.py", line 808, in __call__
    return convert_to_fp32(self.model_forward(*args, **kwargs))
                           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/dayuyang/.conda/envs/llm_env/lib/python3.12/site-packages/torch/amp/autocast_mode.py", line 43, in decorate_autocast
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/dayuyang/.conda/envs/llm_env/lib/python3.12/site-packages/peft/peft_model.py", line 1577, in forward
    return self.base_model(
           ^^^^^^^^^^^^^^^^
  File "/home/dayuyang/.conda/envs/llm_env/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/dayuyang/.conda/envs/llm_env/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/dayuyang/.conda/envs/llm_env/lib/python3.12/site-packages/peft/tuners/tuners_utils.py", line 188, in forward
    return self.model.forward(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/dayuyang/.conda/envs/llm_env/lib/python3.12/site-packages/accelerate/hooks.py", line 170, in new_forward
    output = module._old_forward(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/dayuyang/.conda/envs/llm_env/lib/python3.12/site-packages/transformers/models/llama/modeling_llama.py", line 1189, in forward
    outputs = self.model(
              ^^^^^^^^^^^
  File "/home/dayuyang/.conda/envs/llm_env/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/dayuyang/.conda/envs/llm_env/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/dayuyang/.conda/envs/llm_env/lib/python3.12/site-packages/accelerate/hooks.py", line 170, in new_forward
    output = module._old_forward(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/dayuyang/.conda/envs/llm_env/lib/python3.12/site-packages/transformers/models/llama/modeling_llama.py", line 989, in forward
    layer_outputs = self._gradient_checkpointing_func(
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/dayuyang/CRS/LLaMA-Factory/src/llamafactory/model/model_utils/checkpointing.py", line 93, in custom_gradient_checkpointing_func
    return gradient_checkpointing_func(func, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/dayuyang/.conda/envs/llm_env/lib/python3.12/site-packages/torch/_compile.py", line 31, in inner
    return disable_fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/dayuyang/.conda/envs/llm_env/lib/python3.12/site-packages/torch/_dynamo/eval_frame.py", line 600, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/home/dayuyang/.conda/envs/llm_env/lib/python3.12/site-packages/torch/utils/checkpoint.py", line 481, in checkpoint
    return CheckpointFunction.apply(function, preserve, *args)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/dayuyang/.conda/envs/llm_env/lib/python3.12/site-packages/torch/autograd/function.py", line 574, in apply
    return super().apply(*args, **kwargs)  # type: ignore[misc]
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/dayuyang/.conda/envs/llm_env/lib/python3.12/site-packages/torch/utils/checkpoint.py", line 255, in forward
    outputs = run_function(*args)
              ^^^^^^^^^^^^^^^^^^^
  File "/home/dayuyang/.conda/envs/llm_env/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/dayuyang/.conda/envs/llm_env/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/dayuyang/.conda/envs/llm_env/lib/python3.12/site-packages/accelerate/hooks.py", line 170, in new_forward
    output = module._old_forward(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/dayuyang/.conda/envs/llm_env/lib/python3.12/site-packages/transformers/models/llama/modeling_llama.py", line 734, in forward
    hidden_states, self_attn_weights, present_key_value = self.self_attn(
                                                          ^^^^^^^^^^^^^^^
  File "/home/dayuyang/.conda/envs/llm_env/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/dayuyang/.conda/envs/llm_env/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/dayuyang/.conda/envs/llm_env/lib/python3.12/site-packages/accelerate/hooks.py", line 170, in new_forward
    output = module._old_forward(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/dayuyang/.conda/envs/llm_env/lib/python3.12/site-packages/transformers/models/llama/modeling_llama.py", line 617, in forward
    query_states = self.q_proj(hidden_states)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/dayuyang/.conda/envs/llm_env/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/dayuyang/.conda/envs/llm_env/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/dayuyang/.conda/envs/llm_env/lib/python3.12/site-packages/peft/tuners/lora/bnb.py", line 221, in forward
    result = self.base_layer(x, *args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/dayuyang/.conda/envs/llm_env/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/dayuyang/.conda/envs/llm_env/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/dayuyang/.conda/envs/llm_env/lib/python3.12/site-packages/accelerate/hooks.py", line 170, in new_forward
    output = module._old_forward(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/dayuyang/.conda/envs/llm_env/lib/python3.12/site-packages/bitsandbytes/nn/modules.py", line 817, in forward
    out = bnb.matmul(x, self.weight, bias=self.bias, state=self.state)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/dayuyang/.conda/envs/llm_env/lib/python3.12/site-packages/bitsandbytes/autograd/_functions.py", line 556, in matmul
    return MatMul8bitLt.apply(A, B, out, bias, state)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/dayuyang/.conda/envs/llm_env/lib/python3.12/site-packages/torch/autograd/function.py", line 574, in apply
    return super().apply(*args, **kwargs)  # type: ignore[misc]
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/dayuyang/.conda/envs/llm_env/lib/python3.12/site-packages/bitsandbytes/autograd/_functions.py", line 395, in forward
    out32, Sout32 = F.igemmlt(C32A, state.CxB, SA, state.SB)
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/dayuyang/.conda/envs/llm_env/lib/python3.12/site-packages/bitsandbytes/functional.py", line 2341, in igemmlt
    raise Exception("cublasLt ran into an error!")
Exception: cublasLt ran into an error!

Expected behavior

谢谢!

Others

No response

hiyouga commented 1 month ago

使用 4bit lora