hiyouga / LLaMA-Factory

Efficiently Fine-Tune 100+ LLMs in WebUI (ACL 2024)
https://arxiv.org/abs/2403.13372
Apache License 2.0
31.86k stars 3.91k forks source link

Phi-3-small exploding gradient issue. #3881

Open HideLord opened 4 months ago

HideLord commented 4 months ago

Reminder

Reproduction

To trigger the issue, I tried to train Phi-3-small using LoRA on 4 GPUs using deepspeed with ds_z2_config. Full yaml config:

### model
model_name_or_path: microsoft/Phi-3-small-8k-instruct

### method
stage: sft
do_train: true
finetuning_type: lora
low_cpu_mem_usage: true
flash_attn: fa2

### lora
lora_rank: 128
lora_alpha: 256
lora_dropout: 0.05
lora_target: all

### ddp
ddp_timeout: 180000000
deepspeed: examples/deepspeed/ds_z2_config.json

### dataset
dataset: my_dataset
dataset_dir: data
template: phi
data_seed: 66
seed: 66
cutoff_len: 2000
preprocessing_num_workers: 16
use_fast_tokenizer: true

### output
output_dir: saves/lora/Phi3small_test_128
logging_steps: 5
save_steps: 98
overwrite_output_dir: true
load_best_model_at_end: true
run_name: Phi3small_test_128

### train
per_device_train_batch_size: 2
gradient_accumulation_steps: 1
learning_rate: 0.0000175
num_train_epochs: 1.0
lr_scheduler_type: polynomial
bf16: true
max_grad_norm: 1.0
warmup_steps: 50
weight_decay: 0.005

### eval
val_size: 0.05
per_device_eval_batch_size: 2
evaluation_strategy: steps
eval_steps: 98
save_total_limit: 4

To trigger it:

#!/bin/bash

NPROC_PER_NODE=4
NNODES=1
RANK=0
MASTER_ADDR=127.0.0.1
MASTER_PORT=29500

CUDA_VISIBLE_DEVICES=0,1,2,3 torchrun \
    --nproc_per_node $NPROC_PER_NODE \
    --nnodes $NNODES \
    --node_rank $RANK \
    --master_addr $MASTER_ADDR \
    --master_port $MASTER_PORT \
    src/train.py examples/lora_multi_gpu/my_config.yaml

The gradient explodes at some point: image

The training also breaks with OOM:

[rank3]: Traceback (most recent call last):
[rank3]:   File "/home/hidelord/LLaMA-Factory/src/train.py", line 14, in <module>
[rank3]:     main()
[rank3]:   File "/home/hidelord/LLaMA-Factory/src/train.py", line 5, in main
[rank3]:     run_exp()
[rank3]:   File "/home/hidelord/LLaMA-Factory/src/llamafactory/train/tuner.py", line 34, in run_exp
[rank3]:     run_sft(model_args, data_args, training_args, finetuning_args, generating_args, callbacks)
[rank3]:   File "/home/hidelord/LLaMA-Factory/src/llamafactory/train/sft/workflow.py", line 73, in run_sft
[rank3]:     train_result = trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)
[rank3]:   File "/home/hidelord/miniconda3/envs/llama_fac/lib/python3.10/site-packages/transformers/trainer.py", line 1859, in train
[rank3]:     return inner_training_loop(
[rank3]:   File "/home/hidelord/miniconda3/envs/llama_fac/lib/python3.10/site-packages/transformers/trainer.py", line 2203, in _inner_training_loop
[rank3]:     tr_loss_step = self.training_step(model, inputs)
[rank3]:   File "/home/hidelord/miniconda3/envs/llama_fac/lib/python3.10/site-packages/transformers/trainer.py", line 3147, in training_step
[rank3]:     self.accelerator.backward(loss)
[rank3]:   File "/home/hidelord/miniconda3/envs/llama_fac/lib/python3.10/site-packages/accelerate/accelerator.py", line 2117, in backward
[rank3]:     self.deepspeed_engine_wrapped.backward(loss, **kwargs)
[rank3]:   File "/home/hidelord/miniconda3/envs/llama_fac/lib/python3.10/site-packages/accelerate/utils/deepspeed.py", line 166, in backward
[rank3]:     self.engine.backward(loss, **kwargs)
[rank3]:   File "/home/hidelord/miniconda3/envs/llama_fac/lib/python3.10/site-packages/deepspeed/utils/nvtx.py", line 15, in wrapped_fn
[rank3]:     ret_val = func(*args, **kwargs)
[rank3]:   File "/home/hidelord/miniconda3/envs/llama_fac/lib/python3.10/site-packages/deepspeed/runtime/engine.py", line 1976, in backward
[rank3]:     self.optimizer.backward(loss, retain_graph=retain_graph)
[rank3]:   File "/home/hidelord/miniconda3/envs/llama_fac/lib/python3.10/site-packages/deepspeed/runtime/zero/stage_1_and_2.py", line 2051, in backward
[rank3]:     self.loss_scaler.backward(loss.float(), retain_graph=retain_graph)
[rank3]:   File "/home/hidelord/miniconda3/envs/llama_fac/lib/python3.10/site-packages/deepspeed/runtime/fp16/loss_scaler.py", line 63, in backward
[rank3]:     scaled_loss.backward(retain_graph=retain_graph)
[rank3]:   File "/home/hidelord/miniconda3/envs/llama_fac/lib/python3.10/site-packages/torch/_tensor.py", line 525, in backward
[rank3]:     torch.autograd.backward(
[rank3]:   File "/home/hidelord/miniconda3/envs/llama_fac/lib/python3.10/site-packages/torch/autograd/__init__.py", line 267, in backward
[rank3]:     _engine_run_backward(
[rank3]:   File "/home/hidelord/miniconda3/envs/llama_fac/lib/python3.10/site-packages/torch/autograd/graph.py", line 744, in _engine_run_backward
[rank3]:     return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
[rank3]: torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 1.29 GiB. GPU  has a total capacity of 23.69 GiB of which 1007.81 MiB is free. Including non-PyTorch memory, this process has 22.69 GiB memory in use. Of the allocated memory 20.62 GiB is allocated by PyTorch, and 1.65 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

As a comparison, I ran the same configuration but with mistral-7b: image


May be it's related to the fact that Phi-3-small uses a different architecture than the rest of the family? Phi3SmallForCausalLM vs Phi3ForCausalLM.

Expected behavior

The gradient should stay consistent.

System Info

accelerate==0.30.1 addict==2.4.0 aiofiles==23.2.1 aiohttp==3.9.5 aiosignal==1.3.1 aliyun-python-sdk-core==2.15.1 aliyun-python-sdk-kms==2.16.3 altair==5.3.0 annotated-types==0.6.0 anyio==4.3.0 aqlm==1.1.5 async-timeout==4.0.3 attrs==23.2.0 auto_gptq==0.7.1 autoawq==0.2.5 autoawq_kernels==0.0.6 bitsandbytes==0.43.1 certifi==2024.2.2 cffi==1.16.0 charset-normalizer==3.3.2 click==8.1.7 cloudpickle==3.0.0 cmake==3.29.3 coloredlogs==15.0.1 contourpy==1.2.1 crcmod==1.7 cryptography==42.0.7 cycler==0.12.1 datasets==2.18.0 deepspeed==0.14.0 dill==0.3.8 diskcache==5.6.3 distro==1.9.0 dnspython==2.6.1 docker-pycreds==0.4.0 docstring_parser==0.16 einops==0.8.0 email_validator==2.1.1 exceptiongroup==1.2.1 fastapi==0.111.0 fastapi-cli==0.0.3 ffmpy==0.3.2 filelock==3.14.0 fire==0.6.0 flash-attn==2.5.8 fonttools==4.51.0 frozenlist==1.4.1 fsspec==2024.2.0 gast==0.5.4 gekko==1.1.1 gitdb==4.0.11 GitPython==3.1.43 gradio==4.31.4 gradio_client==0.16.4 h11==0.14.0 hjson==3.1.0 httpcore==1.0.5 httptools==0.6.1 httpx==0.27.0 huggingface-hub==0.23.0 humanfriendly==10.0 idna==3.7 importlib_metadata==7.1.0 importlib_resources==6.4.0 iniconfig==2.0.0 interegular==0.3.3 jieba==0.42.1 Jinja2==3.1.4 jmespath==0.10.0 joblib==1.4.2 jsonschema==4.22.0 jsonschema-specifications==2023.12.1 kiwisolver==1.4.5 lark==1.1.9 -e git+https://github.com/hiyouga/LLaMA-Factory.git@419d47c101eab27dbffc0e3bd646c7e43d036fb3#egg=llamafactory llmtuner==0.6.3.dev0 llvmlite==0.42.0 lm-format-enforcer==0.9.8 markdown-it-py==3.0.0 MarkupSafe==2.1.5 matplotlib==3.9.0 mdurl==0.1.2 modelscope==1.14.0 mpmath==1.3.0 msgpack==1.0.8 multidict==6.0.5 multiprocess==0.70.16 nest-asyncio==1.6.0 networkx==3.3 ninja==1.11.1.1 nltk==3.8.1 numba==0.59.1 numpy==1.26.4 nvidia-cublas-cu12==12.1.3.1 nvidia-cuda-cupti-cu12==12.1.105 nvidia-cuda-nvrtc-cu12==12.1.105 nvidia-cuda-runtime-cu12==12.1.105 nvidia-cudnn-cu12==8.9.2.26 nvidia-cufft-cu12==11.0.2.54 nvidia-curand-cu12==10.3.2.106 nvidia-cusolver-cu12==11.4.5.107 nvidia-cusparse-cu12==12.1.0.106 nvidia-ml-py==12.550.52 nvidia-nccl-cu12==2.20.5 nvidia-nvjitlink-cu12==12.4.127 nvidia-nvtx-cu12==12.1.105 openai==1.30.1 optimum==1.19.2 orjson==3.10.3 oss2==2.18.5 outlines==0.0.34 packaging==24.0 pandas==2.2.2 peft==0.11.1 pillow==10.3.0 platformdirs==4.2.2 pluggy==1.5.0 prometheus-fastapi-instrumentator==7.0.0 prometheus_client==0.20.0 protobuf==4.25.3 psutil==5.9.8 py-cpuinfo==9.0.0 pyarrow==16.1.0 pyarrow-hotfix==0.6 pycparser==2.22 pycryptodome==3.20.0 pydantic==2.7.1 pydantic_core==2.18.2 pydub==0.25.1 Pygments==2.18.0 pynvml==11.5.0 pyparsing==3.1.2 pytest==8.2.1 python-dateutil==2.9.0.post0 python-dotenv==1.0.1 python-multipart==0.0.9 pytz==2024.1 PyYAML==6.0.1 ray==2.22.0 referencing==0.35.1 regex==2024.5.15 requests==2.31.0 rich==13.7.1 rouge==1.0.1 rouge-chinese==1.0.3 rpds-py==0.18.1 ruff==0.4.4 safetensors==0.4.3 scipy==1.13.0 semantic-version==2.10.0 sentencepiece==0.2.0 sentry-sdk==2.2.0 setproctitle==1.3.3 shellingham==1.5.4 shtab==1.7.1 simplejson==3.19.2 six==1.16.0 smmap==5.0.1 sniffio==1.3.1 sortedcontainers==2.4.0 sse-starlette==2.1.0 starlette==0.37.2 sympy==1.12 termcolor==2.4.0 tiktoken==0.6.0 tokenizers==0.19.1 tomli==2.0.1 tomlkit==0.12.0 toolz==0.12.1 torch==2.3.0 tqdm==4.66.4 transformers==4.41.1 transformers-stream-generator==0.0.5 triton==2.3.0 trl==0.8.6 typer==0.12.3 typing_extensions==4.11.0 tyro==0.8.4 tzdata==2024.1 ujson==5.10.0 urllib3==2.2.1 uvicorn==0.29.0 uvloop==0.19.0 vllm==0.4.2 vllm-nccl-cu12==2.18.1.0.4.0 wandb==0.17.0 watchfiles==0.21.0 websockets==11.0.3 xformers==0.0.26.post1 xxhash==3.4.1 yapf==0.40.2 yarl==1.9.4 zipp==3.18.2 zstandard==0.22.0

Others

No response

HideLord commented 4 months ago

Just tested Phi3Medium, and it is working as expected: image

The difference is that I ran medium with ds_z3_config.json

SUNJIMENG commented 4 months ago

Observed the same issue.

nblt commented 3 months ago

Hi, I have observed the same issue, has it been resolved?

maksimstw commented 2 months ago

When I chat with Phi-3-Small, the model often fails to predict the stop token. Perhaps the chat template for Phi-3-small is wrong? Similar issue can be found here: #4712