QwenLM / Qwen2-VL

Qwen2-VL is the multimodal large language model series developed by Qwen team, Alibaba Cloud.
Apache License 2.0
3.06k stars 185 forks source link

cannot do awq quantization on qwen 2vl 7b #522

Open lebronjamesking opened 5 days ago

lebronjamesking commented 5 days ago

Hi there,

I was struggling on how to implement quantization on autoawq as you mentioned in home page. I was trying to quantize 7b qwen2 vl but no matter I use 2 A100 80Gb vram, I still get cuda oom.

The calibration data is alreay small sized :1000 image instruction.

res size: 1000 Qwen2VLRotaryEmbedding can now be fully parameterized by passing the model config through the config argument. All other arguments will be removed in v4.46 Loading checkpoint shards: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:00<00:00, 9.60it/s] [<PIL.Image.Image image mode=RGB size=56x56 at 0x7FBBC8C1BDF0>, <PIL.Image.Image image mode=RGB size=56x56 at 0x7FBBC8C1BFA0>, <PIL.Image.Image image mode=RGB size=56x56 at 0x7FBBC8C1BEE0>, <PIL.Image.Image image mode=RGB size=56x56 at 0x7FBBC8C1BD30>, <PIL.Image.Image image mode=RGB size=56x56 at 0x7FBBC8C1BE80>, <PIL.Image.Image image mode=RGB size=56x56 at 0x7FBBC90B5750>, <PIL.Image.Image image mode=RGB size=56x56 at 0x7FBBC90B7430>, <PIL.Image.Image image mode=RGB size=56x56 at 0x7FBBC90B6A10>, <PIL.Image.Image image mode=RGB size=56x56 at 0x7FBBC90B69E0>, <PIL.Image.Image image mode=RGB size=56x56 at 0x7FBBC90B69B0>] Traceback (most recent call last): File "/kefu-nas/fuxinyu/data/posts/src/LLaMA-Factory-0.9.0/data/AutoAWQ/awq_quantize.py", line 84, in model.quantize(calib_data=inputs, quant_config=quant_config,max_calib_samples=1,max_calib_seq_len=16,n_parallel_calib_samples=1,max_chunk_memory=128128128) File "/root/miniconda3/envs/fxy2/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context return func(args, kwargs) File "/kefu-nas/fuxinyu/data/posts/src/LLaMA-Factory-0.9.0/data/AutoAWQ/awq/models/qwen2vl.py", line 233, in quantize self.quantizer = Qwen2VLAwqQuantizer( File "/kefu-nas/fuxinyu/data/posts/src/LLaMA-Factory-0.9.0/data/AutoAWQ/awq/quantize/quantizer.py", line 69, in init self.modules, self.module_kwargs, self.inps = self.init_quant( File "/kefu-nas/fuxinyu/data/posts/src/LLaMA-Factory-0.9.0/data/AutoAWQ/awq/models/qwen2vl.py", line 77, in init_quant self.model(samples) File "/root/miniconda3/envs/fxy2/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl return self._call_impl(args, kwargs) File "/root/miniconda3/envs/fxy2/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl return forward_call(*args, *kwargs) File "/root/miniconda3/envs/fxy2/lib/python3.10/site-packages/transformers/models/qwen2_vl/modeling_qwen2_vl.py", line 1686, in forward image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw) File "/root/miniconda3/envs/fxy2/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl return self._call_impl(args, kwargs) File "/root/miniconda3/envs/fxy2/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl return forward_call(*args, kwargs) File "/root/miniconda3/envs/fxy2/lib/python3.10/site-packages/transformers/models/qwen2_vl/modeling_qwen2_vl.py", line 1049, in forward hidden_states = blk(hidden_states, cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb) File "/root/miniconda3/envs/fxy2/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl return self._call_impl(*args, *kwargs) File "/root/miniconda3/envs/fxy2/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl return forward_call(args, kwargs) File "/root/miniconda3/envs/fxy2/lib/python3.10/site-packages/transformers/models/qwen2_vl/modeling_qwen2_vl.py", line 431, in forward hidden_states = hidden_states + self.attn( File "/root/miniconda3/envs/fxy2/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl return self._call_impl(*args, *kwargs) File "/root/miniconda3/envs/fxy2/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl return forward_call(args, **kwargs) File "/root/miniconda3/envs/fxy2/lib/python3.10/site-packages/transformers/models/qwen2_vl/modeling_qwen2_vl.py", line 404, in forward attn_output = F.scaled_dot_product_attention(q, k, v, attention_mask, dropout_p=0.0) torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 42.50 GiB. GPU

mehamednews commented 2 days ago

It seems enabling flash attention can help with this. I found it after reading this post: https://bhavyajoshi809.medium.com/fine-tuning-qwen2-vl-mllm-on-custom-data-for-ocr-part-3-quantization-of-custom-qwen2-vl-2b-mllm-2c94577f83a5

so basically instead of this:

    model = Qwen2VLAWQForConditionalGeneration.from_pretrained(
        model_path,
        model_type="qwen2_vl",
        use_cache=False
    )

do this:

    model = Qwen2VLAWQForConditionalGeneration.from_pretrained(
        model_path,
        model_type="qwen2_vl",
        use_cache=False,
        device_map="auto",
        attn_implementation="flash_attention_2",
    )
    model.to("cuda")
lebronjamesking commented 2 days ago

It seems enabling flash attention can help with this. I found it after reading this post: https://bhavyajoshi809.medium.com/fine-tuning-qwen2-vl-mllm-on-custom-data-for-ocr-part-3-quantization-of-custom-qwen2-vl-2b-mllm-2c94577f83a5

so basically instead of this:

    model = Qwen2VLAWQForConditionalGeneration.from_pretrained(
        model_path,
        model_type="qwen2_vl",
        use_cache=False
    )

do this:

    model = Qwen2VLAWQForConditionalGeneration.from_pretrained(
        model_path,
        model_type="qwen2_vl",
        use_cache=False,
        device_map="auto",
        attn_implementation="flash_attention_2",
    )
    model.to("cuda")

Hi, many thanks though. I did try using flash_attention_2, it does help. However, I still get CUDA oom with calibration data exceeding 10 samples. It still cannot work. BTW, I have 4 A100 with each 80Gb.

mehamednews commented 1 day ago

Same, I had to use 8 samples (the result in my case was really good)

lebronjamesking commented 1 day ago

Same, I had to use 8 samples (the result in my case was really good)

Great, Let me try it now. Exactly, the performance has not seen any degradation. Anyway, Let's see if @kq-chen and qwen2-vl team can give more suggestions on how to speed up the vllm inference as you mentioned here https://github.com/QwenLM/Qwen2-VL/issues/532