wejoncy / QLLM

A general 2-8 bits quantization toolbox with GPTQ/AWQ/HQQ, and export to onnx/onnx-runtime easily.
Apache License 2.0
150 stars 15 forks source link

Alibaba-NLP/gte-Qwen2-7B-instruct doesn't load properly #119

Closed prattcmp closed 5 months ago

prattcmp commented 5 months ago

The library unfortunately isn't working with the Alibaba-NLP/gte-Qwen2-7B-instruct Transformers model.

python -m qllm --model Alibaba-NLP/gte-Qwen2-7B-instruct --method gptq --save ./gte-Qwen2-7B-4bit --export_onnx ./gte-Qwen2-7B-4bit_onnx --allow_mix_bits --true-sequential

Namespace(method='gptq', model='Alibaba-NLP/gte-Qwen2-7B-instruct', tokenizer='', dataset='wikitext2', seed=0, nsamples=128, percdamp=0.01, static_groups=False, wbits=4, mix_qlayer_conf=None, groupsize=128, eval=False, save='./gte-Qwen2-7B-4bit', save_safetensors='', load='', sym=False, act_order=False, true_sequential=True, allow_mix_bits=True, export_onnx='./gte-Qwen2-7B-4bit_onnx', use_plugin=False, pack_mode='AUTO')
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
2024-06-19 02:37:27,252 - qllm - INFO - loading model from Alibaba-NLP/gte-Qwen2-7B-instruct
Loading checkpoint shards: 100%|██████████████████| 7/7 [00:02<00:00,  3.06it/s]
2024-06-19 02:37:30,775 - qllm - INFO - loading dataset from wikitext2
2024-06-19 02:37:30,775 - qllm - INFO - found cached dataloader in /tmp/qllm_vubuntu/_Alibaba-NLPgte-Qwen2-7B-instruct_wikitext2_128_2048_0_dataloader.pt
Starting ...
Traceback (most recent call last):
  File "/opt/conda/lib/python3.10/runpy.py", line 196, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/opt/conda/lib/python3.10/runpy.py", line 86, in _run_code
    exec(code, run_globals)
  File "/opt/conda/lib/python3.10/site-packages/qllm/__main__.py", line 6, in <module>
    main()
  File "/opt/conda/lib/python3.10/site-packages/qllm/run.py", line 78, in main
    model_quanter.run(args)
  File "/opt/conda/lib/python3.10/site-packages/qllm/auto_model_quantization.py", line 214, in run
    quantizers = self.__dispatch_quant(model, inputs_dataloader, config, "cuda")
  File "/opt/conda/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/qllm/auto_model_quantization.py", line 42, in __dispatch_quant
    return quantizer.quantize(model, inputs_dataloader, dev)
  File "/opt/conda/lib/python3.10/site-packages/qllm/quantization/quant_frame_base.py", line 119, in quantize
    quantizers.update(self.do_quantize(model, dataloader, prefix, dev))
  File "/opt/conda/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/qllm/quantization/quant_gptq.py", line 91, in do_quantize
    inps, outs, attention_layers, layer_input_args = self.hijack_block_inputs(model, dataloader, model_prefix, dev)
  File "/opt/conda/lib/python3.10/site-packages/qllm/quantization/quant_frame_base.py", line 90, in hijack_block_inputs
    model(batch[0].to(dev))
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/ubuntu/.cache/huggingface/modules/transformers_modules/Alibaba-NLP/gte-Qwen2-7B-instruct/1efef6cb8e5b06824152b8fa2a42e762bd4a3571/modeling_qwen.py", line 1198, in forward
    outputs = self.model(
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/ubuntu/.cache/huggingface/modules/transformers_modules/Alibaba-NLP/gte-Qwen2-7B-instruct/1efef6cb8e5b06824152b8fa2a42e762bd4a3571/modeling_qwen.py", line 1038, in forward
    attention_mask = _prepare_4d_attention_mask_for_sdpa(
  File "/opt/conda/lib/python3.10/site-packages/transformers/modeling_attn_mask_utils.py", line 439, in _prepare_4d_attention_mask_for_sdpa
    batch_size, key_value_length = mask.shape
AttributeError: 'NoneType' object has no attribute 'shape'
wejoncy commented 5 months ago

Hi @prattcmp Thanks for filing the issue.

It's as modeling_qwen.py doesn't handle None attention_mask when using spda backend. https://huggingface.co/Alibaba-NLP/gte-Qwen2-7B-instruct/blob/1efef6cb8e5b06824152b8fa2a42e762bd4a3571/modeling_qwen.py#L1038

The workaround would be add a parament here to specify flash-attn backend https://github.com/wejoncy/QLLM/blob/7ecb24b9c53b0ba7b46c140457170b44682e631a/qllm/modeling/base.py#L176

llm = AutoModelForCausalLM.from_pretrained(
            pretrained_model_name_or_path, torch_dtype=torch.float16, trust_remote_code=trust_remote_code
            ,attn_implementation="flash_attention_2")

Or fix qwen function https://huggingface.co/Alibaba-NLP/gte-Qwen2-7B-instruct/blob/1efef6cb8e5b06824152b8fa2a42e762bd4a3571/modeling_qwen.py#L1038

attention_mask = _prepare_4d_attention_mask_for_sdpa(
                    attention_mask, inputs_embeds.dtype
                )
``` to support None mask

or 
Fix here to pass attention_mask

try: # noqa:SIM105 model(batch[0].to(dev), batch[1].to(dev)) except ValueError: pass


https://github.com/wejoncy/QLLM/blob/7ecb24b9c53b0ba7b46c140457170b44682e631a/qllm/quantization/quant_frame_base.py#L90
wejoncy commented 5 months ago

I create a fix PR to support specify use_flash_attn by a environment variable USE_FLASH_ATTN=1

prattcmp commented 5 months ago

Wow, how did you untangle that so quickly?

Env variable is great. Would also be nice to have it as a CLI flag. Any regression risk if the flag is default true/enabled?

wejoncy commented 5 months ago

Wow, how did you untangle that so quickly?

Env variable is great. Would also be nice to have it as a CLI flag. Any regression risk if the flag is default true/enabled?

If use_flash_attention is set by default, it will require user to install flash-attn package. However, flash-attn requires sm>=8.0, which would be impossible for GPUs like V100 or below.

Besides, Transformers will select the appropriate backend for attention (eager/spda/flash-attn), the first two works great for all GPUs.