pytorch-labs / attention-gym

Helpful tools and examples for working with flex-attention
BSD 3-Clause "New" or "Revised" License
465 stars 23 forks source link

Replacing attention implementation with FlexAttention seems to break Llama3 inference #46

Closed kyleliang919 closed 1 month ago

kyleliang919 commented 1 month ago

I tried to replace the attention implementation here https://github.com/huggingface/transformers/blob/238b13478df209ab534f2195a397dc64a3930883/src/transformers/models/llama/modeling_llama.py#L419

With

def attn_causal(b, h, q_idx, kv_idx, q_len):
    casual_mask = q_idx >= kv_idx
    if q_len > 1:
        return causal_mask
    else:
        return q_idx >= -1 # when q_len == 1, it means you are generating one token at a time
block_mask = create_block_mask(lambda b, h, q_idx, kv_idx: attn_causal(b, h, q_idx, kv_idx, q_len), B = None, H = None, Q_LEN = q_len, KV_LEN = key_states.shape[-2])
attn_output = flex_attention(query_states, key_states, value_states, block_mask = block_mask)

The generation became nonsensical given the same input and temperature, I also compared activation, they seem close enough but not identical.

The only difference in implementation I can think of is the upcasting to 32 bit in the original llama code, but that shouldn't make that much of a difference.

For reference, here is the original HF implementation:

attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)

if attention_mask is not None:  # no matter the length, we just slice it
    causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
    attn_weights = attn_weights + causal_mask

# upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
attn_output = torch.matmul(attn_weights, value_states)
Chillee commented 1 month ago

These might not be identical, depending on what exactly is being passed in the HF implementation. During decoding, is the HF impl passing in an attention_mask?

kyleliang919 commented 1 month ago

attention mask is passed in as lower triangular mask, with lower triangle part as 0, and upper triangle as -3.3895e+38, which is -inf.

kyleliang919 commented 1 month ago

@Chillee you are right, HF inserted empty tokens when using static cache, which was not properly handled in the above code. After masking those out manually with score_mod, the outputs are normal.

mlaugharn commented 3 weeks ago

@kyleliang919 would you mind sharing the fix?

kyleliang919 commented 3 weeks ago

@kyleliang919 would you mind sharing the fix?

sure, you need two kinds of masks, one for prefilling context, the other for generative inference as follows. Also another thing you need to pay attention to is, don't use static cache, it turns out huggingface pre-allocate padding tokens in the cache, it will be a nightmare to handle those tokens dynamically. Using Dynamic Cache will fix the problem.

def attn_causal(b, h, q_idx, kv_idx):
    causal_mask = q_idx >= kv_idx
    return causal_mask

block_mask = create_block_mask(
                lambda b, h, q_idx, kv_idx: attn_causal(b, h, q_idx, kv_idx),
                B=None, H=None, Q_LEN=query_states.shape[1], KV_LEN=key_states.shape[1]
            )
attn_output = flex_attention(query_states, key_states, value_states, block_mask=block_mask) 
Chillee commented 3 weeks ago

it turns out huggingface pre-allocate padding tokens in the cache, it will be a nightmare to handle those tokens dynamically

I actually don't think static cache is that hard to handle. You just need to track your current sequence position and have a mask like

def causal_offset(b, h, q_idx, kv_idx):
    return offset[b] >= kv_idx