Open BugReporterZ opened 10 months ago
Seems like multiple dependencies (mostly flash-attn) are causing some major hiccups along with cuda version discrepancies are the main issues.
After a lot of trial and error I finally got this environment working on my WSL Ubuntu build with a 4070, hope it helps someone at least until build is fixed.
The docker image was also broke for me, which is why I fiddled with the conda environment so much.
name ft
channels
- pytorch
- nvidia
- defaults
dependencies
- _libgcc_mutex=0.1=main
- _openmp_mutex=5.1=1_gnu
- blas=1.0=mkl
- brotli-python=1.0.9=py311h6a678d5_7
- bzip2=1.0.8=h7b6447c_0
- ca-certificates=2023.12.12=h06a4308_0
- certifi=2023.11.17=py311h06a4308_0
- cffi=1.16.0=py311h5eee18b_0
- charset-normalizer=2.0.4=pyhd3eb1b0_0
- cryptography=41.0.7=py311hdda0065_0
- cuda-cccl=12.3.101=0
- cuda-command-line-tools=12.1.1=0
- cuda-compiler=12.3.2=0
- cuda-cudart=12.1.105=0
- cuda-cudart-dev=12.1.105=0
- cuda-cudart-static=12.1.105=0
- cuda-cuobjdump=12.3.101=0
- cuda-cupti=12.1.105=0
- cuda-cupti-static=12.1.105=0
- cuda-cuxxfilt=12.3.101=0
- cuda-documentation=12.3.101=0
- cuda-driver-dev=12.3.101=0
- cuda-gdb=12.3.101=0
- cuda-libraries=12.1.0=0
- cuda-libraries-dev=12.1.0=0
- cuda-libraries-static=12.1.0=0
- cuda-nsight=12.3.101=0
- cuda-nsight-compute=12.3.2=0
- cuda-nvcc=12.3.107=0
- cuda-nvdisasm=12.3.101=0
- cuda-nvml-dev=12.3.101=0
- cuda-nvprof=12.3.101=0
- cuda-nvprune=12.3.101=0
- cuda-nvrtc=12.1.105=0
- cuda-nvrtc-dev=12.1.105=0
- cuda-nvrtc-static=12.1.105=0
- cuda-nvtx=12.1.105=0
- cuda-nvvp=12.3.101=0
- cuda-opencl=12.3.101=0
- cuda-opencl-dev=12.3.101=0
- cuda-profiler-api=12.3.101=0
- cuda-runtime=12.1.0=0
- cuda-sanitizer-api=12.3.101=0
- cuda-toolkit=12.1.0=0
- cuda-tools=12.1.0=0
- cuda-visual-tools=12.1.0=0
- ffmpeg=4.3=hf484d3e_0
- filelock=3.13.1=py311h06a4308_0
- freetype=2.12.1=h4a9f257_0
- gds-tools=1.8.1.2=0
- giflib=5.2.1=h5eee18b_3
- gmp=6.2.1=h295c915_3
- gmpy2=2.1.2=py311hc9b5ff0_0
- gnutls=3.6.15=he1e5248_0
- idna=3.4=py311h06a4308_0
- intel-openmp=2023.1.0=hdb19cb5_46306
- jinja2=3.1.2=py311h06a4308_0
- jpeg=9e=h5eee18b_1
- lame=3.100=h7b6447c_0
- lcms2=2.12=h3be6417_0
- ld_impl_linux-64=2.38=h1181459_1
- lerc=3.0=h295c915_0
- libcublas=12.1.0.26=0
- libcublas-dev=12.1.0.26=0
- libcublas-static=12.1.0.26=0
- libcufft=11.0.2.4=0
- libcufft-dev=11.0.2.4=0
- libcufft-static=11.0.2.4=0
- libcufile=1.8.1.2=0
- libcufile-dev=1.8.1.2=0
- libcufile-static=1.8.1.2=0
- libcurand=10.3.4.107=0
- libcurand-dev=10.3.4.107=0
- libcurand-static=10.3.4.107=0
- libcusolver=11.4.4.55=0
- libcusolver-dev=11.4.4.55=0
- libcusolver-static=11.4.4.55=0
- libcusparse=12.0.2.55=0
- libcusparse-dev=12.0.2.55=0
- libcusparse-static=12.0.2.55=0
- libdeflate=1.17=h5eee18b_1
- libffi=3.4.4=h6a678d5_0
- libgcc-ng=11.2.0=h1234567_1
- libgomp=11.2.0=h1234567_1
- libiconv=1.16=h7f8727e_2
- libidn2=2.3.4=h5eee18b_0
- libjpeg-turbo=2.0.0=h9bf148f_0
- libnpp=12.0.2.50=0
- libnpp-dev=12.0.2.50=0
- libnpp-static=12.0.2.50=0
- libnvjitlink=12.1.105=0
- libnvjitlink-dev=12.1.105=0
- libnvjpeg=12.1.1.14=0
- libnvjpeg-dev=12.1.1.14=0
- libnvjpeg-static=12.1.1.14=0
- libnvvm-samples=12.1.105=0
- libpng=1.6.39=h5eee18b_0
- libstdcxx-ng=11.2.0=h1234567_1
- libtasn1=4.19.0=h5eee18b_0
- libtiff=4.5.1=h6a678d5_0
- libunistring=0.9.10=h27cfd23_0
- libuuid=1.41.5=h5eee18b_0
- libwebp=1.3.2=h11a3e52_0
- libwebp-base=1.3.2=h5eee18b_0
- llvm-openmp=14.0.6=h9e868ea_0
- lz4-c=1.9.4=h6a678d5_0
- markupsafe=2.1.3=py311h5eee18b_0
- mkl=2023.1.0=h213fc3f_46344
- mkl-service=2.4.0=py311h5eee18b_1
- mkl_fft=1.3.8=py311h5eee18b_0
- mkl_random=1.2.4=py311hdb19cb5_0
- mpc=1.1.0=h10f8cd9_1
- mpfr=4.0.2=hb69a4c5_1
- mpmath=1.3.0=py311h06a4308_0
- ncurses=6.4=h6a678d5_0
- nettle=3.7.3=hbbd107a_1
- networkx=3.1=py311h06a4308_0
- nsight-compute=2023.3.1.1=0
- numpy=1.26.3=py311h08b1b3b_0
- numpy-base=1.26.3=py311hf175353_0
- openh264=2.1.1=h4ff587b_0
- openjpeg=2.4.0=h3ad879b_0
- openssl=3.0.12=h7f8727e_0
- pillow=10.0.1=py311ha6cbd5a_0
- pip=23.3.1=py311h06a4308_0
- pycparser=2.21=pyhd3eb1b0_0
- pyopenssl=23.2.0=py311h06a4308_0
- pysocks=1.7.1=py311h06a4308_0
- python=3.11.7=h955ad1f_0
- pytorch-cuda=12.1=ha16c6d3_5
- pytorch-mutex=1.0=cuda
- pyyaml=6.0.1=py311h5eee18b_0
- readline=8.2=h5eee18b_0
- requests=2.31.0=py311h06a4308_0
- setuptools=68.2.2=py311h06a4308_0
- sqlite=3.41.2=h5eee18b_0
- sympy=1.12=py311h06a4308_0
- tbb=2021.8.0=hdb19cb5_0
- tk=8.6.12=h1ccaba5_0
- torchaudio=2.1.2=py311_cu121
- torchvision=0.16.2=py311_cu121
- typing_extensions=4.9.0=py311h06a4308_1
- urllib3=1.26.18=py311h06a4308_0
- wheel=0.41.2=py311h06a4308_0
- xz=5.4.5=h5eee18b_0
- yaml=0.2.5=h7b6447c_0
- zlib=1.2.13=h5eee18b_0
- zstd=1.5.5=hc292b87_0
- pip
- absl-py==2.1.0
- accelerate==0.25.0.dev0
- addict==2.4.0
- aiobotocore==2.7.0
- aiofiles==23.2.1
- aiohttp==3.9.1
- aioitertools==0.11.0
- aiosignal==1.3.1
- alembic==1.13.1
- altair==5.2.0
- anyio==4.2.0
- appdirs==1.4.4
- art==6.1
- attrs==23.2.0
- bert-score==0.3.13
- bitsandbytes==0.42.0
- blinker==1.7.0
- botocore==1.31.64
- cachetools==5.3.2
- click==8.1.7
- cloudpickle==3.0.0
- cmake==3.28.1
- colorama==0.4.6
- coloredlogs==15.0.1
- contourpy==1.2.0
- cycler==0.12.1
- databricks-cli==0.18.0
- datasets==2.16.1
- decorator==5.1.1
- deepspeed==0.12.6
- dill==0.3.7
- docker==6.1.3
- docker-pycreds==0.4.0
- docstring-parser==0.15
- einops==0.7.0
- entrypoints==0.4
- evaluate==0.4.0
- fastapi==0.109.0
- ffmpy==0.3.1
- fire==0.5.0
- flash-attn==2.4.2
- flask==3.0.0
- fonttools==4.47.2
- frozenlist==1.4.1
- fschat==0.2.34
- fsspec==2023.10.0
- gcsfs==2023.10.0
- gitdb==4.0.11
- gitpython==3.1.41
- google-api-core==2.15.0
- google-auth==2.26.2
- google-auth-oauthlib==1.2.0
- google-cloud-core==2.4.1
- google-cloud-storage==2.14.0
- google-crc32c==1.5.0
- google-resumable-media==2.7.0
- googleapis-common-protos==1.62.0
- gradio==3.50.2
- gradio-client==0.6.1
- greenlet==3.0.3
- grpcio==1.60.0
- gunicorn==21.2.0
- h11==0.14.0
- hf-transfer==0.1.4
- hjson==3.1.0
- httpcore==1.0.2
- httpx==0.26.0
- huggingface-hub==0.20.2
- humanfriendly==10.0
- importlib-metadata==7.0.1
- importlib-resources==6.1.1
- itsdangerous==2.1.2
- jmespath==1.0.1
- joblib==1.3.2
- jsonschema==4.21.0
- jsonschema-specifications==2023.12.1
- kiwisolver==1.4.5
- lit==17.0.6
- llvmlite==0.41.1
- mako==1.3.0
- markdown==3.5.2
- markdown-it-py==3.0.0
- markdown2==2.4.12
- matplotlib==3.8.2
- mdurl==0.1.2
- mlflow==2.9.2
- multidict==6.0.4
- multiprocess==0.70.15
- nh3==0.2.15
- ninja==1.11.1.1
- nltk==3.8.1
- numba==0.58.1
- nvidia-cublas-cu11==11.10.3.66
- nvidia-cuda-cupti-cu11==11.7.101
- nvidia-cuda-nvrtc-cu11==11.7.99
- nvidia-cuda-runtime-cu11==11.7.99
- nvidia-cudnn-cu11==8.5.0.96
- nvidia-cufft-cu11==10.9.0.58
- nvidia-curand-cu11==10.2.10.91
- nvidia-cusolver-cu11==11.4.0.1
- nvidia-cusparse-cu11==11.7.4.91
- nvidia-nccl-cu11==2.14.3
- nvidia-nvtx-cu11==11.7.91
- oauthlib==3.2.2
- optimum==1.13.2
- orjson==3.9.12
- packaging==23.2
- pandas==2.1.4
- peft==0.7.0
- prompt-toolkit==3.0.43
- protobuf==4.23.4
- psutil==5.9.7
- py-cpuinfo==9.0.0
- pyarrow==14.0.2
- pyarrow-hotfix==0.6
- pyasn1==0.5.1
- pyasn1-modules==0.3.0
- pydantic==1.10.13
- pydub==0.25.1
- pygments==2.17.2
- pyjwt==2.8.0
- pynvml==11.5.0
- pyparsing==3.1.1
- python-dateutil==2.8.2
- python-multipart==0.0.6
- pytz==2023.3.post1
- querystring-parser==1.2.4
- referencing==0.32.1
- regex==2023.12.25
- requests-oauthlib==1.3.1
- responses==0.18.0
- rich==13.7.0
- rouge-score==0.1.2
- rpds-py==0.17.1
- rsa==4.9
- s3fs==2023.10.0
- safetensors==0.4.1
- scikit-learn==1.2.2
- scipy==1.11.4
- semantic-version==2.10.0
- sentencepiece==0.1.99
- sentry-sdk==1.39.2
- setproctitle==1.3.3
- shortuuid==1.0.11
- shtab==1.6.5
- six==1.16.0
- smmap==5.0.1
- sniffio==1.3.0
- sqlalchemy==2.0.25
- sqlparse==0.4.4
- starlette==0.35.1
- svgwrite==1.4.3
- tabulate==0.9.0
- tensorboard==2.15.1
- tensorboard-data-server==0.7.2
- termcolor==2.4.0
- threadpoolctl==3.2.0
- tiktoken==0.5.2
- tokenizers==0.15.0
- toolz==0.12.0
- torch==2.0.1
- tqdm==4.66.1
- transformers==4.37.0.dev0
- triton==2.0.0
- trl==0.7.9
- tyro==0.6.5
- tzdata==2023.4
- uvicorn==0.26.0
- wandb==0.16.2
- wavedrom==2.0.3.post3
- wcwidth==0.2.13
- websocket-client==1.7.0
- websockets==11.0.3
- werkzeug==3.0.1
- wrapt==1.16.0
- xformers==0.0.22
- xxhash==3.4.1
- yarl==1.9.4
- zipp==3.17.0
EDIT: It seems updating xformers to xformers==0.0.23.post1
resolves this issue. See their release notes
This patch fixed the installation process for me:
diff --git a/requirements.txt b/requirements.txt
index 4583850..1b985c0 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,6 +1,7 @@
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
packaging==23.2
peft==0.7.0
+torch==2.1.2
transformers @ git+https://github.com/huggingface/transformers.git@3cefac1d974db5e2825a0cb2b842883a628be7a0
tokenizers==0.15.0
bitsandbytes>=0.41.1
@@ -14,7 +15,7 @@ flash-attn==2.3.3
sentencepiece
wandb
einops
-xformers==0.0.22
+xformers==0.0.23.post1
optimum==1.13.2
hf_transfer
colorama
diff --git a/setup.py b/setup.py
index 235018d..8ea9d76 100644
--- a/setup.py
+++ b/setup.py
@@ -49,6 +49,9 @@ setup(
install_requires=install_requires,
dependency_links=dependency_links,
extras_require={
+ "torch": [
+ "torch==2.1.2",
+ ],
"flash-attn": [
"flash-attn==2.3.3",
],
@gardner That appears to fix axolotl not getting installed and running in my case, but there are still issues with training in that memory usage seems unusually high compared to past axolotl commits (I definitely didn't have these issues back in December).
I'm currently OOMing (at batch size 1) on 24GB with 4096-tokens long sequences on a 7B model (Mistral-7B) with QLoRA (4-bit). Training that without FA used to be possible, and even after enabling it memory usage still seems unusually high now.
I've read suggestions that the flash-attn
version could be involved, but upgrading to 2.4.2 didn't solve.
I was able to kick of a tinyllama QLoRA run with 4096 tokens on my 12GB 3060. It's hanging out at 11.4 GB usage but its running (pretty slowly).
You might want to try downgrading flash-attn?
There's also https://github.com/unslothai/unsloth if you don't mind writing a bit of code.
Reverting to an axolotl commit mid-December (5f79b82
but I haven't investigated when issues began exactly), reinstalling packages then uninstalling flash-attn
and doing pip install flash-attn=2.3.2
fixes the issue. Training Mistral-7B in 4bit with 4096 tokens sequences at batch size 1 and FA enabled now takes about 9000 MB of VRAM.
Attempting to finetune in 4-bit a llamafied InternLM2 20B, again with 4096 tokens sequences, now takes about 21500 MB instead of OOMing as with newer axolotl commits.
Sounds like a reproducible regression. Maybe log it in a new issue with steps to reproduce so someone can try to narrow down what happened.
The increased VRAM usage could be possibly related with https://github.com/OpenAccess-AI-Collective/axolotl/issues/1127
Last I checked the setup.py for the pytorch version it asks for 2.1.1, so the install will work with this version of pytorch:
pip install torch==2.1.1 torchvision==0.16.1 torchaudio==2.1.1 --index-url https://download.pytorch.org/whl/cu121
In the current setup.py:
try:
torch_version = version("torch")
if torch_version.startswith("2.1.1"):
_install_requires.pop(_install_requires.index("xformers==0.0.22"))
_install_requires.append("xformers==0.0.23")
I wrote a guide on reddit and it still works as of the commit today. 59a31fe613c7a39563c0f412b15fc9ac109e3009 https://www.reddit.com/r/LocalLLaMA/comments/18pk6wm/how_to_qlora_fine_tune_using_axolotl_zero_to/?utm_source=share&utm_medium=web2x&context=3
I tracked down the issue to flash-attn
from pip
. Version 2.3.2 works; the newer one as per requirements.txt
(2.3.3) causes problems. At the moment I'm on torch 2.0.1, though.
I tracked down the issue to
flash-attn
frompip
. Version 2.3.2 works; the newer one as perrequirements.txt
(2.3.3) causes problems. At the moment I'm on torch 2.0.1, though.
I'm on torch 2.1.1 and flash-attn-2.3.3 and have no issues. They really need to change the torch version from 2.0.1 to 2.1.1 imo.
Please check that this issue hasn't been reported before.
Expected Behavior
Axolotl installation instructions for Conda should guide the user step-by-step through the entire process from start to finish. It is meant to be a way for obtaining reproducible environments.
Current behaviour
Starting from a fresh Conda environment, the installation instructions on the main page are either currently broken or incomplete: axolotl won't start, the training script failing with the error
ModuleNotFoundError: No module named 'axolotl'
.Attempting to install/uninstall/reinstall packages manually may work around this error, but eventually give another:
ImportError: libcudart.so.12: cannot open shared object file: No such file or directory
Even going around this one reveals other issues (e.g. with FlashAttention).
Steps to reproduce
Create a new Conda environment, activate it, install CUDA toolkit (12.1.1), then the latest stable PyTorch of the same version (12.1) via pip. Afterwards, install axolotl along the Quick Start instructions on the main page.
The last step in the above process gives this error. Interestingly, while the instructions on the Axolotl page suggest to install the latest stable PyTorch (2.1.2 as of writing), the installation process removes it and installs 2.0.1 instead:
Attempting to run the finetuning script either with
accelerate
orpython
suggests that something went wrong:Config yaml
No response
Possible solution
No response
Which Operating Systems are you using?
Python Version
3.10
axolotl branch-commit
main/317fa2555ad622f5a64f7a1108fc4ce96df548f6
Acknowledgements