Open SDaoer opened 1 month ago
Hi @SDaoer, thanks for reporting the issue!
Agreed, here is the difference:
In SDPA attention if attention_mask
is None
, then is_causal = True
set for scaled_dot_product_attention
causal_mask = attention_mask
if attention_mask is not None:
causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]
...
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
is_causal = True if causal_mask is None and q_len > 1 else False
attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states,
key_states,
value_states,
attn_mask=causal_mask, # <------- if causal_mask is None, then is_causal=True
dropout_p=self.attention_dropout if self.training else 0.0,
is_causal=is_causal,
)
For eager attention, if attention_mask
is None
, the causal_mask
is not created and the attention mask is not applied
if attention_mask is not None: # no matter the length, we just slice it
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
attn_weights = attn_weights + causal_mask
However causal_mask
logic is correctly handled in the LLamaModel
class and then propagated to the attention layers, thus we do not observe differences in the model itself.
So, to have it correctly handled you should probably avoid modifying output_attentions=True
in the LlamaDecoderLayer
and instead pass this argument to the model's forward.
In general, it would be nice to have the same behavior for both attention implementations as suggested by @SDaoer
@qubvel Thanks for your patience! (,,・ω・,,) I encountered this issue because I used a hook to modify the methods of the module like this:
....
# Self Attention
hidden_states, self_attn_weights, present_key_value = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=True, # @daoer: set `output_attentions=True` to get the attention weights from `LlamaAttention.forward`, instead of `torch.nn.functional.scaled_dot_product_attention`
use_cache=use_cache,
cache_position=cache_position,
)
hk_self.self_attn_weights_trace.append(self_attn_weights)
...
So I didn't pass this parameter in the model but directly modified the parameters in LlamaDecoderLayer
. Thank you for your patient response, now I thoroughly understand this matter. ^_^
Hope this issue will be helpful to others with similar needs.
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.
Please note that issues that do not follow the contributing guidelines are likely to be ignored.
System Info
transformers
version: 4.43.0.dev0Who can help?
@ArthurZucker
Information
Tasks
examples
folder (such as GLUE/SQuAD, ...)Reproduction
output_attentions=output_attentions
to beoutput_attentions=True
at https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L617model = LlavaForConditionalGeneration.from_pretrained("llava-hf/llava-1.5-7b-hf") processor = AutoProcessor.from_pretrained("llava-hf/llava-1.5-7b-hf")
prompt = "USER:\nWhat's the content of the image? ASSISTANT:"
url = "https://www.ilankelman.org/stopsigns/australia.jpg"
image = Image.open(requests.get(url, stream=True).raw)
inputs = processor(text=prompt, images=image, return_tensors="pt")
generate_ids = model.generate(**inputs, max_new_tokens=15)
print( processor.batch_decode( generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False )[0] )
USER: \nWhat's the content of the image? ASSISTANT: the
What's the content of the image? ASSISTANT: The image features a street scene with a stop sign, a red building,