huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
131.87k stars 26.26k forks source link

Wrong ouput of Gemma-2 models using `flash_attention_2` #32309

Closed tanliboy closed 1 month ago

tanliboy commented 1 month ago

I remember that the soft-capping issue was resolved for forward pass in flash_attn. However, I am still seeing poor model outputs when I enable use_flash_attention_2 in Transformers, even for inference:

Did I miss something? Or is it a recent regression?

Who can help?

@ArthurZucker

Reproduction

  1. Turn on use_flash_attention_2 to load Gemma-2 7B IT model

python

from transformers import AutoTokenizer, AutoModelForCausalLM
import torch

tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-9b-it")
model = AutoModelForCausalLM.from_pretrained(
    "google/gemma-2-9b-it",
    device_map="auto",
    torch_dtype=torch.bfloat16,
    use_flash_attention_2=True     # It generates non-sense if I set it to be true
)
input_text = "Write me a poem about Machine Learning."
input_ids = tokenizer(input_text, return_tensors="pt").to("cuda")
outputs = model.generate(**input_ids)
print(tokenizer.decode(outputs[0]))
  1. Observe the non-sense output and compare it with the situation when use_flash_attention_2=False. This can be consistently reproduced.

Expected behavior

See the difference below:

Screenshot 2024-07-29 at 6 52 53 PM Screenshot 2024-07-29 at 6 53 24 PM
zucchini-nlp commented 1 month ago

Yes, there's a PR to fix it (https://github.com/huggingface/transformers/pull/32188)

tanliboy commented 1 month ago

Thank you, @zucchini-nlp !

After the fix, will we be able to use flash_attention_2 for both forward (inference) and backward (training) paths of Gemma2 models in transformers?

Since FlashAttention currently doesn't support a static cache, do you think this issue will also impact other libraries (e.g., vLLM and other frameworks) when using flash_attention_2 with the Gemma2 model? If so, do you think we can address this within the FlashAttention library?

zucchini-nlp commented 1 month ago

@tanliboy FA2 should now work for transformer, in forward and backward.

For other libraries, I am not super familiar with all of them but for vllm Gemma2 should work same way as other models because they do not use the same `StaticCache we do. Also note that currently vllm doesn't do sliding window in every second attn block, as per the comment I see here

tanliboy commented 1 month ago

Thank you for the details, @zucchini-nlp !

xenova commented 1 month ago

@tanliboy Glad to see it's fixed! Let me know if I can close the issue 😇

zucchini-nlp commented 1 month ago

sure, the PR is merged already, closing the issue :)

HuangBugWei commented 1 month ago

@zucchini-nlp, very thank you to fix the issue about it. Since the PR #32188 is merged in 5 days ago, I guess the latest released v4.43.3 Patch deepspeed does not contain this branch update right? We should install from source by pip install git+https://github.com/huggingface/transformers to adopt that feature right?

zucchini-nlp commented 1 month ago

@HuangBugWei correct! We might have a release soon, but until then it should be installed from source

tanliboy commented 1 month ago

@zucchini-nlp thanks for the fix!

I installed the latest release but ran into the below error while using flash_attention_2 (it is fine without using flash_attention_2).

../aten/src/ATen/native/cuda/ScatterGatherKernel.cu:144: operator(): block: [12,0,0], thread: [31,0,0] Assertion `idx_dim >= 0 && idx_dim < index_size && "index out of bounds"` failed.
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[2], line 12
      5 streamer = TextStreamer(tokenizer)
      6 terminators = (
      7     [
      8         tokenizer.eos_token_id,
      9         tokenizer.convert_tokens_to_ids("<end_of_turn>"),
     10     ]
     11 )
---> 12 _ = model.generate(**input_ids, streamer=streamer, eos_token_id=terminators, max_new_tokens=2048)

File [/opt/conda/envs/handbook/lib/python3.10/site-packages/torch/utils/_contextlib.py:116](http://localhost:8181/opt/conda/envs/handbook/lib/python3.10/site-packages/torch/utils/_contextlib.py#line=115), in context_decorator.<locals>.decorate_context(*args, **kwargs)
    113 @functools.wraps(func)
    114 def decorate_context(*args, **kwargs):
    115     with ctx_factory():
--> 116         return func(*args, **kwargs)

File [/opt/conda/envs/handbook/lib/python3.10/site-packages/transformers/generation/utils.py:2024](http://localhost:8181/opt/conda/envs/handbook/lib/python3.10/site-packages/transformers/generation/utils.py#line=2023), in GenerationMixin.generate(self, inputs, generation_config, logits_processor, stopping_criteria, prefix_allowed_tokens_fn, synced_gpus, assistant_model, streamer, negative_prompt_ids, negative_prompt_attention_mask, **kwargs)
   2016     input_ids, model_kwargs = self._expand_inputs_for_generation(
   2017         input_ids=input_ids,
   2018         expand_size=generation_config.num_return_sequences,
   2019         is_encoder_decoder=self.config.is_encoder_decoder,
   2020         **model_kwargs,
   2021     )
   2023     # 13. run sample (it degenerates to greedy search when `generation_config.do_sample=False`)
-> 2024     result = self._sample(
   2025         input_ids,
   2026         logits_processor=prepared_logits_processor,
   2027         logits_warper=prepared_logits_warper,
   2028         stopping_criteria=prepared_stopping_criteria,
   2029         generation_config=generation_config,
   2030         synced_gpus=synced_gpus,
   2031         streamer=streamer,
   2032         **model_kwargs,
   2033     )
   2035 elif generation_mode in (GenerationMode.BEAM_SAMPLE, GenerationMode.BEAM_SEARCH):
   2036     # 11. prepare logits warper
   2037     prepared_logits_warper = (
   2038         self._get_logits_warper(generation_config, device=input_ids.device)
   2039         if generation_config.do_sample
   2040         else None
   2041     )

File [/opt/conda/envs/handbook/lib/python3.10/site-packages/transformers/generation/utils.py:2982](http://localhost:8181/opt/conda/envs/handbook/lib/python3.10/site-packages/transformers/generation/utils.py#line=2981), in GenerationMixin._sample(self, input_ids, logits_processor, stopping_criteria, generation_config, synced_gpus, streamer, logits_warper, **model_kwargs)
   2979 model_inputs.update({"output_hidden_states": output_hidden_states} if output_hidden_states else {})
   2981 # forward pass to get next token
-> 2982 outputs = self(**model_inputs, return_dict=True)
   2984 if synced_gpus and this_peer_finished:
   2985     continue  # don't waste resources running the code we don't need

File [/opt/conda/envs/handbook/lib/python3.10/site-packages/torch/nn/modules/module.py:1553](http://localhost:8181/opt/conda/envs/handbook/lib/python3.10/site-packages/torch/nn/modules/module.py#line=1552), in Module._wrapped_call_impl(self, *args, **kwargs)
   1551     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1552 else:
-> 1553     return self._call_impl(*args, **kwargs)

File [/opt/conda/envs/handbook/lib/python3.10/site-packages/torch/nn/modules/module.py:1562](http://localhost:8181/opt/conda/envs/handbook/lib/python3.10/site-packages/torch/nn/modules/module.py#line=1561), in Module._call_impl(self, *args, **kwargs)
   1557 # If we don't have any hooks, we want to skip the rest of the logic in
   1558 # this function, and just call forward.
   1559 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1560         or _global_backward_pre_hooks or _global_backward_hooks
   1561         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1562     return forward_call(*args, **kwargs)
   1564 try:
   1565     result = None

File [/opt/conda/envs/handbook/lib/python3.10/site-packages/accelerate/hooks.py:166](http://localhost:8181/opt/conda/envs/handbook/lib/python3.10/site-packages/accelerate/hooks.py#line=165), in add_hook_to_module.<locals>.new_forward(module, *args, **kwargs)
    164         output = module._old_forward(*args, **kwargs)
    165 else:
--> 166     output = module._old_forward(*args, **kwargs)
    167 return module._hf_hook.post_forward(module, output)

File [/opt/conda/envs/handbook/lib/python3.10/site-packages/transformers/models/gemma2/modeling_gemma2.py:999](http://localhost:8181/opt/conda/envs/handbook/lib/python3.10/site-packages/transformers/models/gemma2/modeling_gemma2.py#line=998), in Gemma2ForCausalLM.forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict, cache_position)
    996 return_dict = return_dict if return_dict is not None else self.config.use_return_dict
    998 # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
--> 999 outputs = self.model(
   1000     input_ids=input_ids,
   1001     attention_mask=attention_mask,
   1002     position_ids=position_ids,
   1003     past_key_values=past_key_values,
   1004     inputs_embeds=inputs_embeds,
   1005     use_cache=use_cache,
   1006     output_attentions=output_attentions,
   1007     output_hidden_states=output_hidden_states,
   1008     return_dict=return_dict,
   1009     cache_position=cache_position,
   1010 )
   1012 hidden_states = outputs[0]
   1013 logits = self.lm_head(hidden_states)

File [/opt/conda/envs/handbook/lib/python3.10/site-packages/torch/nn/modules/module.py:1553](http://localhost:8181/opt/conda/envs/handbook/lib/python3.10/site-packages/torch/nn/modules/module.py#line=1552), in Module._wrapped_call_impl(self, *args, **kwargs)
   1551     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1552 else:
-> 1553     return self._call_impl(*args, **kwargs)

File [/opt/conda/envs/handbook/lib/python3.10/site-packages/torch/nn/modules/module.py:1562](http://localhost:8181/opt/conda/envs/handbook/lib/python3.10/site-packages/torch/nn/modules/module.py#line=1561), in Module._call_impl(self, *args, **kwargs)
   1557 # If we don't have any hooks, we want to skip the rest of the logic in
   1558 # this function, and just call forward.
   1559 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1560         or _global_backward_pre_hooks or _global_backward_hooks
   1561         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1562     return forward_call(*args, **kwargs)
   1564 try:
   1565     result = None

File [/opt/conda/envs/handbook/lib/python3.10/site-packages/transformers/models/gemma2/modeling_gemma2.py:847](http://localhost:8181/opt/conda/envs/handbook/lib/python3.10/site-packages/transformers/models/gemma2/modeling_gemma2.py#line=846), in Gemma2Model.forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, use_cache, output_attentions, output_hidden_states, return_dict, cache_position)
    836     layer_outputs = self._gradient_checkpointing_func(
    837         decoder_layer.__call__,
    838         hidden_states,
   (...)
    844         cache_position,
    845     )
    846 else:
--> 847     layer_outputs = decoder_layer(
    848         hidden_states,
    849         attention_mask=causal_mask,
    850         position_ids=position_ids,
    851         past_key_value=past_key_values,
    852         output_attentions=output_attentions,
    853         use_cache=use_cache,
    854         cache_position=cache_position,
    855     )
    857 hidden_states = layer_outputs[0]
    859 if output_attentions:

File [/opt/conda/envs/handbook/lib/python3.10/site-packages/torch/nn/modules/module.py:1553](http://localhost:8181/opt/conda/envs/handbook/lib/python3.10/site-packages/torch/nn/modules/module.py#line=1552), in Module._wrapped_call_impl(self, *args, **kwargs)
   1551     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1552 else:
-> 1553     return self._call_impl(*args, **kwargs)

File [/opt/conda/envs/handbook/lib/python3.10/site-packages/torch/nn/modules/module.py:1562](http://localhost:8181/opt/conda/envs/handbook/lib/python3.10/site-packages/torch/nn/modules/module.py#line=1561), in Module._call_impl(self, *args, **kwargs)
   1557 # If we don't have any hooks, we want to skip the rest of the logic in
   1558 # this function, and just call forward.
   1559 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1560         or _global_backward_pre_hooks or _global_backward_hooks
   1561         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1562     return forward_call(*args, **kwargs)
   1564 try:
   1565     result = None

File [/opt/conda/envs/handbook/lib/python3.10/site-packages/accelerate/hooks.py:166](http://localhost:8181/opt/conda/envs/handbook/lib/python3.10/site-packages/accelerate/hooks.py#line=165), in add_hook_to_module.<locals>.new_forward(module, *args, **kwargs)
    164         output = module._old_forward(*args, **kwargs)
    165 else:
--> 166     output = module._old_forward(*args, **kwargs)
    167 return module._hf_hook.post_forward(module, output)

File [/opt/conda/envs/handbook/lib/python3.10/site-packages/transformers/models/gemma2/modeling_gemma2.py:590](http://localhost:8181/opt/conda/envs/handbook/lib/python3.10/site-packages/transformers/models/gemma2/modeling_gemma2.py#line=589), in Gemma2DecoderLayer.forward(self, hidden_states, attention_mask, position_ids, past_key_value, output_attentions, use_cache, cache_position)
    587 hidden_states = self.input_layernorm(hidden_states)
    589 # Self Attention
--> 590 hidden_states, self_attn_weights, present_key_value = self.self_attn(
    591     hidden_states=hidden_states,
    592     attention_mask=attention_mask,
    593     position_ids=position_ids,
    594     past_key_value=past_key_value,
    595     output_attentions=output_attentions,
    596     use_cache=use_cache,
    597     cache_position=cache_position,
    598 )
    599 hidden_states = self.post_attention_layernorm(hidden_states)
    600 hidden_states = residual + hidden_states

File [/opt/conda/envs/handbook/lib/python3.10/site-packages/torch/nn/modules/module.py:1553](http://localhost:8181/opt/conda/envs/handbook/lib/python3.10/site-packages/torch/nn/modules/module.py#line=1552), in Module._wrapped_call_impl(self, *args, **kwargs)
   1551     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1552 else:
-> 1553     return self._call_impl(*args, **kwargs)

File [/opt/conda/envs/handbook/lib/python3.10/site-packages/torch/nn/modules/module.py:1562](http://localhost:8181/opt/conda/envs/handbook/lib/python3.10/site-packages/torch/nn/modules/module.py#line=1561), in Module._call_impl(self, *args, **kwargs)
   1557 # If we don't have any hooks, we want to skip the rest of the logic in
   1558 # this function, and just call forward.
   1559 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1560         or _global_backward_pre_hooks or _global_backward_hooks
   1561         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1562     return forward_call(*args, **kwargs)
   1564 try:
   1565     result = None

File [/opt/conda/envs/handbook/lib/python3.10/site-packages/accelerate/hooks.py:166](http://localhost:8181/opt/conda/envs/handbook/lib/python3.10/site-packages/accelerate/hooks.py#line=165), in add_hook_to_module.<locals>.new_forward(module, *args, **kwargs)
    164         output = module._old_forward(*args, **kwargs)
    165 else:
--> 166     output = module._old_forward(*args, **kwargs)
    167 return module._hf_hook.post_forward(module, output)

File [/opt/conda/envs/handbook/lib/python3.10/site-packages/transformers/models/gemma2/modeling_gemma2.py:423](http://localhost:8181/opt/conda/envs/handbook/lib/python3.10/site-packages/transformers/models/gemma2/modeling_gemma2.py#line=422), in Gemma2FlashAttention2.forward(self, hidden_states, attention_mask, position_ids, past_key_value, output_attentions, use_cache, cache_position)
    420     key_states = key_states.to(target_dtype)
    421     value_states = value_states.to(target_dtype)
--> 423 attn_output = _flash_attention_forward(
    424     query_states,
    425     key_states,
    426     value_states,
    427     attention_mask,
    428     q_len,
    429     dropout=dropout_rate,
    430     softmax_scale=self.scaling,
    431     is_causal=self.is_causal,
    432     use_top_left_mask=self._flash_attn_uses_top_left_mask,
    433     softcap=self.config.attn_logit_softcapping if is_flash_attn_greater_or_equal("2.6.0") else None,
    434 )
    436 attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
    437 attn_output = self.o_proj(attn_output)

File [/opt/conda/envs/handbook/lib/python3.10/site-packages/transformers/modeling_flash_attention_utils.py:246](http://localhost:8181/opt/conda/envs/handbook/lib/python3.10/site-packages/transformers/modeling_flash_attention_utils.py#line=245), in _flash_attention_forward(query_states, key_states, value_states, attention_mask, query_length, is_causal, dropout, position_ids, softmax_scale, sliding_window, use_top_left_mask, softcap, deterministic)
    244 if attention_mask is not None:
    245     batch_size = query_states.shape[0]
--> 246     query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = _upad_input(
    247         query_states, key_states, value_states, attention_mask, query_length
    248     )
    249     cu_seqlens_q, cu_seqlens_k = cu_seq_lens
    250     max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens

File [/opt/conda/envs/handbook/lib/python3.10/site-packages/transformers/modeling_flash_attention_utils.py:121](http://localhost:8181/opt/conda/envs/handbook/lib/python3.10/site-packages/transformers/modeling_flash_attention_utils.py#line=120), in _upad_input(query_layer, key_layer, value_layer, attention_mask, query_length)
    118 else:
    119     # The -q_len: slice assumes left padding.
    120     attention_mask = attention_mask[:, -query_length:]
--> 121     query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
    123 return (
    124     query_layer,
    125     key_layer,
   (...)
    129     (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
    130 )

File [/opt/conda/envs/handbook/lib/python3.10/site-packages/flash_attn/bert_padding.py:110](http://localhost:8181/opt/conda/envs/handbook/lib/python3.10/site-packages/flash_attn/bert_padding.py#line=109), in unpad_input(hidden_states, attention_mask)
     99 """
    100 Arguments:
    101     hidden_states: (batch, seqlen, ...)
   (...)
    107     max_seqlen_in_batch: int
    108 """
    109 seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
--> 110 indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
    111 max_seqlen_in_batch = seqlens_in_batch.max().item()
    112 cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))

RuntimeError: CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

Here is the testing code to repro:

from transformers import AutoTokenizer, AutoModelForCausalLM
import torch

tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-9b-it")
model = AutoModelForCausalLM.from_pretrained(
    "google/gemma-2-9b-it",
    device_map="auto",
    torch_dtype=torch.bfloat16,
    use_flash_attention_2=True
)
input_text = "Write me a poem about Machine Learning."
input_ids = tokenizer(input_text, return_tensors="pt").to("cuda")

from transformers import TextStreamer
streamer = TextStreamer(tokenizer)
terminators = (
    [
        tokenizer.eos_token_id,
        tokenizer.convert_tokens_to_ids("<end_of_turn>"),
    ]
)
_ = model.generate(**input_ids, streamer=streamer, eos_token_id=terminators, max_new_tokens=2048)

Did I miss something?

zucchini-nlp commented 1 month ago

@tanliboy yeah, seems like there were some other changes in how attn mask is prepared, which broke FA2 again... Will open a new PR

tanliboy commented 1 month ago

Thank you, @zucchini-nlp !

tanliboy commented 1 month ago

I tested the fix, and it worked well. Thank you!

I also had a side-by-side comparison during fine-tuning with and without flash_attention_2. Surprisingly, the fine-tuning with flash_attention_2 showed only a marginal improvement over the eager mode on my A100x8 setup.

The "GPU Time Spent Accessing Memory" was around 40%, which is lower than the ~47% observed with eager, but still higher than other models during fine-tuning (~32%). The "Process GPU Memory" is ~91% with flash_attention_2, compared with ~97% with eager.

With flash_attention_2:

Screenshot 2024-08-09 at 5 00 10 PM

With eager:

Screenshot 2024-08-09 at 5 00 31 PM
tanliboy commented 1 month ago

@zucchini-nlp , is this warning still true? Or should we remove it given the fix?

It is strongly recommended to train Gemma2 models with the eager attention implementation instead of flash_attention_2. Use eager with AutoModelForCausalLM.from_pretrained('<path-to-checkpoint>', attn_implementation='eager').

zucchini-nlp commented 1 month ago

Yes, I believe it still holds true as it wasn't related to FA2 not being supported, but rather due to small numerical precision differences between eager and non-eager attn

ArthurZucker commented 1 month ago

No it’s not longer true as flash attention soft capping is supported. Will remove

zucchini-nlp commented 1 month ago

I guess SDPA is not yet supported?

ArthurZucker commented 1 month ago

Yes, we need to integrate flex attention for that!