hao-ai-lab / LookaheadDecoding

[ICML 2024] Break the Sequential Dependency of LLM Inference Using Lookahead Decoding
https://arxiv.org/abs/2402.02057
Apache License 2.0
1.11k stars 65 forks source link

Compatibility with Flash Attention 2 #46

Closed jasonli0707 closed 8 months ago

jasonli0707 commented 8 months ago

Environments:

flash-attn==2.4.2 torch == 2.0.1+cu118 transformers == 4.36.2

It seems that the current version does not support Flash Attention V2. I encountered the following errors when running minimal.py with attn_implementation="flash_attention_2".

Traceback (most recent call last): File "/home/jasonlcl/dev/LookaheadDecoding/minimal.py", line 47, in greedy_output = model.generate(model_inputs, max_new_tokens=256, do_sample=False) File "/home/jasonlcl/miniconda3/envs/llm/lib/python3.9/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context return func(*args, *kwargs) File "/home/jasonlcl/miniconda3/envs/llm/lib/python3.9/site-packages/transformers/generation/utils.py", line 1718, in generate return self.greedy_search( File "/home/jasonlcl/dev/LookaheadDecoding/lade/decoding.py", line 24, in greedy_search_proxy return jacobi_greedy_search_multilevel(self, chat=False, args, kwargs) File "/home/jasonlcl/dev/LookaheadDecoding/lade/decoding.py", line 278, in jacobi_greedy_search_multilevel outputs = self.jforward_multilevel( File "/home/jasonlcl/dev/LookaheadDecoding/lade/models/llama.py", line 383, in jforward_multilevel outputs = self.model.LlamaModeljforward( File "/home/jasonlcl/dev/LookaheadDecoding/lade/models/llama.py", line 234, in LlamaModeljforward layer_outputs = decoder_layer.forward( File "/home/jasonlcl/miniconda3/envs/llm/lib/python3.9/site-packages/transformers/models/llama/modeling_llama.py", line 796, in forward hidden_states, self_attn_weights, present_key_value = self.self_attn( File "/home/jasonlcl/miniconda3/envs/llm/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl return forward_call(*args, *kwargs) File "/home/jasonlcl/miniconda3/envs/llm/lib/python3.9/site-packages/transformers/models/llama/modeling_llama.py", line 547, in forward attn_output = self._flash_attention_forward( File "/home/jasonlcl/miniconda3/envs/llm/lib/python3.9/site-packages/transformers/models/llama/modeling_llama.py", line 597, in _flash_attention_forward attn_output_unpad = flash_attn_varlen_func( File "/home/jasonlcl/miniconda3/envs/llm/lib/python3.9/site-packages/flash_attn/flash_attn_interface.py", line 1059, in flash_attn_varlen_func return FlashAttnVarlenFunc.apply( File "/home/jasonlcl/miniconda3/envs/llm/lib/python3.9/site-packages/torch/autograd/function.py", line 506, in apply return super().apply(args, **kwargs) # type: ignore[misc] File "/home/jasonlcl/miniconda3/envs/llm/lib/python3.9/site-packages/flash_attn/flash_attn_interface.py", line 576, in forward out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_varlen_forward( File "/home/jasonlcl/miniconda3/envs/llm/lib/python3.9/site-packages/flash_attn/flash_attn_interface.py", line 85, in _flash_attn_varlen_forward out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = flash_attn_cuda.varlen_fwd( RuntimeError: cu_seqlens_q must have shape (batch_size + 1)

Viol2000 commented 8 months ago

The current released version does not support FlashAttention. But I have already made it compatible with FlashAttention-2 (and many other augmentations). It will be released in a few days.

jasonli0707 commented 8 months ago

I see, thanks! looking forward!