modelscope / ms-swift

Use PEFT or Full-parameter to finetune 400+ LLMs or 100+ MLLMs. (LLM: Qwen2.5, Llama3.2, GLM4, Internlm2.5, Yi1.5, Mistral, Baichuan2, DeepSeek, Gemma2, ...; MLLM: Qwen2-VL, Qwen2-Audio, Llama3.2-Vision, Llava, InternVL2, MiniCPM-V-2.6, GLM4v, Xcomposer2.5, Yi-VL, DeepSeek-VL, Phi3.5-Vision, ...)
https://swift.readthedocs.io/zh-cn/latest/Instruction/index.html
Apache License 2.0
4.35k stars 383 forks source link

longlora finetuning llama3.1-8b-instruct报错positional embeddings #2431

Open xtchen96 opened 1 week ago

xtchen96 commented 1 week ago

Describe the bug

4xA100 gpu fine-tuning llama-3.1-8b-instruct (also tried llama2-13b-ms, same error), cli

CUDA_VISIBLE_DEVICES=0,1,2,3 \
NPROC_PER_NODE=4 \
LOCAL_WORLD_SIZE=4 \
swift sft \
    --model_id_or_path LLM-Research/Meta-Llama-3.1-8B-Instruct \
    --sft_type longlora \
    --dataset /home/ms-swift/swift/llm/data/alpaca_8k.json \
    --output_dir output \
    --deepspeed zero3-offload \
    --max_steps 200

报错如下:

[rank3]: Traceback (most recent call last):
[rank3]:   File "/home/ms-swift/swift/cli/sft.py", line 5, in <module>
[rank3]:     sft_main()
[rank3]:   File "/home/ms-swift/swift/utils/run_utils.py", line 32, in x_main
[rank3]:     result = llm_x(args, **kwargs)
[rank3]:   File "/home/ms-swift/swift/llm/sft.py", line 546, in llm_sft
[rank3]:     return trainer_train(args, model, template, train_dataset, val_dataset, callbacks=callbacks, msg=msg)
[rank3]:   File "/home/ms-swift/swift/llm/sft.py", line 496, in trainer_train
[rank3]:     trainer.train(training_args.resume_from_checkpoint)
[rank3]:   File "/home/ms-swift/swift/trainers/mixin.py", line 493, in train
[rank3]:     res = super().train(resume_from_checkpoint, *args, **kwargs)
[rank3]:   File "/home/anaconda3/envs/swift/lib/python3.10/site-packages/transformers/trainer.py", line 2123, in train
[rank3]:     return inner_training_loop(
[rank3]:   File "/home/anaconda3/envs/swift/lib/python3.10/site-packages/transformers/trainer.py", line 2481, in _inner_training_loop
[rank3]:     tr_loss_step = self.training_step(model, inputs, num_items_in_batch)
[rank3]:   File "/home/anaconda3/envs/swift/lib/python3.10/site-packages/transformers/trainer.py", line 3579, in training_step
[rank3]:     loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch)
[rank3]:   File "/home/ms-swift/swift/trainers/trainers.py", line 161, in compute_loss
[rank3]:     outputs = model(**inputs)
[rank3]:   File "/home/anaconda3/envs/swift/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
[rank3]:     return self._call_impl(*args, **kwargs)
[rank3]:   File "/home/anaconda3/envs/swift/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
[rank3]:     return forward_call(*args, **kwargs)
[rank3]:   File "/home/anaconda3/envs/swift/lib/python3.10/site-packages/deepspeed/utils/nvtx.py", line 18, in wrapped_fn
[rank3]:     ret_val = func(*args, **kwargs)
[rank3]:   File "/home/anaconda3/envs/swift/lib/python3.10/site-packages/deepspeed/runtime/engine.py", line 1899, in forward
[rank3]:     loss = self.module(*inputs, **kwargs)
[rank3]:   File "/home/anaconda3/envs/swift/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
[rank3]:     return self._call_impl(*args, **kwargs)
[rank3]:   File "/home/anaconda3/envs/swift/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1844, in _call_impl
[rank3]:     return inner()
[rank3]:   File "/home/anaconda3/envs/swift/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1790, in inner
[rank3]:     result = forward_call(*args, **kwargs)
[rank3]:   File "/home/anaconda3/envs/swift/lib/python3.10/site-packages/peft/peft_model.py", line 1577, in forward
[rank3]:     return self.base_model(
[rank3]:   File "/home/anaconda3/envs/swift/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
[rank3]:     return self._call_impl(*args, **kwargs)
[rank3]:   File "/home/anaconda3/envs/swift/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1844, in _call_impl
[rank3]:     return inner()
[rank3]:   File "/home/anaconda3/envs/swift/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1790, in inner
[rank3]:     result = forward_call(*args, **kwargs)
[rank3]:   File "/home/anaconda3/envs/swift/lib/python3.10/site-packages/peft/tuners/tuners_utils.py", line 188, in forward
[rank3]:     return self.model.forward(*args, **kwargs)
[rank3]:   File "/home/anaconda3/envs/swift/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 1190, in forward
[rank3]:     outputs = self.model(
[rank3]:   File "/home/anaconda3/envs/swift/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
[rank3]:     return self._call_impl(*args, **kwargs)
[rank3]:   File "/home/anaconda3/envs/swift/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1844, in _call_impl
[rank3]:     return inner()
[rank3]:   File "/home/anaconda3/envs/swift/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1790, in inner
[rank3]:     result = forward_call(*args, **kwargs)
[rank3]:   File "/home/anaconda3/envs/swift/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 933, in forward
[rank3]:     layer_outputs = self._gradient_checkpointing_func(
[rank3]:   File "/home/ms-swift/swift/llm/utils/model.py", line 7139, in <lambda>
[rank3]:     _old_checkpoint(*args, use_reentrant=use_reentrant, **kwargs))
[rank3]:   File "/home/anaconda3/envs/swift/lib/python3.10/site-packages/torch/_compile.py", line 32, in inner
[rank3]:     return disable_fn(*args, **kwargs)
[rank3]:   File "/home/anaconda3/envs/swift/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 632, in _fn
[rank3]:     return fn(*args, **kwargs)
[rank3]:   File "/home/anaconda3/envs/swift/lib/python3.10/site-packages/torch/utils/checkpoint.py", line 489, in checkpoint
[rank3]:     return CheckpointFunction.apply(function, preserve, *args)
[rank3]:   File "/home/anaconda3/envs/swift/lib/python3.10/site-packages/torch/autograd/function.py", line 575, in apply
[rank3]:     return super().apply(*args, **kwargs)  # type: ignore[misc]
[rank3]:   File "/home/anaconda3/envs/swift/lib/python3.10/site-packages/torch/utils/checkpoint.py", line 264, in forward
[rank3]:     outputs = run_function(*args)
[rank3]:   File "/home/anaconda3/envs/swift/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
[rank3]:     return self._call_impl(*args, **kwargs)
[rank3]:   File "/home/anaconda3/envs/swift/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1844, in _call_impl
[rank3]:     return inner()
[rank3]:   File "/home/anaconda3/envs/swift/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1790, in inner
[rank3]:     result = forward_call(*args, **kwargs)
[rank3]:   File "/home/anaconda3/envs/swift/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 676, in forward
[rank3]:     hidden_states, self_attn_weights, present_key_value = self.self_attn(
[rank3]:   File "/home/anaconda3/envs/swift/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
[rank3]:     return self._call_impl(*args, **kwargs)
[rank3]:   File "/home/anaconda3/envs/swift/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1844, in _call_impl
[rank3]:     return inner()
[rank3]:   File "/home/anaconda3/envs/swift/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1790, in inner
[rank3]:     result = forward_call(*args, **kwargs)
[rank3]: TypeError: sdpa_forward() got an unexpected keyword argument 'position_embeddings'

Your hardware and system info

ubuntu 22.04, torch 2.5.1 cuda 12.4

Package                       Version     Editable project location
----------------------------- ----------- -------------------------
absl-py                       2.1.0
accelerate                    1.1.1
addict                        2.4.0
aiofiles                      23.2.1
aiohappyeyeballs              2.4.3
aiohttp                       3.10.10
aiosignal                     1.3.1
aliyun-python-sdk-core        2.16.0
aliyun-python-sdk-kms         2.16.5
annotated-types               0.7.0
anyio                         4.6.2.post1
async-timeout                 4.0.3
attrdict                      2.0.1
attrs                         24.2.0
binpacking                    1.5.2
certifi                       2024.8.30
cffi                          1.17.1
charset-normalizer            3.4.0
click                         8.1.7
contourpy                     1.3.0
cpm-kernels                   1.0.11
crcmod                        1.7
cryptography                  43.0.3
cycler                        0.12.1
dacite                        1.8.1
datasets                      2.21.0
deepspeed                     0.15.4
dill                          0.3.8
distro                        1.9.0
docstring_parser              0.16
einops                        0.8.0
exceptiongroup                1.2.2
fastapi                       0.115.4
ffmpy                         0.4.0
filelock                      3.16.1
fonttools                     4.54.1
frozenlist                    1.5.0
fsspec                        2024.6.1
future                        1.0.0
gradio                        5.5.0
gradio_client                 1.4.2
grpcio                        1.67.1
h11                           0.14.0
hjson                         3.1.0
httpcore                      1.0.6
httpx                         0.27.2
huggingface-hub               0.26.2
idna                          3.10
importlib_metadata            8.5.0
jieba                         0.42.1
Jinja2                        3.1.4
jiter                         0.7.0
jmespath                      0.10.0
joblib                        1.4.2
kiwisolver                    1.4.7
Markdown                      3.7
markdown-it-py                3.0.0
MarkupSafe                    2.1.5
matplotlib                    3.9.2
mdurl                         0.1.2
modelscope                    1.18.1
mpi4py                        3.1.4
mpmath                        1.3.0
ms-swift                      2.6.0.dev0  /home/ms-swift
msgpack                       1.1.0
multidict                     6.1.0
multiprocess                  0.70.16
networkx                      3.4.2
ninja                         1.11.1.1
nltk                          3.9.1
numpy                         1.26.4
nvidia-cublas-cu12            12.4.5.8
nvidia-cuda-cupti-cu12        12.4.127
nvidia-cuda-nvrtc-cu12        12.4.127
nvidia-cuda-runtime-cu12      12.4.127
nvidia-cudnn-cu12             9.1.0.70
nvidia-cufft-cu12             11.2.1.3
nvidia-curand-cu12            10.3.5.147
nvidia-cusolver-cu12          11.6.1.9
nvidia-cusparse-cu12          12.3.1.170
nvidia-ml-py                  12.560.30
nvidia-nccl-cu12              2.21.5
nvidia-nvjitlink-cu12         12.4.127
nvidia-nvtx-cu12              12.4.127
openai                        1.54.3
orjson                        3.10.11
oss2                          2.19.1
packaging                     24.2
pandas                        2.2.3
peft                          0.12.0
pillow                        11.0.0
pip                           24.2
propcache                     0.2.0
protobuf                      5.28.3
psutil                        6.1.0
py-cpuinfo                    9.0.0
pyarrow                       18.0.0
pycparser                     2.22
pycryptodome                  3.21.0
pydantic                      2.9.2
pydantic_core                 2.23.4
pydub                         0.25.1
Pygments                      2.18.0
pyparsing                     3.2.0
python-dateutil               2.9.0.post0
python-multipart              0.0.12
pytz                          2024.2
PyYAML                        6.0.2
regex                         2024.11.6
requests                      2.32.3
rich                          13.9.4
rouge                         1.0.1
ruff                          0.7.3
safehttpx                     0.1.1
safetensors                   0.4.5
scipy                         1.14.1
semantic-version              2.10.0
sentencepiece                 0.2.0
setuptools                    69.5.1
shellingham                   1.5.4
shtab                         1.7.1
simplejson                    3.19.3
six                           1.16.0
sniffio                       1.3.1
sortedcontainers              2.4.0
starlette                     0.41.2
sympy                         1.13.1
tensorboard                   2.18.0
tensorboard-data-server       0.7.2
tiktoken                      0.8.0
tokenizers                    0.20.3
tomlkit                       0.12.0
torch                         2.5.1
tqdm                          4.67.0
transformers                  4.46.2
transformers-stream-generator 0.0.5
triton                        3.1.0
trl                           0.11.4
typer                         0.13.0
typing_extensions             4.12.2
tyro                          0.8.14
tzdata                        2024.2
urllib3                       2.2.3
uvicorn                       0.32.0
websockets                    12.0
Werkzeug                      3.1.3
wheel                         0.44.0
xxhash                        3.5.0
yarl                          1.17.1
zipp                          3.21.0

Additional context Add any other context about the problem here(在这里补充其他信息)

xtchen96 commented 1 week ago

尝试sft_type lora就不报错了,看起来是longlora的问题