jquesnelle / yarn

YaRN: Efficient Context Window Extension of Large Language Models
MIT License
1.32k stars 115 forks source link

Training takes a long time #32

Closed Michelleable closed 11 months ago

Michelleable commented 11 months ago

Why does it take so long for me to fine-tune llama2-7b-64k? Each epoch takes 300+ seconds I used 8xA100, turned on deepspeed, and used "yarn" for rope type. Is it a problem with flash attention? But I see that modeling_llama_together_yarn.py uses flash attention by default? Thanks a lot.

jquesnelle commented 11 months ago

Unfortunately at 64k this is actually quite fast believe it or not (attention is quadratic in computation). Our run was around 300s/epoch on an 8x A100 node as well. Took about 24 hours to train for 400 steps.

AndyW-llm commented 8 months ago

@jquesnelle @Michelleable Hi! I am encountering OOM issues when trying to fine-tune llama2-7b-64k with deepspeed on a 8*A100(80GB) node.

May I ask if you executed the fine-tune code using Deepspeed or FSDP? If FSDP, would you mind sharing the configuration you used for answering the 'accelerate config' questions.

I also suspect the current repo might not be compatible with the latest 'accelerate' and/or 'transformers' packages. Did you see warnings like "You are attempting to use Flash Attention 2.0 with a model not initialized on GPU. Make sure to move the model to GPU after initializing it on CPU with model.to('cuda')." when you fine-tune the model?

Thank you so much!


For more details,

Here I attached a log of the failed execution. yarn_testing.log

And here is a list of packages installed on the training enviornment. accelerate==0.26.1 aiohttp==3.9.1 aiosignal==1.3.1 annotated-types==0.6.0 archspec @ file:///croot/archspec_1697725767277/work attrs==23.2.0 boltons @ file:///work/ci_py311/boltons_1677685195580/work Brotli @ file:///work/ci_py311/brotli-split_1676830125088/work certifi @ file:///croot/certifi_1700501669400/work/certifi cffi @ file:///croot/cffi_1700254295673/work charset-normalizer @ file:///tmp/build/80754af9/charset-normalizer_1630003229654/work conda-content-trust @ file:///croot/conda-content-trust_1693490622020/work conda-package-handling @ file:///croot/conda-package-handling_1690999929514/work conda_package_streaming @ file:///croot/conda-package-streaming_1690987966409/work contourpy==1.2.0 cryptography @ file:///croot/cryptography_1702070282333/work cycler==0.12.1 datasets==2.16.1 deepspeed==0.12.6 dill==0.3.7 distro @ file:///croot/distro_1701455004953/work einops==0.7.0 evaluate==0.4.1 filelock @ file:///croot/filelock_1700591183607/work flash-attn==2.4.2 fonttools==4.47.2 frozenlist==1.4.1 fsspec==2023.10.0 gmpy2 @ file:///work/ci_py311/gmpy2_1676839849213/work hjson==3.1.0 huggingface-hub==0.20.2 idna @ file:///work/ci_py311/idna_1676822698822/work Jinja2 @ file:///work/ci_py311/jinja2_1676823587943/work joblib==1.3.2 jsonpatch @ file:///tmp/build/80754af9/jsonpatch_1615747632069/work jsonpointer==2.1 kiwisolver==1.4.5 libmambapy @ file:///croot/mamba-split_1698782620632/work/libmambapy MarkupSafe @ file:///croot/markupsafe_1704205993651/work matplotlib==3.8.2 menuinst @ file:///croot/menuinst_1702390294373/work mkl-fft @ file:///croot/mkl_fft_1695058164594/work mkl-random @ file:///croot/mkl_random_1695059800811/work mkl-service==2.4.0 mpmath @ file:///croot/mpmath_1690848262763/work multidict==6.0.4 multiprocess==0.70.15 networkx @ file:///croot/networkx_1690561992265/work ninja==1.11.1.1 numpy @ file:///croot/numpy_and_numpy_base_1704311704800/work/dist/numpy-1.26.3-cp311-cp311-linux_x86_64.whl#sha256=10a078151ecec16bafb535f7487635217625fa06536dec8509e514648c78d626 nvidia-ml-py3==7.352.0 packaging @ file:///croot/packaging_1693575174725/work pandas==2.1.4 Pillow @ file:///croot/pillow_1696580024257/work platformdirs @ file:///croot/platformdirs_1692205439124/work pluggy @ file:///work/ci_py311/pluggy_1676822818071/work protobuf==3.19.6 psutil==5.9.7 py-cpuinfo==9.0.0 pyarrow==14.0.2 pyarrow-hotfix==0.6 pycosat @ file:///croot/pycosat_1696536503704/work pycparser @ file:///tmp/build/80754af9/pycparser_1636541352034/work pydantic==2.5.3 pydantic_core==2.14.6 pynvml==11.5.0 pyOpenSSL @ file:///croot/pyopenssl_1690223430423/work pyparsing==3.1.1 PySocks @ file:///work/ci_py311/pysocks_1676822712504/work python-dateutil==2.8.2 pytz==2023.3.post1 PyYAML @ file:///croot/pyyaml_1698096049011/work regex==2023.12.25 requests @ file:///croot/requests_1690400202158/work responses==0.18.0 ruamel.yaml @ file:///work/ci_py311/ruamel.yaml_1676838772170/work safetensors==0.4.1 -e git+https://github.com/jquesnelle/yarn@651cbeb012d4fc098400823f54f04bdd51d8f0ac#egg=scaled_rope scikit-learn==1.3.2 scipy==1.11.4 sentencepiece==0.1.99 six==1.16.0 sympy @ file:///croot/sympy_1701397643339/work threadpoolctl==3.2.0 tokenizers==0.15.0 torch==2.1.2 torchaudio==2.1.2 torchvision==0.16.2 tqdm @ file:///croot/tqdm_1679561862951/work transformers==4.36.2 triton==2.1.0 truststore @ file:///croot/truststore_1695244293384/work typing_extensions @ file:///croot/typing_extensions_1705005625920/work tzdata==2023.4 urllib3 @ file:///croot/urllib3_1698257533958/work xxhash==3.4.1 yarl==1.9.4 zstandard @ file:///work/ci_py311_2/zstandard_1679339489613/work

YL-9 commented 4 months ago

@jquesnelle @Michelleable Hi! I am encountering OOM issues when trying to fine-tune llama2-7b-64k with deepspeed on a 8*A100(80GB) node.

May I ask if you executed the fine-tune code using Deepspeed or FSDP? If FSDP, would you mind sharing the configuration you used for answering the 'accelerate config' questions.

I also suspect the current repo might not be compatible with the latest 'accelerate' and/or 'transformers' packages. Did you see warnings like "You are attempting to use Flash Attention 2.0 with a model not initialized on GPU. Make sure to move the model to GPU after initializing it on CPU with model.to('cuda')." when you fine-tune the model?

Thank you so much!

For more details,

Here I attached a log of the failed execution. yarn_testing.log

And here is a list of packages installed on the training enviornment. accelerate==0.26.1 aiohttp==3.9.1 aiosignal==1.3.1 annotated-types==0.6.0 archspec @ file:///croot/archspec_1697725767277/work attrs==23.2.0 boltons @ file:///work/ci_py311/boltons_1677685195580/work Brotli @ file:///work/ci_py311/brotli-split_1676830125088/work certifi @ file:///croot/certifi_1700501669400/work/certifi cffi @ file:///croot/cffi_1700254295673/work charset-normalizer @ file:///tmp/build/80754af9/charset-normalizer_1630003229654/work conda-content-trust @ file:///croot/conda-content-trust_1693490622020/work conda-package-handling @ file:///croot/conda-package-handling_1690999929514/work conda_package_streaming @ file:///croot/conda-package-streaming_1690987966409/work contourpy==1.2.0 cryptography @ file:///croot/cryptography_1702070282333/work cycler==0.12.1 datasets==2.16.1 deepspeed==0.12.6 dill==0.3.7 distro @ file:///croot/distro_1701455004953/work einops==0.7.0 evaluate==0.4.1 filelock @ file:///croot/filelock_1700591183607/work flash-attn==2.4.2 fonttools==4.47.2 frozenlist==1.4.1 fsspec==2023.10.0 gmpy2 @ file:///work/ci_py311/gmpy2_1676839849213/work hjson==3.1.0 huggingface-hub==0.20.2 idna @ file:///work/ci_py311/idna_1676822698822/work Jinja2 @ file:///work/ci_py311/jinja2_1676823587943/work joblib==1.3.2 jsonpatch @ file:///tmp/build/80754af9/jsonpatch_1615747632069/work jsonpointer==2.1 kiwisolver==1.4.5 libmambapy @ file:///croot/mamba-split_1698782620632/work/libmambapy MarkupSafe @ file:///croot/markupsafe_1704205993651/work matplotlib==3.8.2 menuinst @ file:///croot/menuinst_1702390294373/work mkl-fft @ file:///croot/mkl_fft_1695058164594/work mkl-random @ file:///croot/mkl_random_1695059800811/work mkl-service==2.4.0 mpmath @ file:///croot/mpmath_1690848262763/work multidict==6.0.4 multiprocess==0.70.15 networkx @ file:///croot/networkx_1690561992265/work ninja==1.11.1.1 numpy @ file:///croot/numpy_and_numpy_base_1704311704800/work/dist/numpy-1.26.3-cp311-cp311-linux_x86_64.whl#sha256=10a078151ecec16bafb535f7487635217625fa06536dec8509e514648c78d626 nvidia-ml-py3==7.352.0 packaging @ file:///croot/packaging_1693575174725/work pandas==2.1.4 Pillow @ file:///croot/pillow_1696580024257/work platformdirs @ file:///croot/platformdirs_1692205439124/work pluggy @ file:///work/ci_py311/pluggy_1676822818071/work protobuf==3.19.6 psutil==5.9.7 py-cpuinfo==9.0.0 pyarrow==14.0.2 pyarrow-hotfix==0.6 pycosat @ file:///croot/pycosat_1696536503704/work pycparser @ file:///tmp/build/80754af9/pycparser_1636541352034/work pydantic==2.5.3 pydantic_core==2.14.6 pynvml==11.5.0 pyOpenSSL @ file:///croot/pyopenssl_1690223430423/work pyparsing==3.1.1 PySocks @ file:///work/ci_py311/pysocks_1676822712504/work python-dateutil==2.8.2 pytz==2023.3.post1 PyYAML @ file:///croot/pyyaml_1698096049011/work regex==2023.12.25 requests @ file:///croot/requests_1690400202158/work responses==0.18.0 ruamel.yaml @ file:///work/ci_py311/ruamel.yaml_1676838772170/work safetensors==0.4.1 -e git+https://github.com/jquesnelle/yarn@651cbeb012d4fc098400823f54f04bdd51d8f0ac#egg=scaled_rope scikit-learn==1.3.2 scipy==1.11.4 sentencepiece==0.1.99 six==1.16.0 sympy @ file:///croot/sympy_1701397643339/work threadpoolctl==3.2.0 tokenizers==0.15.0 torch==2.1.2 torchaudio==2.1.2 torchvision==0.16.2 tqdm @ file:///croot/tqdm_1679561862951/work transformers==4.36.2 triton==2.1.0 truststore @ file:///croot/truststore_1695244293384/work typing_extensions @ file:///croot/typing_extensions_1705005625920/work tzdata==2023.4 urllib3 @ file:///croot/urllib3_1698257533958/work xxhash==3.4.1 yarl==1.9.4 zstandard @ file:///work/ci_py311_2/zstandard_1679339489613/work

I also encountered this problem, have you solved it now? @AndyW-llm