hao-ai-lab / LookaheadDecoding

Apache License 2.0
1.04k stars 63 forks source link

Incompatible with LlamaSdpaAttention in transformers v4.36 #35

Closed learning-chip closed 6 months ago

learning-chip commented 6 months ago

transformers v4.36 implementedLlamaSdpaAttention (https://github.com/huggingface/transformers/pull/26572) that calls FlashAttention by default.

But running LOAD_LADE=1 USE_LADE=1 python minimal.py leads to:

Traceback (most recent call last):
  File "/home/LookaheadDecoding/minimal.py", line 31, in <module>
    greedy_output = model.generate(**model_inputs, max_new_tokens=256, do_sample=False)
  File "/opt/conda/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/transformers/generation/utils.py", line 1718, in generate
    return self.greedy_search(
  File "/home/LookaheadDecoding/lade/decoding.py", line 23, in greedy_search_proxy
    return jacobi_greedy_search_multilevel(self, chat=False, *args, **kwargs)
  File "/home/LookaheadDecoding/lade/decoding.py", line 278, in jacobi_greedy_search_multilevel
    outputs = self.jforward_multilevel(
  File "/home/LookaheadDecoding/lade/models/llama.py", line 383, in jforward_multilevel
    outputs = self.model.LlamaModeljforward(
  File "/home/LookaheadDecoding/lade/models/llama.py", line 235, in LlamaModeljforward
    layer_outputs = decoder_layer.forward(
  File "/opt/conda/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 796, in forward
    hidden_states, self_attn_weights, present_key_value = self.self_attn(
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
TypeError: LlamaSdpaAttention.forward() got an unexpected keyword argument 'padding_mask'
jiqing-feng commented 6 months ago

This PR can fix the problem