huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
131.74k stars 26.23k forks source link

The implementations of `LlamaAttention` and `LlamaSdpaAttention` are not equivalent. #32086

Open SDaoer opened 1 month ago

SDaoer commented 1 month ago

System Info

Who can help?

@ArthurZucker

Information

Tasks

Reproduction

  1. change the output_attentions=output_attentions to be output_attentions=True at https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L617
  2. run the official example script:
    
    from PIL import Image
    import requests
    from transformers import AutoProcessor, LlavaForConditionalGeneration

model = LlavaForConditionalGeneration.from_pretrained("llava-hf/llava-1.5-7b-hf") processor = AutoProcessor.from_pretrained("llava-hf/llava-1.5-7b-hf")

prompt = "USER: \nWhat's the content of the image? ASSISTANT:" url = "https://www.ilankelman.org/stopsigns/australia.jpg" image = Image.open(requests.get(url, stream=True).raw)

inputs = processor(text=prompt, images=image, return_tensors="pt")

generate_ids = model.generate(**inputs, max_new_tokens=15)

print( processor.batch_decode( generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False )[0] )


The generation will be cut down like this: 

USER: \nWhat's the content of the image? ASSISTANT: the


This is due to the fact that the implementations of `LlamaAttention` and `LlamaSdpaAttention` are not equivalent.
It can be fixed by align the implementations of `LlamaAttention.forward` with the execution logic of  `torch.nn.functional.scaled_dot_product_attention` like this:
    causal_mask = attention_mask
    is_causal = True if causal_mask is None and q_len > 1 else False
    if attention_mask is not None:  # no matter the length, we just slice it
        causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
        attn_weights = attn_weights + causal_mask
    elif is_causal:
        assert causal_mask is None
        attn_bias = torch.zeros(query_states.shape[-2], key_states.shape[-2], dtype=query_states.dtype)
        temp_mask = torch.ones(query_states.shape[-2], key_states.shape[-2], dtype=torch.bool).tril(diagonal=0)
        attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
        attn_bias.to(query_states.dtype)
        attn_weights += attn_bias
change the implementation from https://github.com/huggingface/transformers/blob/e316c5214fe51de0bf8e824245bfd6225c9925aa/src/transformers/models/llama/modeling_llama.py#L330 to the code above works for me(`transformers` version: 4.43.0.dev0).

### Expected behavior

The output should be like this:

What's the content of the image? ASSISTANT: The image features a street scene with a stop sign, a red building,

qubvel commented 1 month ago

Hi @SDaoer, thanks for reporting the issue!

Agreed, here is the difference:

In SDPA attention if attention_mask is None, then is_causal = True set for scaled_dot_product_attention

        causal_mask = attention_mask
        if attention_mask is not None:
            causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]

        ...

        # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
        # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
        is_causal = True if causal_mask is None and q_len > 1 else False

        attn_output = torch.nn.functional.scaled_dot_product_attention(
            query_states,
            key_states,
            value_states,
            attn_mask=causal_mask,   # <-------   if causal_mask is None, then is_causal=True
            dropout_p=self.attention_dropout if self.training else 0.0,
            is_causal=is_causal,
        )

For eager attention, if attention_mask is None, the causal_mask is not created and the attention mask is not applied


        if attention_mask is not None:  # no matter the length, we just slice it
            causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
            attn_weights = attn_weights + causal_mask

However causal_mask logic is correctly handled in the LLamaModel class and then propagated to the attention layers, thus we do not observe differences in the model itself.

See: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L934

So, to have it correctly handled you should probably avoid modifying output_attentions=True in the LlamaDecoderLayer and instead pass this argument to the model's forward.

In general, it would be nice to have the same behavior for both attention implementations as suggested by @SDaoer

SDaoer commented 1 month ago

@qubvel Thanks for your patience! (,,・ω・,,) I encountered this issue because I used a hook to modify the methods of the module like this:

        ....
        # Self Attention
        hidden_states, self_attn_weights, present_key_value = self.self_attn(
            hidden_states=hidden_states,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_value=past_key_value,
            output_attentions=True,  # @daoer: set `output_attentions=True` to get the attention weights from `LlamaAttention.forward`, instead of `torch.nn.functional.scaled_dot_product_attention`
            use_cache=use_cache,
            cache_position=cache_position,
        )
        hk_self.self_attn_weights_trace.append(self_attn_weights)
        ...

source: https://github.com/huggingface/transformers/blob/fe008d6ebea1f5770b740991daeefd9322fa434a/src/transformers/models/llama/modeling_llama.py#L611

So I didn't pass this parameter in the model but directly modified the parameters in LlamaDecoderLayer. Thank you for your patient response, now I thoroughly understand this matter. ^_^

Hope this issue will be helpful to others with similar needs.

github-actions[bot] commented 2 weeks ago

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.