是否已有关于该错误的issue或讨论? | Is there an existing issue / discussion for this?
[X] 我已经搜索过已有的issues和讨论 | I have searched the existing issues / discussions
该问题是否在FAQ中有解答? | Is there an existing answer for this in FAQ?
[X] 我已经搜索过FAQ | I have searched FAQ
当前行为 | Current Behavior
我在使用项目提供的lora微调脚本时 出现如下错误:
[2024-10-28 09:13:26,129] [INFO] [real_accelerator.py:219:get_accelerator] Setting ds_accelerator to cuda (auto detect)
/home/kuuga/.conda/envs/mental-health/lib/python3.10/site-packages/transformers/training_args.py:1545: FutureWarning: `evaluation_strategy` is deprecated and will be removed in version 4.46 of 🤗 Transformers. Use `eval_strategy` instead
warnings.warn(
[2024-10-28 09:13:28,108] [INFO] [comm.py:652:init_distributed] cdb=None
[2024-10-28 09:13:28,108] [INFO] [comm.py:683:init_distributed] Initializing TorchBackend in DeepSpeed with backend nccl
[2024-10-28 09:13:38,415] [INFO] [config.py:733:__init__] Config mesh_device None world_size = 1
[2024-10-28 09:13:51,071] [INFO] [partition_parameters.py:348:__exit__] finished initializing model - num_params = 791, num_elems = 8.11B
Loading checkpoint shards: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:06<00:00, 1.53s/it]
Currently using LoRA for fine-tuning the MiniCPM-V model.
{'Total': 9169527776, 'Trainable': 1070352624}
llm_type=qwen2
Loading data...
Detected kernel version 5.4.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.
max_steps is given, it will override any value given in num_train_epochs
Using /home/kuuga/.cache/torch_extensions/py310_cu121 as PyTorch extensions root...
Emitting ninja build file /home/kuuga/.cache/torch_extensions/py310_cu121/cpu_adam/build.ninja...
Building extension module cpu_adam...
Allowing ninja to set a default number of workers... (overridable by setting the environment variable MAX_JOBS=N)
ninja: no work to do.
Loading extension module cpu_adam...
Time to load cpu_adam op: 2.237880229949951 seconds
Parameter Offload: Total persistent parameters: 3039200 in 759 params
[2024-10-28 09:14:13,013] [WARNING] [lr_schedules.py:683:get_lr] Attempting to get learning rate from scheduler before it has started
0%| | 0/10000 [00:00<?, ?it/s]Starting from v4.46, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)
[rank0]: Traceback (most recent call last):
[rank0]: File "/home/kuuga/MiniCPM-V/finetune/finetune.py", line 299, in <module>
[rank0]: train()
[rank0]: File "/home/kuuga/MiniCPM-V/finetune/finetune.py", line 289, in train
[rank0]: trainer.train()
[rank0]: File "/home/kuuga/.conda/envs/mental-health/lib/python3.10/site-packages/transformers/trainer.py", line 2052, in train
[rank0]: return inner_training_loop(
[rank0]: File "/home/kuuga/.conda/envs/mental-health/lib/python3.10/site-packages/transformers/trainer.py", line 2388, in _inner_training_loop
[rank0]: tr_loss_step = self.training_step(model, inputs)
[rank0]: File "/home/kuuga/MiniCPM-V/finetune/trainer.py", line 211, in training_step
[rank0]: self.accelerator.backward(loss)
[rank0]: File "/home/kuuga/.conda/envs/mental-health/lib/python3.10/site-packages/accelerate/accelerator.py", line 2238, in backward
[rank0]: self.deepspeed_engine_wrapped.backward(loss, **kwargs)
[rank0]: File "/home/kuuga/.conda/envs/mental-health/lib/python3.10/site-packages/accelerate/utils/deepspeed.py", line 186, in backward
[rank0]: self.engine.backward(loss, **kwargs)
[rank0]: File "/home/kuuga/.conda/envs/mental-health/lib/python3.10/site-packages/deepspeed/utils/nvtx.py", line 18, in wrapped_fn
[rank0]: ret_val = func(*args, **kwargs)
[rank0]: File "/home/kuuga/.conda/envs/mental-health/lib/python3.10/site-packages/deepspeed/runtime/engine.py", line 2020, in backward
[rank0]: self.optimizer.backward(loss, retain_graph=retain_graph)
[rank0]: File "/home/kuuga/.conda/envs/mental-health/lib/python3.10/site-packages/deepspeed/utils/nvtx.py", line 18, in wrapped_fn
[rank0]: ret_val = func(*args, **kwargs)
[rank0]: File "/home/kuuga/.conda/envs/mental-health/lib/python3.10/site-packages/deepspeed/runtime/zero/stage3.py", line 2250, in backward
[rank0]: self.loss_scaler.backward(loss.float(), retain_graph=retain_graph)
[rank0]: File "/home/kuuga/.conda/envs/mental-health/lib/python3.10/site-packages/deepspeed/runtime/fp16/loss_scaler.py", line 63, in backward
[rank0]: scaled_loss.backward(retain_graph=retain_graph)
[rank0]: File "/home/kuuga/.conda/envs/mental-health/lib/python3.10/site-packages/torch/_tensor.py", line 581, in backward
[rank0]: torch.autograd.backward(
[rank0]: File "/home/kuuga/.conda/envs/mental-health/lib/python3.10/site-packages/torch/autograd/__init__.py", line 347, in backward
[rank0]: _engine_run_backward(
[rank0]: File "/home/kuuga/.conda/envs/mental-health/lib/python3.10/site-packages/torch/autograd/graph.py", line 825, in _engine_run_backward
[rank0]: return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
[rank0]: File "/home/kuuga/.conda/envs/mental-health/lib/python3.10/site-packages/torch/autograd/function.py", line 307, in apply
[rank0]: return user_fn(self, *args)
[rank0]: File "/home/kuuga/.conda/envs/mental-health/lib/python3.10/site-packages/torch/amp/autocast_mode.py", line 511, in decorate_bwd
[rank0]: return bwd(*args, **kwargs)
[rank0]: File "/home/kuuga/.conda/envs/mental-health/lib/python3.10/site-packages/deepspeed/runtime/zero/linear.py", line 80, in backward
[rank0]: input, weight, bias = ctx.saved_tensors
[rank0]: File "/home/kuuga/.conda/envs/mental-health/lib/python3.10/site-packages/torch/utils/checkpoint.py", line 1129, in unpack_hook
[rank0]: frame.check_recomputed_tensors_match(gid)
[rank0]: File "/home/kuuga/.conda/envs/mental-health/lib/python3.10/site-packages/torch/utils/checkpoint.py", line 903, in check_recomputed_tensors_match
[rank0]: raise CheckpointError(
[rank0]: torch.utils.checkpoint.CheckpointError: torch.utils.checkpoint: Recomputed values for the following tensors have different metadata than during the forward pass.
[rank0]: tensor at position 4:
[rank0]: saved metadata: {'shape': torch.Size([3584]), 'dtype': torch.float16, 'device': device(type='cuda', index=0)}
[rank0]: recomputed metadata: {'shape': torch.Size([0]), 'dtype': torch.float16, 'device': device(type='cuda', index=0)}
[rank0]: tensor at position 6:
[rank0]: saved metadata: {'shape': torch.Size([3584, 3584]), 'dtype': torch.float16, 'device': device(type='cuda', index=0)}
[rank0]: recomputed metadata: {'shape': torch.Size([0]), 'dtype': torch.float16, 'device': device(type='cuda', index=0)}
[rank0]: tensor at position 7:
[rank0]: saved metadata: {'shape': torch.Size([3584]), 'dtype': torch.float16, 'device': device(type='cuda', index=0)}
[rank0]: recomputed metadata: {'shape': torch.Size([0]), 'dtype': torch.float16, 'device': device(type='cuda', index=0)}
[rank0]: tensor at position 14:
[rank0]: saved metadata: {'shape': torch.Size([512, 3584]), 'dtype': torch.float16, 'device': device(type='cuda', index=0)}
[rank0]: recomputed metadata: {'shape': torch.Size([0]), 'dtype': torch.float16, 'device': device(type='cuda', index=0)}
[rank0]: tensor at position 15:
[rank0]: saved metadata: {'shape': torch.Size([512]), 'dtype': torch.float16, 'device': device(type='cuda', index=0)}
[rank0]: recomputed metadata: {'shape': torch.Size([0]), 'dtype': torch.float16, 'device': device(type='cuda', index=0)}
[rank0]: tensor at position 22:
[rank0]: saved metadata: {'shape': torch.Size([512, 3584]), 'dtype': torch.float16, 'device': device(type='cuda', index=0)}
[rank0]: recomputed metadata: {'shape': torch.Size([0]), 'dtype': torch.float16, 'device': device(type='cuda', index=0)}
[rank0]: tensor at position 23:
[rank0]: saved metadata: {'shape': torch.Size([512]), 'dtype': torch.float16, 'device': device(type='cuda', index=0)}
[rank0]: recomputed metadata: {'shape': torch.Size([0]), 'dtype': torch.float16, 'device': device(type='cuda', index=0)}
[rank0]: tensor at position 41:
[rank0]: saved metadata: {'shape': torch.Size([3584, 3584]), 'dtype': torch.float16, 'device': device(type='cuda', index=0)}
[rank0]: recomputed metadata: {'shape': torch.Size([0]), 'dtype': torch.float16, 'device': device(type='cuda', index=0)}
[rank0]: tensor at position 51:
[rank0]: saved metadata: {'shape': torch.Size([3584]), 'dtype': torch.float16, 'device': device(type='cuda', index=0)}
[rank0]: recomputed metadata: {'shape': torch.Size([0]), 'dtype': torch.float16, 'device': device(type='cuda', index=0)}
[rank0]: tensor at position 53:
[rank0]: saved metadata: {'shape': torch.Size([18944, 3584]), 'dtype': torch.float16, 'device': device(type='cuda', index=0)}
[rank0]: recomputed metadata: {'shape': torch.Size([0]), 'dtype': torch.float16, 'device': device(type='cuda', index=0)}
[rank0]: tensor at position 56:
[rank0]: saved metadata: {'shape': torch.Size([18944, 3584]), 'dtype': torch.float16, 'device': device(type='cuda', index=0)}
[rank0]: recomputed metadata: {'shape': torch.Size([0]), 'dtype': torch.float16, 'device': device(type='cuda', index=0)}
0%| | 0/10000 [00:04<?, ?it/s]
E1028 09:14:26.291000 886172 site-packages/torch/distributed/elastic/multiprocessing/api.py:869] failed (exitcode: 1) local_rank: 0 (pid: 886612) of binary: /home/kuuga/.conda/envs/mental-health/bin/python
Traceback (most recent call last):
File "/home/kuuga/.conda/envs/mental-health/bin/torchrun", line 8, in <module>
sys.exit(main())
File "/home/kuuga/.conda/envs/mental-health/lib/python3.10/site-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 355, in wrapper
return f(*args, **kwargs)
File "/home/kuuga/.conda/envs/mental-health/lib/python3.10/site-packages/torch/distributed/run.py", line 919, in main
run(args)
File "/home/kuuga/.conda/envs/mental-health/lib/python3.10/site-packages/torch/distributed/run.py", line 910, in run
elastic_launch(
File "/home/kuuga/.conda/envs/mental-health/lib/python3.10/site-packages/torch/distributed/launcher/api.py", line 138, in __call__
return launch_agent(self._config, self._entrypoint, list(args))
File "/home/kuuga/.conda/envs/mental-health/lib/python3.10/site-packages/torch/distributed/launcher/api.py", line 269, in launch_agent
raise ChildFailedError(
torch.distributed.elastic.multiprocessing.errors.ChildFailedError:
============================================================
finetune.py FAILED
------------------------------------------------------------
Failures:
<NO_OTHER_FAILURES>
------------------------------------------------------------
Root Cause (first observed failure):
[0]:
time : 2024-10-28_09:14:26
host : localhost
rank : 0 (local_rank: 0)
exitcode : 1 (pid: 886612)
error_file: <N/A>
traceback : To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html
============================================================
期望行为 | Expected Behavior
No response
复现方法 | Steps To Reproduce
finetune_lora.sh
#!/bin/bash
GPUS_PER_NODE=1
NNODES=1
NODE_RANK=0
MASTER_ADDR=localhost
MASTER_PORT=6001
MODEL="/home/kuuga/MiniCPM-V-2_6" # or openbmb/MiniCPM-V-2, openbmb/MiniCPM-Llama3-V-2_5
# ATTENTION: specify the path to your training data, which should be a json file consisting of a list of conversations.
# See the section for finetuning in README for more information.
DATA="./train_data.json"
EVAL_DATA="./eval_data.json"
LLM_TYPE="qwen2"
# if use openbmb/MiniCPM-V-2, please set LLM_TYPE=minicpm
#if use openbmb/MiniCPM-Llama3-V-2_5, please set LLM_TYPE=llama3
MODEL_MAX_Length=1200 # if conduct multi-images sft, please set MODEL_MAX_Length=4096
DISTRIBUTED_ARGS="
--nproc_per_node $GPUS_PER_NODE \
--nnodes $NNODES \
--node_rank $NODE_RANK \
--master_addr $MASTER_ADDR \
--master_port $MASTER_PORT
"
torchrun $DISTRIBUTED_ARGS finetune.py \
--model_name_or_path $MODEL \
--llm_type $LLM_TYPE \
--data_path $DATA \
--eval_data_path $EVAL_DATA \
--remove_unused_columns false \
--label_names "labels" \
--prediction_loss_only false \
--bf16 false \
--bf16_full_eval false \
--fp16 true \
--fp16_full_eval true \
--do_train \
--do_eval \
--tune_vision true \
--tune_llm false \
--use_lora true \
--lora_target_modules "llm\..*layers\.\d+\.self_attn\.(q_proj|k_proj|v_proj|o_proj)" \
--model_max_length $MODEL_MAX_Length \
--max_slice_nums 9 \
--max_steps 10000 \
--eval_steps 1000 \
--output_dir output/output__lora \
--logging_dir output/output_lora \
--logging_strategy "steps" \
--per_device_train_batch_size 1 \
--per_device_eval_batch_size 1 \
--gradient_accumulation_steps 1 \
--evaluation_strategy "steps" \
--save_strategy "steps" \
--save_steps 1000 \
--save_total_limit 10 \
--learning_rate 1e-6 \
--weight_decay 0.1 \
--adam_beta2 0.95 \
--warmup_ratio 0.01 \
--lr_scheduler_type "cosine" \
--logging_steps 1 \
--gradient_checkpointing true \
--deepspeed ds_config_zero3.json \
--report_to "tensorboard" # wandb
是否已有关于该错误的issue或讨论? | Is there an existing issue / discussion for this?
该问题是否在FAQ中有解答? | Is there an existing answer for this in FAQ?
当前行为 | Current Behavior
我在使用项目提供的lora微调脚本时 出现如下错误:
期望行为 | Expected Behavior
No response
复现方法 | Steps To Reproduce
finetune_lora.sh
ds_config_zero2.json
ds_config_zero3json
train_data.json
eval_data.json
运行环境 | Environment
备注 | Anything else?
No response