Closed kyleliang919 closed 1 month ago
These might not be identical, depending on what exactly is being passed in the HF implementation. During decoding, is the HF impl passing in an attention_mask
?
attention mask is passed in as lower triangular mask, with lower triangle part as 0, and upper triangle as -3.3895e+38
, which is -inf.
@Chillee you are right, HF inserted empty tokens when using static cache, which was not properly handled in the above code. After masking those out manually with score_mod, the outputs are normal.
@kyleliang919 would you mind sharing the fix?
@kyleliang919 would you mind sharing the fix?
sure, you need two kinds of masks, one for prefilling context, the other for generative inference as follows. Also another thing you need to pay attention to is, don't use static cache, it turns out huggingface pre-allocate padding tokens in the cache, it will be a nightmare to handle those tokens dynamically. Using Dynamic Cache will fix the problem.
def attn_causal(b, h, q_idx, kv_idx):
causal_mask = q_idx >= kv_idx
return causal_mask
block_mask = create_block_mask(
lambda b, h, q_idx, kv_idx: attn_causal(b, h, q_idx, kv_idx),
B=None, H=None, Q_LEN=query_states.shape[1], KV_LEN=key_states.shape[1]
)
attn_output = flex_attention(query_states, key_states, value_states, block_mask=block_mask)
it turns out huggingface pre-allocate padding tokens in the cache, it will be a nightmare to handle those tokens dynamically
I actually don't think static cache is that hard to handle. You just need to track your current sequence position and have a mask like
def causal_offset(b, h, q_idx, kv_idx):
return offset[b] >= kv_idx
I tried to replace the attention implementation here https://github.com/huggingface/transformers/blob/238b13478df209ab534f2195a397dc64a3930883/src/transformers/models/llama/modeling_llama.py#L419
With
The generation became nonsensical given the same input and temperature, I also compared activation, they seem close enough but not identical.
The only difference in implementation I can think of is the upcasting to 32 bit in the original llama code, but that shouldn't make that much of a difference.
For reference, here is the original HF implementation: