huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
134.59k stars 26.91k forks source link

Idefics2 generation erroring with flash_attention_2 #32237

Closed tctrautman closed 3 months ago

tctrautman commented 3 months ago

System Info

- `transformers` version: 4.44.0.dev0
- Platform: Linux-5.4.0-155-generic-x86_64-with-glibc2.35
- Python version: 3.10.12
- Huggingface_hub version: 0.24.2
- Safetensors version: 0.4.3
- Accelerate version: 0.33.0
- Accelerate config:    not found
- PyTorch version (GPU?): 2.1.1+cu121 (True)
- Tensorflow version (GPU?): not installed (NA)
- Flax version (CPU?/GPU?/TPU?): not installed (NA)
- Jax version: not installed
- JaxLib version: not installed
- Using distributed or parallel set-up in script?: no
- Using GPU in script?: yes (see script)
- GPU type: NVIDIA RTX A6000

Who can help?

@zucchini-nlp

Information

Tasks

Reproduction

The below script is the same as the one that is included on the Idefics2 blog post, with three additional lines added within AutoModelForVision2Seq.from_pretrained, with comments to note the new lines.

import requests
import torch
from PIL import Image

from transformers import AutoProcessor, AutoModelForVision2Seq
from transformers.image_utils import load_image

DEVICE = "cuda:0"
dtype = torch.bfloat16

# Note that passing the image urls (instead of the actual pil images) to the processor is also possible
image1 = load_image("https://cdn.britannica.com/61/93061-050-99147DCE/Statue-of-Liberty-Island-New-York-Bay.jpg")
image2 = load_image("https://cdn.britannica.com/59/94459-050-DBA42467/Skyline-Chicago.jpg")
image3 = load_image("https://cdn.britannica.com/68/170868-050-8DDE8263/Golden-Gate-Bridge-San-Francisco.jpg")

processor = AutoProcessor.from_pretrained("HuggingFaceM4/idefics2-8b")
model = AutoModelForVision2Seq.from_pretrained(
    "HuggingFaceM4/idefics2-8b",
    attn_implementation="flash_attention_2", # This is a new line
    torch_dtype=dtype, # This is a new line
    device_map=DEVICE, # This is a new line
).to(DEVICE)

# Create inputs
messages = [
    {
        "role": "user",
        "content": [
            {"type": "image"},
            {"type": "text", "text": "What do we see in this image?"},
        ]
    },
    {
        "role": "assistant",
        "content": [
            {"type": "text", "text": "In this image, we can see the city of New York, and more specifically the Statue of Liberty."},
        ]
    },
    {
        "role": "user",
        "content": [
            {"type": "image"},
            {"type": "text", "text": "And how about this image?"},
        ]
    },
]
prompt = processor.apply_chat_template(messages, add_generation_prompt=True)
inputs = processor(text=prompt, images=[image1, image2], return_tensors="pt")
inputs = {k: v.to(DEVICE) for k, v in inputs.items()}

# Generate
generated_ids = model.generate(**inputs, max_new_tokens=500)
generated_texts = processor.batch_decode(generated_ids, skip_special_tokens=True)

print(generated_texts)

When this block of code is run, it will yield the below error.

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[2], line 55
     51 inputs = {k: v.to(DEVICE) for k, v in inputs.items()}
     54 # Generate
---> 55 generated_ids = model.generate(**inputs, max_new_tokens=500)
     56 generated_texts = processor.batch_decode(generated_ids, skip_special_tokens=True)
     58 print(generated_texts)

File /usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py:115, in context_decorator.<locals>.decorate_context(*args, **kwargs)
    112 @functools.wraps(func)
    113 def decorate_context(*args, **kwargs):
    114     with ctx_factory():
--> 115         return func(*args, **kwargs)

File /usr/local/lib/python3.10/dist-packages/transformers/generation/utils.py:1990, in GenerationMixin.generate(self, inputs, generation_config, logits_processor, stopping_criteria, prefix_allowed_tokens_fn, synced_gpus, assistant_model, streamer, negative_prompt_ids, negative_prompt_attention_mask, **kwargs)
   1982     input_ids, model_kwargs = self._expand_inputs_for_generation(
   1983         input_ids=input_ids,
   1984         expand_size=generation_config.num_return_sequences,
   1985         is_encoder_decoder=self.config.is_encoder_decoder,
   1986         **model_kwargs,
   1987     )
   1989     # 13. run sample (it degenerates to greedy search when `generation_config.do_sample=False`)
-> 1990     result = self._sample(
   1991         input_ids,
   1992         logits_processor=prepared_logits_processor,
   1993         logits_warper=prepared_logits_warper,
   1994         stopping_criteria=prepared_stopping_criteria,
   1995         generation_config=generation_config,
   1996         synced_gpus=synced_gpus,
   1997         streamer=streamer,
   1998         **model_kwargs,
   1999     )
   2001 elif generation_mode in (GenerationMode.BEAM_SAMPLE, GenerationMode.BEAM_SEARCH):
   2002     # 11. prepare logits warper
   2003     prepared_logits_warper = (
   2004         self._get_logits_warper(generation_config, device=input_ids.device)
   2005         if generation_config.do_sample
   2006         else None
   2007     )

File /usr/local/lib/python3.10/dist-packages/transformers/generation/utils.py:2933, in GenerationMixin._sample(self, input_ids, logits_processor, stopping_criteria, generation_config, synced_gpus, streamer, logits_warper, **model_kwargs)
   2930 model_inputs.update({"output_hidden_states": output_hidden_states} if output_hidden_states else {})
   2932 # forward pass to get next token
-> 2933 outputs = self(**model_inputs, return_dict=True)
   2935 if synced_gpus and this_peer_finished:
   2936     continue  # don't waste resources running the code we don't need

File /usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, **kwargs)
   1516     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1517 else:
-> 1518     return self._call_impl(*args, **kwargs)

File /usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, **kwargs)
   1522 # If we don't have any hooks, we want to skip the rest of the logic in
   1523 # this function, and just call forward.
   1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1525         or _global_backward_pre_hooks or _global_backward_hooks
   1526         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527     return forward_call(*args, **kwargs)
   1529 try:
   1530     result = None

File /usr/local/lib/python3.10/dist-packages/transformers/models/idefics2/modeling_idefics2.py:1575, in Idefics2ForConditionalGeneration.forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, pixel_values, pixel_attention_mask, image_hidden_states, labels, use_cache, output_attentions, output_hidden_states, return_dict)
   1572 return_dict = return_dict if return_dict is not None else self.config.use_return_dict
   1574 # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
-> 1575 outputs = self.model(
   1576     input_ids=input_ids,
   1577     attention_mask=attention_mask,
   1578     position_ids=position_ids,
   1579     past_key_values=past_key_values,
   1580     inputs_embeds=inputs_embeds,
   1581     pixel_values=pixel_values,
   1582     pixel_attention_mask=pixel_attention_mask,
   1583     image_hidden_states=image_hidden_states,
   1584     use_cache=use_cache,
   1585     output_attentions=output_attentions,
   1586     output_hidden_states=output_hidden_states,
   1587     return_dict=return_dict,
   1588 )
   1590 hidden_states = outputs[0]
   1591 logits = self.lm_head(hidden_states)

File /usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, **kwargs)
   1516     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1517 else:
-> 1518     return self._call_impl(*args, **kwargs)

File /usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, **kwargs)
   1522 # If we don't have any hooks, we want to skip the rest of the logic in
   1523 # this function, and just call forward.
   1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1525         or _global_backward_pre_hooks or _global_backward_hooks
   1526         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527     return forward_call(*args, **kwargs)
   1529 try:
   1530     result = None

File /usr/local/lib/python3.10/dist-packages/transformers/models/idefics2/modeling_idefics2.py:1408, in Idefics2Model.forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, pixel_values, pixel_attention_mask, image_hidden_states, use_cache, output_attentions, output_hidden_states, return_dict)
   1399 if past_seen_tokens == 0 and inputs_embeds is not None and image_hidden_states is not None:
   1400     # When we generate, we don't want to replace the potential image_token_id that we generated by images
   1401     # that simply don't exist
   1402     inputs_embeds = self.inputs_merger(
   1403         input_ids=input_ids,
   1404         inputs_embeds=inputs_embeds,
   1405         image_hidden_states=image_hidden_states,
   1406     )
-> 1408 outputs = self.text_model(
   1409     inputs_embeds=inputs_embeds,
   1410     attention_mask=attention_mask,
   1411     position_ids=position_ids,
   1412     past_key_values=past_key_values,
   1413     output_attentions=output_attentions,
   1414     output_hidden_states=output_hidden_states,
   1415     return_dict=return_dict,
   1416 )
   1418 if return_legacy_cache and use_cache:
   1419     outputs.past_key_values = outputs.past_key_values.to_legacy_cache()

File /usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, **kwargs)
   1516     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1517 else:
-> 1518     return self._call_impl(*args, **kwargs)

File /usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, **kwargs)
   1522 # If we don't have any hooks, we want to skip the rest of the logic in
   1523 # this function, and just call forward.
   1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1525         or _global_backward_pre_hooks or _global_backward_hooks
   1526         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527     return forward_call(*args, **kwargs)
   1529 try:
   1530     result = None

File /usr/local/lib/python3.10/dist-packages/transformers/models/mistral/modeling_mistral.py:805, in MistralModel.forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, use_cache, output_attentions, output_hidden_states, return_dict, cache_position)
    794     layer_outputs = self._gradient_checkpointing_func(
    795         decoder_layer.__call__,
    796         hidden_states,
   (...)
    802         cache_position,
    803     )
    804 else:
--> 805     layer_outputs = decoder_layer(
    806         hidden_states,
    807         attention_mask=causal_mask,
    808         position_ids=position_ids,
    809         past_key_value=past_key_values,
    810         output_attentions=output_attentions,
    811         use_cache=use_cache,
    812         cache_position=cache_position,
    813     )
    815 hidden_states = layer_outputs[0]
    817 if use_cache:

File /usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, **kwargs)
   1516     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1517 else:
-> 1518     return self._call_impl(*args, **kwargs)

File /usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, **kwargs)
   1522 # If we don't have any hooks, we want to skip the rest of the logic in
   1523 # this function, and just call forward.
   1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1525         or _global_backward_pre_hooks or _global_backward_hooks
   1526         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527     return forward_call(*args, **kwargs)
   1529 try:
   1530     result = None

File /usr/local/lib/python3.10/dist-packages/transformers/models/mistral/modeling_mistral.py:546, in MistralDecoderLayer.forward(self, hidden_states, attention_mask, position_ids, past_key_value, output_attentions, use_cache, cache_position, **kwargs)
    543 hidden_states = self.input_layernorm(hidden_states)
    545 # Self Attention
--> 546 hidden_states, self_attn_weights, present_key_value = self.self_attn(
    547     hidden_states=hidden_states,
    548     attention_mask=attention_mask,
    549     position_ids=position_ids,
    550     past_key_value=past_key_value,
    551     output_attentions=output_attentions,
    552     use_cache=use_cache,
    553     cache_position=cache_position,
    554     **kwargs,
    555 )
    556 hidden_states = residual + hidden_states
    558 # Fully Connected

File /usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, **kwargs)
   1516     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1517 else:
-> 1518     return self._call_impl(*args, **kwargs)

File /usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, **kwargs)
   1522 # If we don't have any hooks, we want to skip the rest of the logic in
   1523 # this function, and just call forward.
   1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1525         or _global_backward_pre_hooks or _global_backward_hooks
   1526         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527     return forward_call(*args, **kwargs)
   1529 try:
   1530     result = None

File /usr/local/lib/python3.10/dist-packages/transformers/models/mistral/modeling_mistral.py:379, in MistralFlashAttention2.forward(self, hidden_states, attention_mask, position_ids, past_key_value, output_attentions, use_cache, cache_position)
    376 key_states = key_states.transpose(1, 2)
    377 value_states = value_states.transpose(1, 2)
--> 379 attn_output = _flash_attention_forward(
    380     query_states,
    381     key_states,
    382     value_states,
    383     attention_mask,
    384     q_len,
    385     position_ids=position_ids,
    386     dropout=dropout_rate,
    387     sliding_window=getattr(self.config, "sliding_window", None),
    388     use_top_left_mask=self._flash_attn_uses_top_left_mask,
    389     is_causal=self.is_causal,
    390 )
    392 attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_dim).contiguous()
    393 attn_output = self.o_proj(attn_output)

File /usr/local/lib/python3.10/dist-packages/transformers/modeling_flash_attention_utils.py:278, in _flash_attention_forward(query_states, key_states, value_states, attention_mask, query_length, is_causal, dropout, position_ids, softmax_scale, sliding_window, use_top_left_mask, softcap, deterministic)
    275     cu_seqlens_q, cu_seqlens_k = cu_seq_lens
    276     max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
--> 278     attn_output = flash_attn_varlen_func(
    279         query_states,
    280         key_states,
    281         value_states,
    282         cu_seqlens_q=cu_seqlens_q,
    283         cu_seqlens_k=cu_seqlens_k,
    284         max_seqlen_q=max_seqlen_in_batch_q,
    285         max_seqlen_k=max_seqlen_in_batch_k,
    286         dropout_p=dropout,
    287         softmax_scale=softmax_scale,
    288         causal=causal,
    289         **flash_kwargs,
    290     )
    292     attn_output = attn_output.view(batch_size, -1, attn_output.size(-2), attn_output.size(-1))
    294 else:

File /usr/local/lib/python3.10/dist-packages/flash_attn/flash_attn_interface.py:1124, in flash_attn_varlen_func(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p, softmax_scale, causal, window_size, softcap, alibi_slopes, deterministic, return_attn_probs, block_table)
   1051 def flash_attn_varlen_func(
   1052     q,
   1053     k,
   (...)
   1067     block_table=None,
   1068 ):
   1069     """dropout_p should be set to 0.0 during evaluation
   1070     Supports multi-query and grouped-query attention (MQA/GQA) by passing in K, V with fewer heads
   1071     than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
   (...)
   1122             pattern (negative means that location was dropped, nonnegative means it was kept).
   1123     """
-> 1124     return FlashAttnVarlenFunc.apply(
   1125         q,
   1126         k,
   1127         v,
   1128         cu_seqlens_q,
   1129         cu_seqlens_k,
   1130         max_seqlen_q,
   1131         max_seqlen_k,
   1132         dropout_p,
   1133         softmax_scale,
   1134         causal,
   1135         window_size,
   1136         softcap,
   1137         alibi_slopes,
   1138         deterministic,
   1139         return_attn_probs,
   1140         block_table,
   1141     )

File /usr/local/lib/python3.10/dist-packages/torch/autograd/function.py:539, in Function.apply(cls, *args, **kwargs)
    536 if not torch._C._are_functorch_transforms_active():
    537     # See NOTE: [functorch vjp and autograd interaction]
    538     args = _functorch.utils.unwrap_dead_wrappers(args)
--> 539     return super().apply(*args, **kwargs)  # type: ignore[misc]
    541 if cls.setup_context == _SingleLevelFunction.setup_context:
    542     raise RuntimeError(
    543         "In order to use an autograd.Function with functorch transforms "
    544         "(vmap, grad, jvp, jacrev, ...), it must override the setup_context "
    545         "staticmethod. For more details, please see "
    546         "https://pytorch.org/docs/master/notes/extending.func.html"
    547     )

File /usr/local/lib/python3.10/dist-packages/flash_attn/flash_attn_interface.py:620, in FlashAttnVarlenFunc.forward(ctx, q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p, softmax_scale, causal, window_size, softcap, alibi_slopes, deterministic, return_softmax, block_table)
    618 if softmax_scale is None:
    619     softmax_scale = q.shape[-1] ** (-0.5)
--> 620 out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_varlen_forward(
    621     q,
    622     k,
    623     v,
    624     cu_seqlens_q,
    625     cu_seqlens_k,
    626     max_seqlen_q,
    627     max_seqlen_k,
    628     dropout_p,
    629     softmax_scale,
    630     causal=causal,
    631     window_size=window_size,
    632     softcap=softcap,
    633     alibi_slopes=alibi_slopes,
    634     return_softmax=return_softmax and dropout_p > 0,
    635     block_table=block_table,
    636 )
    637 ctx.save_for_backward(
    638     q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state
    639 )
    640 ctx.dropout_p = dropout_p

File /usr/local/lib/python3.10/dist-packages/flash_attn/flash_attn_interface.py:90, in _flash_attn_varlen_forward(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p, softmax_scale, causal, window_size, softcap, alibi_slopes, return_softmax, block_table, leftpad_k, seqused_k)
     70 def _flash_attn_varlen_forward(
     71     q,
     72     k,
   (...)
     87     seqused_k=None,
     88 ):
     89     q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
---> 90     out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = flash_attn_cuda.varlen_fwd(
     91         q,
     92         k,
     93         v,
     94         None,
     95         cu_seqlens_q,
     96         cu_seqlens_k,
     97         seqused_k,
     98         leftpad_k,
     99         block_table,
    100         alibi_slopes,
    101         max_seqlen_q,
    102         max_seqlen_k,
    103         dropout_p,
    104         softmax_scale,
    105         False,
    106         causal,
    107         window_size[0],
    108         window_size[1],
    109         softcap,
    110         return_softmax,
    111         None,
    112     )
    113     # if out.isnan().any() or softmax_lse.isnan().any():
    114     #     breakpoint()
    115     return out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state

RuntimeError: batch size must be positive

Expected behavior

I'd expect the above script to generate without error (a similar one did earlier this week, which now yields the same error.)

I believe one of these two issues might be related to this issue:

zucchini-nlp commented 3 months ago

Hey! Indeed Flash-attention seems to be broken in the last release caused by https://github.com/huggingface/transformers/pull/31629. I located the reason and will work on fix, in the meanwhile you can downgrade transformers version to at most v.4.42.4, and try generating again :)

tctrautman commented 3 months ago

Thank you, @zucchini-nlp!