mit-han-lab / llm-awq

[MLSys 2024 Best Paper Award] AWQ: Activation-aware Weight Quantization for LLM Compression and Acceleration
MIT License
2.14k stars 155 forks source link

CUDA out of memory when trying to run AWQ search on A100-80GB #144

Closed isaac-vidas closed 4 months ago

isaac-vidas commented 4 months ago

Thanks for the latest updates and improvements! I was looking into the different llava example notebooks and the VILA example and getting torch.cuda.OutOfMemoryError: CUDA out of memory. on A100-80GB while trying to run the AWQ search script:

python -m awq.entry \
    --model_path /home/gcpuser/sky_workdir/llava-v1.5-7b \
    --w_bit 4 \
    --q_group_size 128 \
    --run_awq \
    --dump_awq /home/gcpuser/sky_workdir/awq_cache/llava-v1.5-7b-w4-g128.pt
Quantization config: {'zero_point': True, 'q_group_size': 128}
* Building model /home/gcpuser/sky_workdir/llava-v1.5-7b
You are using a model of type llava to instantiate a model of type llava_llama. This is not supported for all configurations of models and can yield errors.
Loading checkpoint shards:   0%|                                                                                                                                                                | 0/2 [00:00<?, ?it/s]/opt/conda/envs/quantize_llava/lib/python3.10/site-packages/torch/_utils.py:831: UserWarning: TypedStorage is deprecated. It will be removed in the future and UntypedStorage will be the only storage class. This should only matter to you if you are using storages directly.  To access UntypedStorage directly, use tensor.untyped_storage() instead of tensor.storage()
  return self.fget.__get__(instance, owner)()
Loading checkpoint shards: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  6.58it/s]
/opt/conda/envs/quantize_llava/lib/python3.10/site-packages/huggingface_hub/repocard.py:105: UserWarning: Repo card metadata block was not found. Setting CardData to empty.
  warnings.warn("Repo card metadata block was not found. Setting CardData to empty.")
Token indices sequence length is longer than the specified maximum sequence length for this model (8322 > 2048). Running this sequence through the model will result in indexing errors
 * Split into 65 blocks
Running AWQ...:  19%|██████████████████████████████▍                                                                                                                                   | 6/32 [01:53<08:12, 18.95s/it]
Traceback (most recent call last):
  File "/opt/conda/envs/quantize_llava/lib/python3.10/runpy.py", line 196, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/opt/conda/envs/quantize_llava/lib/python3.10/runpy.py", line 86, in _run_code
    exec(code, run_globals)
  File "/home/gcpuser/sky_workdir/llm-awq/awq/entry.py", line 299, in <module>
    main()
  File "/home/gcpuser/sky_workdir/llm-awq/awq/entry.py", line 239, in main
    model, enc = build_model_and_enc(args.model_path)
  File "/home/gcpuser/sky_workdir/llm-awq/awq/entry.py", line 161, in build_model_and_enc
    awq_results = run_awq(
  File "/opt/conda/envs/quantize_llava/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/home/gcpuser/sky_workdir/llm-awq/awq/quantize/pre_quant.py", line 181, in run_awq
    scales_list = auto_scale_block(
  File "/opt/conda/envs/quantize_llava/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/home/gcpuser/sky_workdir/llm-awq/awq/quantize/auto_scale.py", line 217, in auto_scale_block
    _auto_get_scale(
  File "/home/gcpuser/sky_workdir/llm-awq/awq/quantize/auto_scale.py", line 163, in _auto_get_scale
    scales = _search_module_scale(module2inspect, layers, inp, kwargs)
  File "/home/gcpuser/sky_workdir/llm-awq/awq/quantize/auto_scale.py", line 134, in _search_module_scale
    out = block(x, **kwargs)
  File "/opt/conda/envs/quantize_llava/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/conda/envs/quantize_llava/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/conda/envs/quantize_llava/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 710, in forward
    key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
  File "/opt/conda/envs/quantize_llava/lib/python3.10/site-packages/transformers/cache_utils.py", line 127, in update
    self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2)
torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 2.54 GiB. GPU 0 has a total capacty of 79.15 GiB of which 2.41 GiB is free. Including non-PyTorch memory, this process has 76.73 GiB memory in use. Of the allocated memory 74.07 GiB is allocated by PyTorch, and 2.16 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

I noticed that parallelism parameters aren't being used in the run_awq in the entry.py script. Anything I'm missing?

Environment setup details:

conda create -n quantize_llava python=$PYTHON -y
conda activate quantize_llava
pip install --upgrade pip
pip install nvitop

echo "Install LLaVA"
cd ~/sky_workdir
git clone https://github.com/haotian-liu/LLaVA.git
cd ~/sky_workdir/LLaVA/
pip install -e .

echo "Install llm-awq and kernels"
cd ~/sky_workdir
git clone https://github.com/mit-han-lab/llm-awq.git
cd ~/sky_workdir/llm-awq/
pip install -e .
cd ~/sky_workdir/llm-awq/awq/kernels
python setup.py install

echo "Download LLaVA model"
mkdir ~/sky_workdir/llava-v1.5-7b
huggingface-cli download \
  --local-dir ~/sky_workdir/llava-v1.5-7b \
  --local-dir-use-symlinks False \
  liuhaotian/llava-v1.5-7b

echo "Create awq folders"
mkdir ~/sky_workdir/awq_cache
mkdir ~/sky_workdir/quant_cache
isaac-vidas commented 4 months ago

@ys-2020 any advice would be highly appreciated. Thanks in advance!

ys-2020 commented 4 months ago

Hi @isaac-vidas . Thank you for your interests in AWQ.

You are using the right commands to quantize the llava model. And I tried to reproduce your environment for quantizing llava-7b. However, I did not encounter the CUDA out of memory problem.

I used transformers==4.32.0 in my environment. My pytorch and CUDA versions are same as yours.

Since the original 7b model only takes ~14GB memory, and we only move 1 layer onto GPU each time during the quantization. The CUDA out of memory problem is not expected. I guess you should double check if your GPU is running other jobs on the meantime.

I hope this information is helpful.

isaac-vidas commented 4 months ago

Got it, thanks for verifying @ys-2020!

🤔 I'm running it on a brand new instance and was running nvitop in the background to make sure nothing else is running on the GPU(s). Also tried it on A100 (40GB) and A100-80GB. What I'm seeing is that as AWQ is running it slowly takes more and more memory so I suspected a memory leak of some kind. Specifically it seems to run out of memory in _search_module_scale.

With A100-80GB I'm able to reach 19% while with A100-40GB I'm only able to reach 12% before getting torch.cuda.OutOfMemoryError

Running AWQ...:  19%|██████████████████████████████▍ 

I was surprised to run into this error because I'm able to run the LLaVA 7b model on as L4 with no issues.

I'll try again with transformers==4.32.0 and see if it makes any difference.

isaac-vidas commented 4 months ago

One more question, if you have a moment. Which commit hash of LLaVA are you using? I just pulled the latest main branch which is 5d8f1760c08b7dfba3ae97b71cbd4c6f17d12dbd.

isaac-vidas commented 4 months ago

Another update, the error seems to change depending on the order that I install the packages.

Installing LLaVA before llm-awq

Performing this order of operations will result in CUDA out of memory error:

git clone https://github.com/haotian-liu/LLaVA.git
cd LLaVA
pip install -e .

cd ..
git clone https://github.com/mit-han-lab/llm-awq.git
cd llm-awq
pip install -e .
cd awq/kernels
python setup.py install

Installed packages

Package                   Version     Editable project location
------------------------- ----------- ---------------------------------
absl-py                   2.1.0
accelerate                0.21.0
aiofiles                  23.2.1
aiohttp                   3.9.3
aiosignal                 1.3.1
altair                    5.2.0
annotated-types           0.6.0
anyio                     4.3.0
async-timeout             4.0.3
attributedict             0.3.0
attrs                     23.2.0
awq                       0.1.0       /home/gcpuser/sky_workdir/llm-awq
awq_inference_engine      0.0.0
bitsandbytes              0.42.0
blessings                 1.7
cachetools                5.3.2
certifi                   2024.2.2
chardet                   5.2.0
charset-normalizer        3.3.2
click                     8.1.7
codecov                   2.1.13
colorama                  0.4.6
coloredlogs               15.0.1
colour-runner             0.1.1
contourpy                 1.2.0
coverage                  7.4.3
cycler                    0.12.1
DataProperty              1.0.1
datasets                  2.17.1
deepdiff                  6.7.1
dill                      0.3.8
distlib                   0.3.8
distro                    1.9.0
einops                    0.6.1
einops-exts               0.0.4
exceptiongroup            1.2.0
fastapi                   0.110.0
ffmpy                     0.3.2
filelock                  3.13.1
fonttools                 4.49.0
frozenlist                1.4.1
fsspec                    2023.10.0
gradio                    3.35.2
gradio_client             0.2.9
h11                       0.14.0
hf_transfer               0.1.5
httpcore                  0.17.3
httpx                     0.24.0
huggingface-hub           0.20.3
humanfriendly             10.0
idna                      3.6
importlib_resources       6.1.2
inspecta                  0.1.3
Jinja2                    3.1.3
joblib                    1.3.2
jsonlines                 4.0.0
jsonschema                4.21.1
jsonschema-specifications 2023.12.1
kiwisolver                1.4.5
linkify-it-py             2.0.3
llava                     1.2.2.post1 /home/gcpuser/sky_workdir/LLaVA
lm-eval                   0.3.0
markdown-it-py            2.2.0
markdown2                 2.4.13
MarkupSafe                2.1.5
matplotlib                3.8.3
mbstrdecoder              1.1.3
mdit-py-plugins           0.3.3
mdurl                     0.1.2
mpmath                    1.3.0
multidict                 6.0.5
multiprocess              0.70.16
networkx                  3.2.1
nltk                      3.8.1
numexpr                   2.9.0
numpy                     1.26.4
nvidia-cublas-cu12        12.1.3.1
nvidia-cuda-cupti-cu12    12.1.105
nvidia-cuda-nvrtc-cu12    12.1.105
nvidia-cuda-runtime-cu12  12.1.105
nvidia-cudnn-cu12         8.9.2.26
nvidia-cufft-cu12         11.0.2.54
nvidia-curand-cu12        10.3.2.106
nvidia-cusolver-cu12      11.4.5.107
nvidia-cusparse-cu12      12.1.0.106
nvidia-ml-py              12.535.133
nvidia-nccl-cu12          2.18.1
nvidia-nvjitlink-cu12     12.3.101
nvidia-nvtx-cu12          12.1.105
nvitop                    1.3.2
openai                    1.12.0
ordered-set               4.1.0
orjson                    3.9.15
packaging                 23.2
pandas                    2.2.1
pathvalidate              3.2.0
peft                      0.8.2
pillow                    10.2.0
pip                       24.0
platformdirs              4.2.0
pluggy                    1.4.0
portalocker               2.8.2
protobuf                  4.25.3
psutil                    5.9.8
pyarrow                   15.0.0
pyarrow-hotfix            0.6
pybind11                  2.11.1
pycountry                 23.12.11
pydantic                  1.10.14
pydantic_core             2.16.3
pydub                     0.25.1
Pygments                  2.17.2
pyparsing                 3.1.1
pyproject-api             1.6.1
pytablewriter             1.2.0
python-dateutil           2.8.2
python-multipart          0.0.9
pytz                      2024.1
PyYAML                    6.0.1
referencing               0.33.0
regex                     2023.12.25
requests                  2.31.0
rich                      13.7.0
rootpath                  0.1.1
rouge_score               0.1.2
rpds-py                   0.18.0
ruff                      0.2.2
sacrebleu                 1.5.0
safetensors               0.4.2
scikit-learn              1.2.2
scipy                     1.12.0
semantic-version          2.10.0
sentencepiece             0.1.99
setuptools                69.1.1
shellingham               1.5.4
shortuuid                 1.0.11
six                       1.16.0
sniffio                   1.3.1
sqlitedict                2.1.0
starlette                 0.36.3
svgwrite                  1.4.3
sympy                     1.12
tabledata                 1.3.3
tcolorpy                  0.1.4
termcolor                 2.4.0
texttable                 1.7.0
threadpoolctl             3.3.0
timm                      0.6.13
tokenizers                0.15.1
toml                      0.10.2
tomli                     2.0.1
tomlkit                   0.12.0
toolz                     0.12.1
torch                     2.1.2
torchvision               0.16.2
tox                       4.13.0
tqdm                      4.66.2
tqdm-multiprocess         0.0.11
transformers              4.37.2
triton                    2.1.0
typepy                    1.3.2
typer                     0.9.0
typing_extensions         4.10.0
tzdata                    2024.1
uc-micro-py               1.0.3
urllib3                   2.2.1
uvicorn                   0.27.1
virtualenv                20.25.1
wavedrom                  2.0.3.post3
websockets                11.0.3
wheel                     0.42.0
xxhash                    3.4.1
yarl                      1.9.4
zstandard                 0.22.0

Installing llm-awq before LLaVA

git clone https://github.com/mit-han-lab/llm-awq.git
cd llm-awq
pip install -e .
cd awq/kernels
python setup.py install
cd ../../..

git clone https://github.com/haotian-liu/LLaVA.git
cd LLaVA
pip install -e .

Performing this order of operations will result in an import error:

$ python -m awq.entry     --model_path /home/gcpuser/sky_workdir/llava-v1.5-7b     --w_bit 4     --q_group_size 128     --run_awq     --dump_awq /home/gcpuser/sky_workdir/awq_cache/llava-v1.5-7b-w4-g128.pt
Traceback (most recent call last):
  File "/opt/conda/envs/quantize_llava/lib/python3.10/runpy.py", line 196, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/opt/conda/envs/quantize_llava/lib/python3.10/runpy.py", line 86, in _run_code
    exec(code, run_globals)
  File "/home/gcpuser/sky_workdir/llm-awq/awq/entry.py", line 15, in <module>
    from awq.quantize.pre_quant import run_awq, apply_awq
  File "/home/gcpuser/sky_workdir/llm-awq/awq/quantize/pre_quant.py", line 12, in <module>
    from tinychat.models import LlavaLlamaForCausalLM
  File "/home/gcpuser/sky_workdir/llm-awq/tinychat/models/__init__.py", line 1, in <module>
    from .falcon import FalconForCausalLM
  File "/home/gcpuser/sky_workdir/llm-awq/tinychat/models/falcon.py", line 11, in <module>
    import awq_inference_engine
ImportError: /opt/conda/envs/quantize_llava/lib/python3.10/site-packages/awq_inference_engine-0.0.0-py3.10-linux-x86_64.egg/awq_inference_engine.cpython-310-x86_64-linux-gnu.so: undefined symbol: _ZN2at4_ops19empty_memory_format4callEN3c108ArrayRefINS2_6SymIntEEESt8optionalINS2_10ScalarTypeEES6_INS2_6LayoutEES6_INS2_6DeviceEES6_IbES6_INS2_12MemoryFormatEE
casper-hansen commented 4 months ago

Hi let me add some more context to why this is happening. In transformers 4.36.0, they introduced a new caching system and broke a lot of quantization systems including AWQ. Use a version prior to that or try to pass in use_cache=False when the model is created.

ys-2020 commented 4 months ago

Thank you so much!!! @casper-hansen

ys-2020 commented 4 months ago

@isaac-vidas I was using ba72f82cc610b01dc27764b483dfe982948b0633 for LLaVA.

For the question regarding installing llm-awq before LLaVA, the problem is that LLaVA may broke the environment of awq. Specifically, LLaVA will reinstall pytorch in the current environment. And awq kernels are wrapped as Pytorch extensions. The change of Pytorch would lead to import error of awq.

To avoid this, you may need to comment out the LLaVA's requirements for torch when installing LLaVA. Alternatively, you can compile awq CUDA kernels again (in awq/kernels, with python setup.py install) with the current pytorch.

isaac-vidas commented 4 months ago

@casper-hansen and @ys-2020 thank you both so much! This worked on my end with use_cache=False. I initially added the use_cache but I added it in the different section of the code that loaded the model if it wasn't LLaVA 🤦‍♂️. Anyway, it's working now and I also created #145 to add the use_cache=False by default if anyone else will face this issue.

isaac-vidas commented 4 months ago

Fixed with #145