QwenLM / Qwen2.5

Qwen2.5 is the large language model series developed by Qwen team, Alibaba Cloud.
8.52k stars 521 forks source link

[Bug]: Qwen2 moe out of memory #954

Open FL77N opened 2 days ago

FL77N commented 2 days ago

Model Series

Qwen2

What are the models used?

Qwen2-57B-A14B

What is the scenario where the problem happened?

train with transformers

Is this a known issue?

Information about environment

accelerate 0.33.0
addict 2.4.0
aiohappyeyeballs 2.3.6
aiohttp 3.10.3
aiosignal 1.3.1
altair 5.4.0
annotated-types 0.7.0
anyio 4.4.0
argon2-cffi 23.1.0
argon2-cffi-bindings 21.2.0
arrow 1.3.0
arxiv 2.1.3
asttokens 2.0.5
astunparse 1.6.3
async-lru 2.0.4
async-timeout 4.0.3
attrs 23.1.0
babel 2.16.0
backcall 0.2.0
beautifulsoup4 4.12.2
bitsandbytes 0.43.3
bleach 6.1.0
blinker 1.8.2
boltons 23.0.0
Brotli 1.0.9
cachetools 5.4.0
certifi 2023.7.22
cffi 1.15.1
chardet 4.0.0
charset-normalizer 2.0.4
click 8.1.7
colorama 0.4.6
comm 0.2.2
conda 23.9.0
conda-build 3.27.0
conda-content-trust 0.2.0
conda_index 0.3.0
conda-libmamba-solver 23.7.0
conda-package-handling 2.2.0
conda_package_streaming 0.9.0
contourpy 1.2.1
cryptography 41.0.3
cycler 0.12.1
datasets 2.21.0
debugpy 1.8.5
decorator 5.1.1
deepspeed 0.14.5
defusedxml 0.7.1
dill 0.3.8
distro 1.9.0 dnspython 2.4.2 dropout-layer-norm 0.1 duckduckgo_search 5.3.1b1 einops 0.8.0 et-xmlfile 1.1.0 exceptiongroup 1.0.4 executing 0.8.3 expecttest 0.1.6 fastjsonschema 2.20.0 feedparser 6.0.11 filelock 3.9.0 flash-attn 2.3.5 fonttools 4.53.1 fqdn 1.5.1 frozenlist 1.4.1 fsspec 2023.10.0 func-timeout 4.3.5 gitdb 4.0.11 GitPython 3.1.43 gmpy2 2.1.2 griffe 0.49.0 h11 0.14.0 h2 4.1.0 hjson 3.1.0 hpack 4.0.0 httpcore 1.0.5 httpx 0.27.0 huggingface-hub 0.24.5 hyperframe 6.0.1 hypothesis 6.88.4 idna 3.4 imageio 2.35.0 importlib_metadata 8.2.0 ipykernel 6.29.5 ipython 8.15.0 ipywidgets 8.1.3 isoduration 20.11.0 jedi 0.18.1 Jinja2 3.1.2 json5 0.9.25 jsonpatch 1.32 jsonpointer 2.1 jsonschema 4.23.0 jsonschema-specifications 2023.12.1 jupyter 1.0.0 jupyter_client 8.6.2 jupyter-console 6.6.3 jupyter_core 5.7.2 jupyter-events 0.10.0 jupyter-lsp 2.2.5 jupyter_server 2.14.2 [96/1864] jupyter_server_terminals 0.5.3 jupyterlab 4.2.4 jupyterlab_pygments 0.3.0 jupyterlab_server 2.27.3 jupyterlab_widgets 3.0.11 kiwisolver 1.4.5 lagent 0.2.3 lazy_loader 0.4 libarchive-c 2.9 libmambapy 1.5.1 markdown-it-py 3.0.0 MarkupSafe 2.1.1 matplotlib 3.9.2 matplotlib-inline 0.1.6 mdurl 0.1.2 mistune 3.0.2 mkl-fft 1.3.8 mkl-random 1.2.4 mkl-service 2.4.0 mmengine 0.10.4 modelscope 1.17.1 more-itertools 8.12.0 mpi4py_mpich 3.1.5 mpmath 1.3.0 multidict 6.0.5 multiprocess 0.70.16 narwhals 1.4.2 nbclient 0.10.0 nbconvert 7.16.4 nbformat 5.10.4 nest-asyncio 1.6.0 networkx 3.1 ninja 1.11.1.1 notebook 7.2.1 notebook_shim 0.2.4 numpy 1.26.0 nvidia-ml-py 12.560.30 opencv-python 4.10.0.84 opencv-python-headless 4.10.0.84 openpyxl 3.1.5 overrides 7.7.0 packaging 23.1 pandas 2.2.2 pandocfilters 1.5.1 parso 0.8.3 peft 0.12.0 pexpect 4.8.0 phx-class-registry 4.1.0 pickleshare 0.7.5 Pillow 10.0.1 pip 23.3 pkginfo 1.9.6 platformdirs 4.2.2 pluggy 1.0.0 prometheus_client 0.20.0 prompt-toolkit 3.0.36 protobuf 5.27.3 psutil 5.9.0 ptyprocess 0.7.0 pure-eval 0.2.2 py-cpuinfo 9.0.0 pyarrow 17.0.0 pycosat 0.6.6 pycparser 2.21 pydantic 2.8.2 pydantic_core 2.20.1 pydeck 0.9.1 Pygments 2.15.1 pyOpenSSL 23.2.0 pyparsing 3.1.2 PySocks 1.7.1 python-dateutil 2.9.0.post0 python-etcd 0.4.5 python-json-logger 2.0.7 pytz 2023.3.post1 PyYAML 6.0.1 pyzmq 26.1.0 qtconsole 5.5.2 QtPy 2.4.1 referencing 0.35.1 regex 2024.7.24 requests 2.32.3 rfc3339-validator 0.1.4 rfc3986-validator 0.1.1 rich 13.7.1 rpds-py 0.20.0 ruamel.yaml 0.17.21 ruamel.yaml.clib 0.2.6 safetensors 0.4.4 scikit-image 0.24.0 scipy 1.14.0 Send2Trash 1.8.3 sentencepiece 0.2.0 setuptools 68.0.0 sgmllib3k 1.0.0 six 1.16.0 smmap 5.0.1 sniffio 1.3.1 socksio 1.0.0 sortedcontainers 2.4.0 soupsieve 2.5 stack-data 0.2.0 streamlit 1.37.1 sympy 1.11.1 tenacity 8.5.0 termcolor 2.4.0 terminado 0.18.1 tifffile 2024.8.10 tiktoken 0.7.0 timeout-decorator 0.5.0 tinycss2 1.3.0 tokenizers 0.19.1 toml 0.10.2 tomli 2.0.1 toolz 0.12.0 torch 2.1.1 torchaudio 2.1.1 torchelastic 0.2.2 torchvision 0.16.1 tornado 6.4.1 tqdm 4.66.5 traitlets 5.7.1 transformers 4.42.1 transformers-stream-generator 0.0.5 triton 2.1.0 truststore 0.8.0 types-dataclasses 0.6.6 types-python-dateutil 2.9.0.20240316 typing_extensions 4.12.2 tzdata 2024.1 uri-template 1.3.0 urllib3 1.26.18 watchdog 4.0.2 wcwidth 0.2.5 webcolors 24.8.0 webencodings 0.5.1 websocket-client 1.8.0 wheel 0.41.2 widgetsnbextension 4.0.11 xxhash 3.4.1 yapf 0.40.2 yarl 1.9.4 zipp 3.20.0 zstandard 0.19.0

Log output

torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 18.00 MiB. GPU 1 has a total capacty of 79.33 GiB of which 11.81 MiB is free. Process 3080904 has 79.30 GiB memory in use. Of the allocated memory 77.98 GiB is allocated by PyTorch, and 212.06 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLO[[202[2024-09-24 10:50:59,051] torch.distributed.elastic.multiprocessing.api: [ERROR] failed (exitcode: 1) local_rank: 0 (pid: 35) of binary: /opt/conda/bin/python

Description

Hi! When I use transformers to sft Qwen2-57B-A14B on 32 x A100 with 2048 input length, it encounter oom in the backward stage. is there something wrong with my setting?

jklj077 commented 2 days ago

For reference, full parameter finetuning for Qwen2-57B-A14B should be possible with 2 8 80GB GPUs with 4K sequence length (estimated minimum). However, you should enable an mixture of tensor/expert/pipeline/ parallelism, e.g., pp4tp4 or pp2ep8. 4 8 80GB GPUs should be preferred.

It is recommend to check whether the training framework you adopted support those kinds of configurations. Ultimately, it would be best if you could take a closer look at your own code to identify any issues.

FL77N commented 2 days ago

For reference, full parameter finetuning for Qwen2-57B-A14B should be possible with 2 8 80GB GPUs with 4K sequence length (estimated minimum). However, you should enable an mixture of tensor/expert/pipeline/ parallelism, e.g., pp4tp4 or pp2ep8. 4 8 80GB GPUs should be preferred.

It is recommend to check whether the training framework you adopted support those kinds of configurations. Ultimately, it would be best if you could take a closer look at your own code to identify any issues.

Thanks for your quick update! My training framework is transformers, it only supports data parallelism and deepspeed zero3 strategy. However,when I sft it with 8 8 80GB GPUs with 2K sequence length, it is still oom.Maybe it is best to train it with the training framework like Megatron-LM.