dvmazur / mixtral-offloading

Run Mixtral-8x7B models in Colab or consumer desktops
MIT License
2.29k stars 227 forks source link

CUDA OOM errors in wsl2 #18

Open MrNova111 opened 10 months ago

MrNova111 commented 10 months ago

Trying to run this in Win10 WSL2 on a 3080TI /w 12gb VRAM. Setting the offload_per_layer=7 does not seem to help, VRAM memory usage never goes above 6.5gb so there seems to be lots of room available.

/home/mrnova/.conda/envs/mixtral/lib/python3.10/site-packages/torch/nn/init.py:412: UserWarning: Initializing zero-element tensors is a no-op
  warnings.warn("Initializing zero-element tensors is a no-op")
Traceback (most recent call last):
  File "/home/mrnova/mixtral-offloading/main.py", line 54, in <module>
    model = build_model(
  File "/home/mrnova/mixtral-offloading/src/build_model.py", line 204, in build_model
    expert_cache = ExpertCache(
  File "/home/mrnova/mixtral-offloading/src/expert_cache.py", line 67, in __init__
    self.offloaded_storages = [
  File "/home/mrnova/mixtral-offloading/src/expert_cache.py", line 68, in <listcomp>
    torch.UntypedStorage(self.module_size).pin_memory(self.device) for _ in range(offload_size)]
  File "/home/mrnova/.conda/envs/mixtral/lib/python3.10/site-packages/torch/storage.py", line 226, in pin_memory
    cast(Storage, self)).pin_memory(device)
RuntimeError: CUDA error: out of memory
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.
# packages in environment at /home/mrnova/.conda/envs/mixtral:
#
# Name                    Version                   Build  Channel
_libgcc_mutex             0.1                 conda_forge    conda-forge
_openmp_mutex             4.5                       2_gnu    conda-forge
brotli-python             1.1.0           py310hc6cd4ac_1    conda-forge
bzip2                     1.0.8                h7b6447c_0
ca-certificates           2023.12.12           h06a4308_0
certifi                   2023.11.17         pyhd8ed1ab_0    conda-forge
charset-normalizer        3.3.2              pyhd8ed1ab_0    conda-forge
colorama                  0.4.6              pyhd8ed1ab_0    conda-forge
cuda                      12.3.2                        0    nvidia
cuda-cccl                 12.3.101                      0    nvidia
cuda-command-line-tools   12.3.2                        0    nvidia
cuda-compiler             12.3.2                        0    nvidia
cuda-cudart               12.3.101                      0    nvidia
cuda-cudart-dev           12.3.101                      0    nvidia
cuda-cudart-static        12.3.101                      0    nvidia
cuda-cuobjdump            12.3.101                      0    nvidia
cuda-cupti                12.3.101                      0    nvidia
cuda-cupti-static         12.3.101                      0    nvidia
cuda-cuxxfilt             12.3.101                      0    nvidia
cuda-demo-suite           12.3.101                      0    nvidia
cuda-documentation        12.3.101                      0    nvidia
cuda-driver-dev           12.3.101                      0    nvidia
cuda-gdb                  12.3.101                      0    nvidia
cuda-libraries            12.3.2                        0    nvidia
cuda-libraries-dev        12.3.2                        0    nvidia
cuda-libraries-static     12.3.2                        0    nvidia
cuda-nsight               12.3.101                      0    nvidia
cuda-nsight-compute       12.3.2                        0    nvidia
cuda-nvcc                 12.3.107                      0    nvidia
cuda-nvdisasm             12.3.101                      0    nvidia
cuda-nvml-dev             12.3.101                      0    nvidia
cuda-nvprof               12.3.101                      0    nvidia
cuda-nvprune              12.3.101                      0    nvidia
cuda-nvrtc                12.3.107                      0    nvidia
cuda-nvrtc-dev            12.3.107                      0    nvidia
cuda-nvrtc-static         12.3.107                      0    nvidia
cuda-nvtx                 12.3.101                      0    nvidia
cuda-nvvp                 12.3.101                      0    nvidia
cuda-opencl               12.3.101                      0    nvidia
cuda-opencl-dev           12.3.101                      0    nvidia
cuda-profiler-api         12.3.101                      0    nvidia
cuda-runtime              12.3.2                        0    nvidia
cuda-sanitizer-api        12.3.101                      0    nvidia
cuda-toolkit              12.3.2                        0    nvidia
cuda-tools                12.3.2                        0    nvidia
cuda-version              11.8                 h70ddcb2_2    conda-forge
cuda-visual-tools         12.3.2                        0    nvidia
cudatoolkit               11.8.0               h6a678d5_0
cudnn                     8.9.2.26               cuda11_0
filelock                  3.13.1             pyhd8ed1ab_0    conda-forge
fsspec                    2023.12.2          pyhca7485f_0    conda-forge
gds-tools                 1.8.1.2                       0    nvidia
hqq                       0.1.1                    pypi_0    pypi
hqq-aten                  0.0.0                    pypi_0    pypi
huggingface_hub           0.20.2             pyhd8ed1ab_0    conda-forge
idna                      3.6                pyhd8ed1ab_0    conda-forge
ld_impl_linux-64          2.38                 h1181459_1
libcublas                 12.3.4.1                      0    nvidia
libcublas-dev             12.3.4.1                      0    nvidia
libcublas-static          12.3.4.1                      0    nvidia
libcufft                  11.0.12.1                     0    nvidia
libcufft-dev              11.0.12.1                     0    nvidia
libcufft-static           11.0.12.1                     0    nvidia
libcufile                 1.8.1.2                       0    nvidia
libcufile-dev             1.8.1.2                       0    nvidia
libcufile-static          1.8.1.2                       0    nvidia
libcurand                 10.3.4.107                    0    nvidia
libcurand-dev             10.3.4.107                    0    nvidia
libcurand-static          10.3.4.107                    0    nvidia
libcusolver               11.5.4.101                    0    nvidia
libcusolver-dev           11.5.4.101                    0    nvidia
libcusolver-static        11.5.4.101                    0    nvidia
libcusparse               12.2.0.103                    0    nvidia
libcusparse-dev           12.2.0.103                    0    nvidia
libcusparse-static        12.2.0.103                    0    nvidia
libffi                    3.4.4                h6a678d5_0
libgcc-ng                 13.2.0               h807b86a_3    conda-forge
libgomp                   13.2.0               h807b86a_3    conda-forge
libnpp                    12.2.3.2                      0    nvidia
libnpp-dev                12.2.3.2                      0    nvidia
libnpp-static             12.2.3.2                      0    nvidia
libnvjitlink              12.3.101                      0    nvidia
libnvjitlink-dev          12.3.101                      0    nvidia
libnvjpeg                 12.3.0.81                     0    nvidia
libnvjpeg-dev             12.3.0.81                     0    nvidia
libnvjpeg-static          12.3.0.81                     0    nvidia
libstdcxx-ng              13.2.0               h7e041cc_3    conda-forge
libuuid                   1.41.5               h5eee18b_0
nccl                      2.19.4.1             h6103f9b_0    conda-forge
ncurses                   6.4                  h6a678d5_0
nsight-compute            2023.3.1.1                    0    nvidia
numpy                     1.24.4                   pypi_0    pypi
openssl                   3.2.0                hd590300_1    conda-forge
packaging                 23.2               pyhd8ed1ab_0    conda-forge
pip                       23.3.1          py310h06a4308_0
pysocks                   1.7.1              pyha2e5f31_6    conda-forge
python                    3.10.12              h955ad1f_0
python_abi                3.10                    2_cp310    conda-forge
pyyaml                    6.0.1           py310h2372a71_1    conda-forge
readline                  8.2                  h5eee18b_0
requests                  2.31.0             pyhd8ed1ab_0    conda-forge
setuptools                68.2.2          py310h06a4308_0
sqlite                    3.41.2               h5eee18b_0
termcolor                 2.4.0                    pypi_0    pypi
timm                      0.9.12                   pypi_0    pypi
tk                        8.6.12               h1ccaba5_0
torch                     2.1.2                    pypi_0    pypi
torchvision               0.16.2                   pypi_0    pypi
tqdm                      4.66.1             pyhd8ed1ab_0    conda-forge
transformers              4.36.1                   pypi_0    pypi
typing-extensions         4.9.0                hd8ed1ab_0    conda-forge
typing_extensions         4.9.0              pyha770c72_0    conda-forge
tzdata                    2023d                h04d1e81_0
urllib3                   2.1.0              pyhd8ed1ab_0    conda-forge
wheel                     0.41.2          py310h06a4308_0
xz                        5.4.5                h5eee18b_0
yaml                      0.2.5                h7f98852_2    conda-forge
zlib                      1.2.13               h5eee18b_0
absl-py==2.0.0
accelerate==0.25.0
bitsandbytes==0.41.2.post2
Brotli @ file:///home/conda/feedstock_root/build_artifacts/brotli-split_1695989787169/work
cachetools==5.3.2
certifi==2023.11.17
charset-normalizer==3.3.2
colorama @ file:///home/conda/feedstock_root/build_artifacts/colorama_1666700638685/work
filelock==3.13.1
fsspec==2023.12.1
ftfy==6.1.3
google-auth==2.24.0
grpcio==1.59.3
hqq @ git+https://github.com/mobiusml/hqq.git@37502bea31f2969c6680c0c4a88ca74b3bb234a5
hqq-aten==0.0.0
huggingface-hub==0.20.1
idna==3.6
inquirerpy==0.3.4
Jinja2==3.1.2
JPype1==1.4.1
Markdown==3.5.1
markdown2==2.4.10
MarkupSafe==2.1.3
mpmath==1.3.0
networkx==3.2.1
numpy==1.24.4
nvidia-cublas-cu12==12.1.3.1
nvidia-cuda-cupti-cu12==12.1.105
nvidia-cuda-nvrtc-cu12==12.1.105
nvidia-cuda-runtime-cu12==12.1.105
nvidia-cudnn-cu12==8.9.2.26
nvidia-cufft-cu12==11.0.2.54
nvidia-curand-cu12==10.3.2.106
nvidia-cusolver-cu12==11.4.5.107
nvidia-cusparse-cu12==12.1.0.106
nvidia-nccl-cu12==2.18.1
nvidia-nvjitlink-cu12==12.3.101
nvidia-nvtx-cu12==12.1.105
packaging==23.2
pandas==2.1.3
patsy==0.5.3
pfzy==0.3.4
Pillow==10.1.0
prompt-toolkit==3.0.43
psutil==5.9.6
pyasn1==0.5.1
pyasn1-modules==0.3.0
PyPDF2==3.0.1
PySocks @ file:///home/conda/feedstock_root/build_artifacts/pysocks_1661604839144/work
python-dateutil==2.8.2
pytz==2023.3.post1
PyYAML @ file:///home/conda/feedstock_root/build_artifacts/pyyaml_1695373428874/work
regex==2023.10.3
requests==2.31.0
requests-oauthlib==1.3.1
rsa==4.9
safetensors==0.4.1
scipy==1.11.4
statsmodels==0.14.0
sympy==1.12
tabula-py==2.9.0
termcolor==2.4.0
timm==0.9.12
tokenizers==0.15.0
torch==2.1.2
torchvision==0.16.2
tqdm==4.66.1
transformers==4.36.1
triton==2.1.0
typing_extensions==4.8.0
tzdata==2023.3
urllib3==2.1.0
wcwidth==0.2.12
Werkzeug==3.0.1
xformers==0.0.22.post7