Closed learning-chip closed 6 months ago
transformers v4.36 implementedLlamaSdpaAttention (https://github.com/huggingface/transformers/pull/26572) that calls FlashAttention by default.
LlamaSdpaAttention
But running LOAD_LADE=1 USE_LADE=1 python minimal.py leads to:
LOAD_LADE=1 USE_LADE=1 python minimal.py
Traceback (most recent call last): File "/home/LookaheadDecoding/minimal.py", line 31, in <module> greedy_output = model.generate(**model_inputs, max_new_tokens=256, do_sample=False) File "/opt/conda/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context return func(*args, **kwargs) File "/opt/conda/lib/python3.10/site-packages/transformers/generation/utils.py", line 1718, in generate return self.greedy_search( File "/home/LookaheadDecoding/lade/decoding.py", line 23, in greedy_search_proxy return jacobi_greedy_search_multilevel(self, chat=False, *args, **kwargs) File "/home/LookaheadDecoding/lade/decoding.py", line 278, in jacobi_greedy_search_multilevel outputs = self.jforward_multilevel( File "/home/LookaheadDecoding/lade/models/llama.py", line 383, in jforward_multilevel outputs = self.model.LlamaModeljforward( File "/home/LookaheadDecoding/lade/models/llama.py", line 235, in LlamaModeljforward layer_outputs = decoder_layer.forward( File "/opt/conda/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 796, in forward hidden_states, self_attn_weights, present_key_value = self.self_attn( File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl return forward_call(*args, **kwargs) TypeError: LlamaSdpaAttention.forward() got an unexpected keyword argument 'padding_mask'
This PR can fix the problem
transformers v4.36 implemented
LlamaSdpaAttention
(https://github.com/huggingface/transformers/pull/26572) that calls FlashAttention by default.But running
LOAD_LADE=1 USE_LADE=1 python minimal.py
leads to: