axolotl-ai-cloud / axolotl

Go ahead and axolotl questions
https://axolotl-ai-cloud.github.io/axolotl/
Apache License 2.0
8.01k stars 884 forks source link

Installation instructions for Conda are incomplete or broken #1144

Open BugReporterZ opened 10 months ago

BugReporterZ commented 10 months ago

Please check that this issue hasn't been reported before.

Expected Behavior

Axolotl installation instructions for Conda should guide the user step-by-step through the entire process from start to finish. It is meant to be a way for obtaining reproducible environments.

Current behaviour

Starting from a fresh Conda environment, the installation instructions on the main page are either currently broken or incomplete: axolotl won't start, the training script failing with the error ModuleNotFoundError: No module named 'axolotl'.

Attempting to install/uninstall/reinstall packages manually may work around this error, but eventually give another: ImportError: libcudart.so.12: cannot open shared object file: No such file or directory

Even going around this one reveals other issues (e.g. with FlashAttention).

Steps to reproduce

Create a new Conda environment, activate it, install CUDA toolkit (12.1.1), then the latest stable PyTorch of the same version (12.1) via pip. Afterwards, install axolotl along the Quick Start instructions on the main page.

conda create -n axolotl python=3.10
conda activate axolotl
conda install nvidia/label/cuda-12.1.1::cuda-toolkit
pip3 install torch torchvision torchaudio
pip3 install "axolotl[flash-attn,deepspeed] @ git+https://github.com/OpenAccess-AI-Collective/axolotl"

The last step in the above process gives this error. Interestingly, while the instructions on the Axolotl page suggest to install the latest stable PyTorch (2.1.2 as of writing), the installation process removes it and installs 2.0.1 instead:

ERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
torchaudio 2.1.2 requires torch==2.1.2, but you have torch 2.0.1 which is incompatible.
torchvision 0.16.2 requires torch==2.1.2, but you have torch 2.0.1 which is incompatible.

Attempting to run the finetuning script either with accelerate or python suggests that something went wrong:

python ./scripts/finetune.py script.yaml 

Traceback (most recent call last):
  File "/home/anon/bin/axolotl/./scripts/finetune.py", line 8, in <module>
    from axolotl.cli import (
ModuleNotFoundError: No module named 'axolotl'

Config yaml

No response

Possible solution

No response

Which Operating Systems are you using?

Python Version

3.10

axolotl branch-commit

main/317fa2555ad622f5a64f7a1108fc4ce96df548f6

Acknowledgements

Lswhiteh commented 10 months ago

Seems like multiple dependencies (mostly flash-attn) are causing some major hiccups along with cuda version discrepancies are the main issues.

After a lot of trial and error I finally got this environment working on my WSL Ubuntu build with a 4070, hope it helps someone at least until build is fixed.

The docker image was also broke for me, which is why I fiddled with the conda environment so much.

name ft
channels
  - pytorch
  - nvidia
  - defaults
dependencies
  - _libgcc_mutex=0.1=main
  - _openmp_mutex=5.1=1_gnu
  - blas=1.0=mkl
  - brotli-python=1.0.9=py311h6a678d5_7
  - bzip2=1.0.8=h7b6447c_0
  - ca-certificates=2023.12.12=h06a4308_0
  - certifi=2023.11.17=py311h06a4308_0
  - cffi=1.16.0=py311h5eee18b_0
  - charset-normalizer=2.0.4=pyhd3eb1b0_0
  - cryptography=41.0.7=py311hdda0065_0
  - cuda-cccl=12.3.101=0
  - cuda-command-line-tools=12.1.1=0
  - cuda-compiler=12.3.2=0
  - cuda-cudart=12.1.105=0
  - cuda-cudart-dev=12.1.105=0
  - cuda-cudart-static=12.1.105=0
  - cuda-cuobjdump=12.3.101=0
  - cuda-cupti=12.1.105=0
  - cuda-cupti-static=12.1.105=0
  - cuda-cuxxfilt=12.3.101=0
  - cuda-documentation=12.3.101=0
  - cuda-driver-dev=12.3.101=0
  - cuda-gdb=12.3.101=0
  - cuda-libraries=12.1.0=0
  - cuda-libraries-dev=12.1.0=0
  - cuda-libraries-static=12.1.0=0
  - cuda-nsight=12.3.101=0
  - cuda-nsight-compute=12.3.2=0
  - cuda-nvcc=12.3.107=0
  - cuda-nvdisasm=12.3.101=0
  - cuda-nvml-dev=12.3.101=0
  - cuda-nvprof=12.3.101=0
  - cuda-nvprune=12.3.101=0
  - cuda-nvrtc=12.1.105=0
  - cuda-nvrtc-dev=12.1.105=0
  - cuda-nvrtc-static=12.1.105=0
  - cuda-nvtx=12.1.105=0
  - cuda-nvvp=12.3.101=0
  - cuda-opencl=12.3.101=0
  - cuda-opencl-dev=12.3.101=0
  - cuda-profiler-api=12.3.101=0
  - cuda-runtime=12.1.0=0
  - cuda-sanitizer-api=12.3.101=0
  - cuda-toolkit=12.1.0=0
  - cuda-tools=12.1.0=0
  - cuda-visual-tools=12.1.0=0
  - ffmpeg=4.3=hf484d3e_0
  - filelock=3.13.1=py311h06a4308_0
  - freetype=2.12.1=h4a9f257_0
  - gds-tools=1.8.1.2=0
  - giflib=5.2.1=h5eee18b_3
  - gmp=6.2.1=h295c915_3
  - gmpy2=2.1.2=py311hc9b5ff0_0
  - gnutls=3.6.15=he1e5248_0
  - idna=3.4=py311h06a4308_0
  - intel-openmp=2023.1.0=hdb19cb5_46306
  - jinja2=3.1.2=py311h06a4308_0
  - jpeg=9e=h5eee18b_1
  - lame=3.100=h7b6447c_0
  - lcms2=2.12=h3be6417_0
  - ld_impl_linux-64=2.38=h1181459_1
  - lerc=3.0=h295c915_0
  - libcublas=12.1.0.26=0
  - libcublas-dev=12.1.0.26=0
  - libcublas-static=12.1.0.26=0
  - libcufft=11.0.2.4=0
  - libcufft-dev=11.0.2.4=0
  - libcufft-static=11.0.2.4=0
  - libcufile=1.8.1.2=0
  - libcufile-dev=1.8.1.2=0
  - libcufile-static=1.8.1.2=0
  - libcurand=10.3.4.107=0
  - libcurand-dev=10.3.4.107=0
  - libcurand-static=10.3.4.107=0
  - libcusolver=11.4.4.55=0
  - libcusolver-dev=11.4.4.55=0
  - libcusolver-static=11.4.4.55=0
  - libcusparse=12.0.2.55=0
  - libcusparse-dev=12.0.2.55=0
  - libcusparse-static=12.0.2.55=0
  - libdeflate=1.17=h5eee18b_1
  - libffi=3.4.4=h6a678d5_0
  - libgcc-ng=11.2.0=h1234567_1
  - libgomp=11.2.0=h1234567_1
  - libiconv=1.16=h7f8727e_2
  - libidn2=2.3.4=h5eee18b_0
  - libjpeg-turbo=2.0.0=h9bf148f_0
  - libnpp=12.0.2.50=0
  - libnpp-dev=12.0.2.50=0
  - libnpp-static=12.0.2.50=0
  - libnvjitlink=12.1.105=0
  - libnvjitlink-dev=12.1.105=0
  - libnvjpeg=12.1.1.14=0
  - libnvjpeg-dev=12.1.1.14=0
  - libnvjpeg-static=12.1.1.14=0
  - libnvvm-samples=12.1.105=0
  - libpng=1.6.39=h5eee18b_0
  - libstdcxx-ng=11.2.0=h1234567_1
  - libtasn1=4.19.0=h5eee18b_0
  - libtiff=4.5.1=h6a678d5_0
  - libunistring=0.9.10=h27cfd23_0
  - libuuid=1.41.5=h5eee18b_0
  - libwebp=1.3.2=h11a3e52_0
  - libwebp-base=1.3.2=h5eee18b_0
  - llvm-openmp=14.0.6=h9e868ea_0
  - lz4-c=1.9.4=h6a678d5_0
  - markupsafe=2.1.3=py311h5eee18b_0
  - mkl=2023.1.0=h213fc3f_46344
  - mkl-service=2.4.0=py311h5eee18b_1
  - mkl_fft=1.3.8=py311h5eee18b_0
  - mkl_random=1.2.4=py311hdb19cb5_0
  - mpc=1.1.0=h10f8cd9_1
  - mpfr=4.0.2=hb69a4c5_1
  - mpmath=1.3.0=py311h06a4308_0
  - ncurses=6.4=h6a678d5_0
  - nettle=3.7.3=hbbd107a_1
  - networkx=3.1=py311h06a4308_0
  - nsight-compute=2023.3.1.1=0
  - numpy=1.26.3=py311h08b1b3b_0
  - numpy-base=1.26.3=py311hf175353_0
  - openh264=2.1.1=h4ff587b_0
  - openjpeg=2.4.0=h3ad879b_0
  - openssl=3.0.12=h7f8727e_0
  - pillow=10.0.1=py311ha6cbd5a_0
  - pip=23.3.1=py311h06a4308_0
  - pycparser=2.21=pyhd3eb1b0_0
  - pyopenssl=23.2.0=py311h06a4308_0
  - pysocks=1.7.1=py311h06a4308_0
  - python=3.11.7=h955ad1f_0
  - pytorch-cuda=12.1=ha16c6d3_5
  - pytorch-mutex=1.0=cuda
  - pyyaml=6.0.1=py311h5eee18b_0
  - readline=8.2=h5eee18b_0
  - requests=2.31.0=py311h06a4308_0
  - setuptools=68.2.2=py311h06a4308_0
  - sqlite=3.41.2=h5eee18b_0
  - sympy=1.12=py311h06a4308_0
  - tbb=2021.8.0=hdb19cb5_0
  - tk=8.6.12=h1ccaba5_0
  - torchaudio=2.1.2=py311_cu121
  - torchvision=0.16.2=py311_cu121
  - typing_extensions=4.9.0=py311h06a4308_1
  - urllib3=1.26.18=py311h06a4308_0
  - wheel=0.41.2=py311h06a4308_0
  - xz=5.4.5=h5eee18b_0
  - yaml=0.2.5=h7b6447c_0
  - zlib=1.2.13=h5eee18b_0
  - zstd=1.5.5=hc292b87_0
  - pip
      - absl-py==2.1.0
      - accelerate==0.25.0.dev0
      - addict==2.4.0
      - aiobotocore==2.7.0
      - aiofiles==23.2.1
      - aiohttp==3.9.1
      - aioitertools==0.11.0
      - aiosignal==1.3.1
      - alembic==1.13.1
      - altair==5.2.0
      - anyio==4.2.0
      - appdirs==1.4.4
      - art==6.1
      - attrs==23.2.0
      - bert-score==0.3.13
      - bitsandbytes==0.42.0
      - blinker==1.7.0
      - botocore==1.31.64
      - cachetools==5.3.2
      - click==8.1.7
      - cloudpickle==3.0.0
      - cmake==3.28.1
      - colorama==0.4.6
      - coloredlogs==15.0.1
      - contourpy==1.2.0
      - cycler==0.12.1
      - databricks-cli==0.18.0
      - datasets==2.16.1
      - decorator==5.1.1
      - deepspeed==0.12.6
      - dill==0.3.7
      - docker==6.1.3
      - docker-pycreds==0.4.0
      - docstring-parser==0.15
      - einops==0.7.0
      - entrypoints==0.4
      - evaluate==0.4.0
      - fastapi==0.109.0
      - ffmpy==0.3.1
      - fire==0.5.0
      - flash-attn==2.4.2
      - flask==3.0.0
      - fonttools==4.47.2
      - frozenlist==1.4.1
      - fschat==0.2.34
      - fsspec==2023.10.0
      - gcsfs==2023.10.0
      - gitdb==4.0.11
      - gitpython==3.1.41
      - google-api-core==2.15.0
      - google-auth==2.26.2
      - google-auth-oauthlib==1.2.0
      - google-cloud-core==2.4.1
      - google-cloud-storage==2.14.0
      - google-crc32c==1.5.0
      - google-resumable-media==2.7.0
      - googleapis-common-protos==1.62.0
      - gradio==3.50.2
      - gradio-client==0.6.1
      - greenlet==3.0.3
      - grpcio==1.60.0
      - gunicorn==21.2.0
      - h11==0.14.0
      - hf-transfer==0.1.4
      - hjson==3.1.0
      - httpcore==1.0.2
      - httpx==0.26.0
      - huggingface-hub==0.20.2
      - humanfriendly==10.0
      - importlib-metadata==7.0.1
      - importlib-resources==6.1.1
      - itsdangerous==2.1.2
      - jmespath==1.0.1
      - joblib==1.3.2
      - jsonschema==4.21.0
      - jsonschema-specifications==2023.12.1
      - kiwisolver==1.4.5
      - lit==17.0.6
      - llvmlite==0.41.1
      - mako==1.3.0
      - markdown==3.5.2
      - markdown-it-py==3.0.0
      - markdown2==2.4.12
      - matplotlib==3.8.2
      - mdurl==0.1.2
      - mlflow==2.9.2
      - multidict==6.0.4
      - multiprocess==0.70.15
      - nh3==0.2.15
      - ninja==1.11.1.1
      - nltk==3.8.1
      - numba==0.58.1
      - nvidia-cublas-cu11==11.10.3.66
      - nvidia-cuda-cupti-cu11==11.7.101
      - nvidia-cuda-nvrtc-cu11==11.7.99
      - nvidia-cuda-runtime-cu11==11.7.99
      - nvidia-cudnn-cu11==8.5.0.96
      - nvidia-cufft-cu11==10.9.0.58
      - nvidia-curand-cu11==10.2.10.91
      - nvidia-cusolver-cu11==11.4.0.1
      - nvidia-cusparse-cu11==11.7.4.91
      - nvidia-nccl-cu11==2.14.3
      - nvidia-nvtx-cu11==11.7.91
      - oauthlib==3.2.2
      - optimum==1.13.2
      - orjson==3.9.12
      - packaging==23.2
      - pandas==2.1.4
      - peft==0.7.0
      - prompt-toolkit==3.0.43
      - protobuf==4.23.4
      - psutil==5.9.7
      - py-cpuinfo==9.0.0
      - pyarrow==14.0.2
      - pyarrow-hotfix==0.6
      - pyasn1==0.5.1
      - pyasn1-modules==0.3.0
      - pydantic==1.10.13
      - pydub==0.25.1
      - pygments==2.17.2
      - pyjwt==2.8.0
      - pynvml==11.5.0
      - pyparsing==3.1.1
      - python-dateutil==2.8.2
      - python-multipart==0.0.6
      - pytz==2023.3.post1
      - querystring-parser==1.2.4
      - referencing==0.32.1
      - regex==2023.12.25
      - requests-oauthlib==1.3.1
      - responses==0.18.0
      - rich==13.7.0
      - rouge-score==0.1.2
      - rpds-py==0.17.1
      - rsa==4.9
      - s3fs==2023.10.0
      - safetensors==0.4.1
      - scikit-learn==1.2.2
      - scipy==1.11.4
      - semantic-version==2.10.0
      - sentencepiece==0.1.99
      - sentry-sdk==1.39.2
      - setproctitle==1.3.3
      - shortuuid==1.0.11
      - shtab==1.6.5
      - six==1.16.0
      - smmap==5.0.1
      - sniffio==1.3.0
      - sqlalchemy==2.0.25
      - sqlparse==0.4.4
      - starlette==0.35.1
      - svgwrite==1.4.3
      - tabulate==0.9.0
      - tensorboard==2.15.1
      - tensorboard-data-server==0.7.2
      - termcolor==2.4.0
      - threadpoolctl==3.2.0
      - tiktoken==0.5.2
      - tokenizers==0.15.0
      - toolz==0.12.0
      - torch==2.0.1
      - tqdm==4.66.1
      - transformers==4.37.0.dev0
      - triton==2.0.0
      - trl==0.7.9
      - tyro==0.6.5
      - tzdata==2023.4
      - uvicorn==0.26.0
      - wandb==0.16.2
      - wavedrom==2.0.3.post3
      - wcwidth==0.2.13
      - websocket-client==1.7.0
      - websockets==11.0.3
      - werkzeug==3.0.1
      - wrapt==1.16.0
      - xformers==0.0.22
      - xxhash==3.4.1
      - yarl==1.9.4
      - zipp==3.17.0
gardner commented 10 months ago

EDIT: It seems updating xformers to xformers==0.0.23.post1 resolves this issue. See their release notes

This patch fixed the installation process for me:

diff --git a/requirements.txt b/requirements.txt
index 4583850..1b985c0 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,6 +1,7 @@
 --extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
 packaging==23.2
 peft==0.7.0
+torch==2.1.2
 transformers @ git+https://github.com/huggingface/transformers.git@3cefac1d974db5e2825a0cb2b842883a628be7a0
 tokenizers==0.15.0
 bitsandbytes>=0.41.1
@@ -14,7 +15,7 @@ flash-attn==2.3.3
 sentencepiece
 wandb
 einops
-xformers==0.0.22
+xformers==0.0.23.post1
 optimum==1.13.2
 hf_transfer
 colorama
diff --git a/setup.py b/setup.py
index 235018d..8ea9d76 100644
--- a/setup.py
+++ b/setup.py
@@ -49,6 +49,9 @@ setup(
     install_requires=install_requires,
     dependency_links=dependency_links,
     extras_require={
+        "torch": [
+            "torch==2.1.2",
+        ],
         "flash-attn": [
             "flash-attn==2.3.3",
         ],
BugReporterZ commented 10 months ago

@gardner That appears to fix axolotl not getting installed and running in my case, but there are still issues with training in that memory usage seems unusually high compared to past axolotl commits (I definitely didn't have these issues back in December).

I'm currently OOMing (at batch size 1) on 24GB with 4096-tokens long sequences on a 7B model (Mistral-7B) with QLoRA (4-bit). Training that without FA used to be possible, and even after enabling it memory usage still seems unusually high now.

I've read suggestions that the flash-attn version could be involved, but upgrading to 2.4.2 didn't solve.

gardner commented 10 months ago

I was able to kick of a tinyllama QLoRA run with 4096 tokens on my 12GB 3060. It's hanging out at 11.4 GB usage but its running (pretty slowly).

You might want to try downgrading flash-attn?

There's also https://github.com/unslothai/unsloth if you don't mind writing a bit of code.

BugReporterZ commented 10 months ago

Reverting to an axolotl commit mid-December (5f79b82 but I haven't investigated when issues began exactly), reinstalling packages then uninstalling flash-attn and doing pip install flash-attn=2.3.2 fixes the issue. Training Mistral-7B in 4bit with 4096 tokens sequences at batch size 1 and FA enabled now takes about 9000 MB of VRAM.

Attempting to finetune in 4-bit a llamafied InternLM2 20B, again with 4096 tokens sequences, now takes about 21500 MB instead of OOMing as with newer axolotl commits.

gardner commented 10 months ago

Sounds like a reproducible regression. Maybe log it in a new issue with steps to reproduce so someone can try to narrow down what happened.

BugReporterZ commented 10 months ago

The increased VRAM usage could be possibly related with https://github.com/OpenAccess-AI-Collective/axolotl/issues/1127

Nero10578 commented 10 months ago

Last I checked the setup.py for the pytorch version it asks for 2.1.1, so the install will work with this version of pytorch: pip install torch==2.1.1 torchvision==0.16.1 torchaudio==2.1.1 --index-url https://download.pytorch.org/whl/cu121

In the current setup.py:

 try:
        torch_version = version("torch")
        if torch_version.startswith("2.1.1"):
            _install_requires.pop(_install_requires.index("xformers==0.0.22"))
            _install_requires.append("xformers==0.0.23")

I wrote a guide on reddit and it still works as of the commit today. 59a31fe613c7a39563c0f412b15fc9ac109e3009 https://www.reddit.com/r/LocalLLaMA/comments/18pk6wm/how_to_qlora_fine_tune_using_axolotl_zero_to/?utm_source=share&utm_medium=web2x&context=3

BugReporterZ commented 10 months ago

I tracked down the issue to flash-attn from pip. Version 2.3.2 works; the newer one as per requirements.txt (2.3.3) causes problems. At the moment I'm on torch 2.0.1, though.

Nero10578 commented 10 months ago

I tracked down the issue to flash-attn from pip. Version 2.3.2 works; the newer one as per requirements.txt (2.3.3) causes problems. At the moment I'm on torch 2.0.1, though.

I'm on torch 2.1.1 and flash-attn-2.3.3 and have no issues. They really need to change the torch version from 2.0.1 to 2.1.1 imo.