OpenGVLab / OmniQuant

[ICLR2024 spotlight] OmniQuant is a simple and powerful quantization technique for LLMs.
MIT License
626 stars 49 forks source link

TypeError: QuantLlamaDecoderLayer.forward() got an unexpected keyword argument 'padding_mask' #44

Closed xianwujie closed 6 months ago

xianwujie commented 7 months ago

Hi, I have a problem evaluating the quantified llama-2-7b model,, can anyone help?

I quantize llama-2-7b model with below command:

CUDA_VISIABLE_DEVICES=6 python main.py --model ../llama_2-7b --epochs 0 --output_dir ./log/llama-2b-w4a16 --wbits 4 --abits 16 --lwc --eval_ppl

and get error as follow:

Traceback (most recent call last): File "/home/user/workspace/quantization/s24_quant/OmniQuant/main.py", line 376, in main() File "/home/user/workspace/quantization/s24_quant/OmniQuant/main.py", line 352, in main evaluate(lm, args,logger) File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 115, in decorate_context return func(*args, kwargs) File "/home/user/workspace/quantization/s24_quant/OmniQuant/main.py", line 124, in evaluate outputs = lm.model.model(batch) File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1505, in _wrapped_call_impl return self._call_impl(*args, *kwargs) File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1514, in _call_impl return forward_call(args, kwargs) File "/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py", line 925, in forward layer_outputs = decoder_layer( File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1505, in _wrapped_call_impl return self._call_impl(*args, *kwargs) File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1514, in _call_impl return forward_call(args, **kwargs) TypeError: QuantLlamaDecoderLayer.forward() got an unexpected keyword argument 'padding_mask'

my environment is configured as follows:

torch: 2.1.0 transformers: 4.34.0 python: 3.10.2

Alvant commented 7 months ago

A possible "workaround" to try would be adding this padding_mask param in QuantLlamaDecoderLayer:

def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_value: Optional[Tuple[torch.Tensor]] = None,
        output_attentions: Optional[bool] = False,
        use_cache: Optional[bool] = False,
        padding_mask: Optional[torch.Tensor] = None,  # <- Here
    ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:

https://github.com/OpenGVLab/OmniQuant/blob/main/models/int_llama_layer.py#L220

It might work but I can't guarantee that it will lead to correct results :sweat_smile: However, you can check the file /usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py and make sure that padding_mask is not used anywhere in computation. If that is true, than the workaround above might be OK.