hiyouga / LLaMA-Factory

Unify Efficient Fine-Tuning of 100+ LLMs
Apache License 2.0
25.26k stars 3.13k forks source link

Get "RuntimeError: 'weight' must be 2-D" Error when finetuning llama3-8b using ZeRO3 and customised dataset #4557

Closed NeWive closed 2 days ago

NeWive commented 2 days ago

Reminder

System Info

Reproduction

运行参数:

llamafactory-cli train /home/work/workspace/zyw/project/emr/llama3/LLaMA-Factory/examples/train_lora/llama3_lora_sft_ds3_emr.yaml

配置文件

llama3_lora_sft_ds3_emr.yaml:

### model
model_name_or_path: /home/work/workspace/zyw/project/emr/llama3/llama3-8B-hf

### method
stage: sft
do_train: true
finetuning_type: lora
lora_target: all
deepspeed: examples/deepspeed/ds_z3_config.json

### dataset
dataset: emr
template: llama3
cutoff_len: 8192
max_samples: 111
overwrite_cache: true
preprocessing_num_workers: 4

### output
output_dir: saves/llama3-8b/lora/sft9
logging_steps: 1
save_steps: 2
plot_loss: true
overwrite_output_dir: true

### train
per_device_train_batch_size: 4
gradient_accumulation_steps: 1
learning_rate: 1.0e-4
num_train_epochs: 16
lr_scheduler_type: cosine
warmup_ratio: 0.1
fp16: true
ddp_timeout: 180000000

### eval
val_size: 0.1
per_device_eval_batch_size: 1
eval_strategy: steps
eval_steps: 1

ds_z3_config.json:

{
  "train_batch_size": "auto",
  "train_micro_batch_size_per_gpu": "auto",
  "gradient_accumulation_steps": "auto",
  "gradient_clipping": "auto",
  "zero_allow_untested_optimizer": true,
  "fp16": {
    "enabled": "auto",
    "loss_scale": 0,
    "loss_scale_window": 1000,
    "initial_scale_power": 16,
    "hysteresis": 2,
    "min_loss_scale": 1
  },
  "bf16": {
    "enabled": "auto"
  },
  "zero_optimization": {
    "stage": 3,
    "overlap_comm": true,
    "contiguous_gradients": true,
    "sub_group_size": 1e9,
    "reduce_bucket_size": "auto",
    "stage3_prefetch_bucket_size": "auto",
    "stage3_param_persistence_threshold": "auto",
    "stage3_max_live_parameters": 1e9,
    "stage3_max_reuse_distance": 1e9,
    "stage3_gather_16bit_weights_on_model_save": true
  }
}

stacktrace:

[INFO|trainer.py:2078] 2024-06-26 18:20:43,472 >> ***** Running training *****
[INFO|trainer.py:2079] 2024-06-26 18:20:43,472 >>   Num examples = 99
[INFO|trainer.py:2080] 2024-06-26 18:20:43,472 >>   Num Epochs = 16
[INFO|trainer.py:2081] 2024-06-26 18:20:43,472 >>   Instantaneous batch size per device = 4
[INFO|trainer.py:2084] 2024-06-26 18:20:43,473 >>   Total train batch size (w. parallel, distributed & accumulation) = 4
[INFO|trainer.py:2085] 2024-06-26 18:20:43,473 >>   Gradient Accumulation steps = 1
[INFO|trainer.py:2086] 2024-06-26 18:20:43,473 >>   Total optimization steps = 400
[INFO|trainer.py:2087] 2024-06-26 18:20:43,476 >>   Number of trainable parameters = 20,971,520
  0%|                                                                                                                                                                                                                                                                                                                   | 0/400 [00:00<?, ?it/s]Traceback (most recent call last):
  File "/home/work/workspace/zyw/.conda/envs/rag/bin/llamafactory-cli", line 8, in <module>
    sys.exit(main())
             ^^^^^^
  File "/home/work/workspace/zyw/project/emr/llama3/LLaMA-Factory-for-update/LLaMA-Factory/src/llamafactory/cli.py", line 110, in main
    run_exp()
  File "/home/work/workspace/zyw/project/emr/llama3/LLaMA-Factory-for-update/LLaMA-Factory/src/llamafactory/train/tuner.py", line 50, in run_exp
    run_sft(model_args, data_args, training_args, finetuning_args, generating_args, callbacks)
  File "/home/work/workspace/zyw/project/emr/llama3/LLaMA-Factory-for-update/LLaMA-Factory/src/llamafactory/train/sft/workflow.py", line 88, in run_sft
    train_result = trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/work/workspace/zyw/.conda/envs/rag/lib/python3.11/site-packages/transformers/trainer.py", line 1885, in train
    return inner_training_loop(
           ^^^^^^^^^^^^^^^^^^^^
  File "/home/work/workspace/zyw/.conda/envs/rag/lib/python3.11/site-packages/transformers/trainer.py", line 2216, in _inner_training_loop
    tr_loss_step = self.training_step(model, inputs)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/work/workspace/zyw/.conda/envs/rag/lib/python3.11/site-packages/transformers/trainer.py", line 3238, in training_step
    loss = self.compute_loss(model, inputs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/work/workspace/zyw/.conda/envs/rag/lib/python3.11/site-packages/transformers/trainer.py", line 3264, in compute_loss
    outputs = model(**inputs)
              ^^^^^^^^^^^^^^^
  File "/home/work/workspace/zyw/.conda/envs/rag/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/work/workspace/zyw/.conda/envs/rag/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/work/workspace/zyw/.conda/envs/rag/lib/python3.11/site-packages/accelerate/utils/operations.py", line 822, in forward
    return model_forward(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/work/workspace/zyw/.conda/envs/rag/lib/python3.11/site-packages/accelerate/utils/operations.py", line 810, in __call__
    return convert_to_fp32(self.model_forward(*args, **kwargs))
                           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/work/workspace/zyw/.conda/envs/rag/lib/python3.11/site-packages/torch/amp/autocast_mode.py", line 16, in decorate_autocast
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/work/workspace/zyw/.conda/envs/rag/lib/python3.11/site-packages/peft/peft_model.py", line 1430, in forward
    return self.base_model(
           ^^^^^^^^^^^^^^^^
  File "/home/work/workspace/zyw/.conda/envs/rag/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/work/workspace/zyw/.conda/envs/rag/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/work/workspace/zyw/.conda/envs/rag/lib/python3.11/site-packages/peft/tuners/tuners_utils.py", line 179, in forward
    return self.model.forward(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/work/workspace/zyw/.conda/envs/rag/lib/python3.11/site-packages/transformers/models/llama/modeling_llama.py", line 1164, in forward
    outputs = self.model(
              ^^^^^^^^^^^
  File "/home/work/workspace/zyw/.conda/envs/rag/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/work/workspace/zyw/.conda/envs/rag/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/work/workspace/zyw/.conda/envs/rag/lib/python3.11/site-packages/transformers/models/llama/modeling_llama.py", line 925, in forward
    inputs_embeds = self.embed_tokens(input_ids)
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/work/workspace/zyw/.conda/envs/rag/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/work/workspace/zyw/.conda/envs/rag/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1561, in _call_impl
    result = forward_call(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/work/workspace/zyw/.conda/envs/rag/lib/python3.11/site-packages/torch/nn/modules/sparse.py", line 163, in forward
    return F.embedding(
           ^^^^^^^^^^^^
  File "/home/work/workspace/zyw/.conda/envs/rag/lib/python3.11/site-packages/torch/nn/functional.py", line 2237, in embedding
    return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: 'weight' must be 2-D
  0%|                              

Expected behavior

Using a zero0-based configuration, llama3 can be finetuned successfully. But after switching to zero3-based configuration, it does not work. No changes are made to deepspeed-zero*-related config files.

Others

No response

hiyouga commented 2 days ago

It looks like you are fine-tuning model on a single card, you can remove the deepspeed config or use FORCE_TORCHRUN=1 env var

NeWive commented 2 days ago

Solved, many thanks