Open FL77N opened 2 days ago
For reference, full parameter finetuning for Qwen2-57B-A14B should be possible with 2 8 80GB GPUs with 4K sequence length (estimated minimum). However, you should enable an mixture of tensor/expert/pipeline/ parallelism, e.g., pp4tp4 or pp2ep8. 4 8 80GB GPUs should be preferred.
It is recommend to check whether the training framework you adopted support those kinds of configurations. Ultimately, it would be best if you could take a closer look at your own code to identify any issues.
For reference, full parameter finetuning for Qwen2-57B-A14B should be possible with 2 8 80GB GPUs with 4K sequence length (estimated minimum). However, you should enable an mixture of tensor/expert/pipeline/ parallelism, e.g., pp4tp4 or pp2ep8. 4 8 80GB GPUs should be preferred.
It is recommend to check whether the training framework you adopted support those kinds of configurations. Ultimately, it would be best if you could take a closer look at your own code to identify any issues.
Thanks for your quick update! My training framework is transformers
, it only supports data parallelism and deepspeed zero3 strategy. However,when I sft it with 8 8 80GB GPUs with 2K sequence length, it is still oom.Maybe it is best to train it with the training framework like Megatron-LM.
Model Series
Qwen2
What are the models used?
Qwen2-57B-A14B
What is the scenario where the problem happened?
train with transformers
Is this a known issue?
Information about environment
accelerate 0.33.0
addict 2.4.0
aiohappyeyeballs 2.3.6
aiohttp 3.10.3
aiosignal 1.3.1
altair 5.4.0
annotated-types 0.7.0
anyio 4.4.0
argon2-cffi 23.1.0
argon2-cffi-bindings 21.2.0
arrow 1.3.0
arxiv 2.1.3
asttokens 2.0.5
astunparse 1.6.3
async-lru 2.0.4
async-timeout 4.0.3
attrs 23.1.0
babel 2.16.0
backcall 0.2.0
beautifulsoup4 4.12.2
bitsandbytes 0.43.3
bleach 6.1.0
blinker 1.8.2
boltons 23.0.0
Brotli 1.0.9
cachetools 5.4.0
certifi 2023.7.22
cffi 1.15.1
chardet 4.0.0
charset-normalizer 2.0.4
click 8.1.7
colorama 0.4.6
comm 0.2.2
conda 23.9.0
conda-build 3.27.0
conda-content-trust 0.2.0
conda_index 0.3.0
conda-libmamba-solver 23.7.0
conda-package-handling 2.2.0
conda_package_streaming 0.9.0
contourpy 1.2.1
cryptography 41.0.3
cycler 0.12.1
datasets 2.21.0
debugpy 1.8.5
decorator 5.1.1
deepspeed 0.14.5
defusedxml 0.7.1
dill 0.3.8
distro 1.9.0 dnspython 2.4.2 dropout-layer-norm 0.1 duckduckgo_search 5.3.1b1 einops 0.8.0 et-xmlfile 1.1.0 exceptiongroup 1.0.4 executing 0.8.3 expecttest 0.1.6 fastjsonschema 2.20.0 feedparser 6.0.11 filelock 3.9.0 flash-attn 2.3.5 fonttools 4.53.1 fqdn 1.5.1 frozenlist 1.4.1 fsspec 2023.10.0 func-timeout 4.3.5 gitdb 4.0.11 GitPython 3.1.43 gmpy2 2.1.2 griffe 0.49.0 h11 0.14.0 h2 4.1.0 hjson 3.1.0 hpack 4.0.0 httpcore 1.0.5 httpx 0.27.0 huggingface-hub 0.24.5 hyperframe 6.0.1 hypothesis 6.88.4 idna 3.4 imageio 2.35.0 importlib_metadata 8.2.0 ipykernel 6.29.5 ipython 8.15.0 ipywidgets 8.1.3 isoduration 20.11.0 jedi 0.18.1 Jinja2 3.1.2 json5 0.9.25 jsonpatch 1.32 jsonpointer 2.1 jsonschema 4.23.0 jsonschema-specifications 2023.12.1 jupyter 1.0.0 jupyter_client 8.6.2 jupyter-console 6.6.3 jupyter_core 5.7.2 jupyter-events 0.10.0 jupyter-lsp 2.2.5 jupyter_server 2.14.2 [96/1864] jupyter_server_terminals 0.5.3 jupyterlab 4.2.4 jupyterlab_pygments 0.3.0 jupyterlab_server 2.27.3 jupyterlab_widgets 3.0.11 kiwisolver 1.4.5 lagent 0.2.3 lazy_loader 0.4 libarchive-c 2.9 libmambapy 1.5.1 markdown-it-py 3.0.0 MarkupSafe 2.1.1 matplotlib 3.9.2 matplotlib-inline 0.1.6 mdurl 0.1.2 mistune 3.0.2 mkl-fft 1.3.8 mkl-random 1.2.4 mkl-service 2.4.0 mmengine 0.10.4 modelscope 1.17.1 more-itertools 8.12.0 mpi4py_mpich 3.1.5 mpmath 1.3.0 multidict 6.0.5 multiprocess 0.70.16 narwhals 1.4.2 nbclient 0.10.0 nbconvert 7.16.4 nbformat 5.10.4 nest-asyncio 1.6.0 networkx 3.1 ninja 1.11.1.1 notebook 7.2.1 notebook_shim 0.2.4 numpy 1.26.0 nvidia-ml-py 12.560.30 opencv-python 4.10.0.84 opencv-python-headless 4.10.0.84 openpyxl 3.1.5 overrides 7.7.0 packaging 23.1 pandas 2.2.2 pandocfilters 1.5.1 parso 0.8.3 peft 0.12.0 pexpect 4.8.0 phx-class-registry 4.1.0 pickleshare 0.7.5 Pillow 10.0.1 pip 23.3 pkginfo 1.9.6 platformdirs 4.2.2 pluggy 1.0.0 prometheus_client 0.20.0 prompt-toolkit 3.0.36 protobuf 5.27.3 psutil 5.9.0 ptyprocess 0.7.0 pure-eval 0.2.2 py-cpuinfo 9.0.0 pyarrow 17.0.0 pycosat 0.6.6 pycparser 2.21 pydantic 2.8.2 pydantic_core 2.20.1 pydeck 0.9.1 Pygments 2.15.1 pyOpenSSL 23.2.0 pyparsing 3.1.2 PySocks 1.7.1 python-dateutil 2.9.0.post0 python-etcd 0.4.5 python-json-logger 2.0.7 pytz 2023.3.post1 PyYAML 6.0.1 pyzmq 26.1.0 qtconsole 5.5.2 QtPy 2.4.1 referencing 0.35.1 regex 2024.7.24 requests 2.32.3 rfc3339-validator 0.1.4 rfc3986-validator 0.1.1 rich 13.7.1 rpds-py 0.20.0 ruamel.yaml 0.17.21 ruamel.yaml.clib 0.2.6 safetensors 0.4.4 scikit-image 0.24.0 scipy 1.14.0 Send2Trash 1.8.3 sentencepiece 0.2.0 setuptools 68.0.0 sgmllib3k 1.0.0 six 1.16.0 smmap 5.0.1 sniffio 1.3.1 socksio 1.0.0 sortedcontainers 2.4.0 soupsieve 2.5 stack-data 0.2.0 streamlit 1.37.1 sympy 1.11.1 tenacity 8.5.0 termcolor 2.4.0 terminado 0.18.1 tifffile 2024.8.10 tiktoken 0.7.0 timeout-decorator 0.5.0 tinycss2 1.3.0 tokenizers 0.19.1 toml 0.10.2 tomli 2.0.1 toolz 0.12.0 torch 2.1.1 torchaudio 2.1.1 torchelastic 0.2.2 torchvision 0.16.1 tornado 6.4.1 tqdm 4.66.5 traitlets 5.7.1 transformers 4.42.1 transformers-stream-generator 0.0.5 triton 2.1.0 truststore 0.8.0 types-dataclasses 0.6.6 types-python-dateutil 2.9.0.20240316 typing_extensions 4.12.2 tzdata 2024.1 uri-template 1.3.0 urllib3 1.26.18 watchdog 4.0.2 wcwidth 0.2.5 webcolors 24.8.0 webencodings 0.5.1 websocket-client 1.8.0 wheel 0.41.2 widgetsnbextension 4.0.11 xxhash 3.4.1 yapf 0.40.2 yarl 1.9.4 zipp 3.20.0 zstandard 0.19.0
Log output
Description
Hi! When I use transformers to sft Qwen2-57B-A14B on 32 x A100 with 2048 input length, it encounter oom in the backward stage. is there something wrong with my setting?