Open learning-chip opened 7 months ago
For the LlamaAttention._attn()
implementation:
The self.norm_coef
is never defined, and its else
branch is never entered. So the code is equivalent to (I checked that code below gives identical result):
def _attn(self, query, key, value, attention_mask=None, head_mask=None):
if query.size(0) == 1:
attn_weights = torch.baddbmm(attention_mask.squeeze(0), query.squeeze(0),
key.squeeze(0).transpose(-1, -2))
else:
attn_weights = torch.matmul(key, query.transpose(-1, -2)).transpose(-1, -2)
attn_weights = attn_weights.add_(attention_mask)
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
attn_output = torch.matmul(attn_weights, value)
return attn_output, attn_weights
which is a very standard masked SDPA... Why is it not equivalent to using scaled_dot_product_attention
?
OK I see, the original _attn()
function is missing the scaling factor. Setting scale=1.0
for scaled_dot_product_attention
fixes the problem.
Here's a simple test:
import torch
import torch.nn.functional as F
def attn(query, key, value, attention_mask=None):
if query.size(0) == 1:
attn_weights = torch.baddbmm(attention_mask.squeeze(0), query.squeeze(0),
key.squeeze(0).transpose(-1, -2))
else:
attn_weights = torch.matmul(key, query.transpose(-1, -2)).transpose(-1, -2)
attn_weights = attn_weights.add_(attention_mask)
attn_weights = F.softmax(attn_weights, dim=-1)
attn_output = torch.matmul(attn_weights, value)
return attn_output
def sdp_attn(query, key, value, attention_mask=None):
return F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, scale=1.0)
@torch.inference_mode()
def main():
torch.backends.cuda.enable_mem_efficient_sdp(False)
torch.backends.cuda.enable_flash_sdp(False)
torch.backends.cuda.enable_math_sdp(True)
batch_size = 2
q_len, kv_len = 9, 9
head_num = 32
hidden_size = 128
q_shape = [batch_size, head_num, q_len, hidden_size]
kv_shape = [batch_size, head_num, kv_len, hidden_size]
mask_shape = [batch_size, 1, q_len, kv_len]
dtype = torch.float16
device = "cuda"
query = torch.randn(q_shape, dtype=dtype).to(device)
key = torch.randn(kv_shape, dtype=dtype).to(device)
value = torch.randn(kv_shape, dtype=dtype).to(device)
attention_mask = torch.zeros(mask_shape, dtype=dtype).to(device)
out = attn(query, key, value, attention_mask)
sdp_out = sdp_attn(query, key, value, attention_mask)
print(torch.allclose(out, sdp_out)) # True
if __name__ == "__main__":
main()
Now with
def _sdp_attn(self, query, key, value, attention_mask=None, head_mask=None):
with torch.backends.cuda.sdp_kernel(enable_math=False):
return F.scaled_dot_product_attention(query, key, value, attn_mask=attention_mask, scale=1.0), None
I can get correct results:
lookahead:False time:3.198s speed:36.9token/s response:["I'm here to help you.\nI'm here to help you with any questions or problems you might have. I'm a highly advanced AI language model, so I can provide information, answer questions, and even help you with your daily tasks.\n\nIs there something specific you would like to", "about a chicken crossing the road\n\nSure! Here's a classic one:\n\nWhy did the chicken cross the road?\n\nTo get to the other side... of the bar!\n\nI hope you found that one amusing!</s></s></s></s></s></s></s></s></s></s>"]
lookahead:False time:1.500s speed:78.7token/s response:["I'm here to help you.\nI'm here to help you with any questions or problems you might have. I'm a highly advanced AI language model, so I can provide information, answer questions, and even help you with your daily tasks.\n\nIs there something specific you would like to", "about a chicken crossing the road\n\nSure! Here's a classic one:\n\nWhy did the chicken cross the road?\n\nTo get to the other side... of the bar!\n\nI hope you found that one amusing!</s></s></s></s></s></s></s></s></s></s>"]
lookahead:True time:1.305s speed:90.4token/s response:["I'm here to help you.\nI'm here to help you with any questions or problems you might have. I'm a highly advanced AI language model, so I can provide information, answer questions, and even help you with your daily tasks.\n\nIs there something specific you would like to", "about a chicken crossing the road\n\nSure! Here's a classic one:\n\nWhy did the chicken cross the road?\n\nTo get to the other side... of the bar!\n\nI hope you found that one amusing!</s></s></s></s></s></s></s></s></s></s>"]
lookahead:True time:0.976s speed:120.9token/s response:["I'm here to help you.\nI'm just an AI, I don't have personal experiences or emotions like humans do, but I'm here to assist you in any way I can. Is there something specific you would like to know or discuss?\n\nPlease let me know if", "about a chicken crossing the road\n\nSure! Here's a classic one:\n\nWhy did the chicken cross the road?\n\nTo get to the other side... of the bar!\n\nI hope you found that one amusing!</s></s></s></s></s></s></s></s></s></s>"]
I attempted to swap-in FlashAttention for batched llama, by simply changing
self._attn()
toself._sdp_attn()
insideLlamaAttention.forward()
:https://github.com/alipay/PainlessInferenceAcceleration/blob/6280cb2f097ba0bc6bc423ab910b9de7ddbe3bf2/pia/lookahead/models/llama/modeling_llama_batch.py#L372-L375
https://github.com/alipay/PainlessInferenceAcceleration/blob/6280cb2f097ba0bc6bc423ab910b9de7ddbe3bf2/pia/lookahead/models/llama/modeling_llama_batch.py#L404-L407
where
_sdp_attn
is defined as:https://github.com/alipay/PainlessInferenceAcceleration/blob/6280cb2f097ba0bc6bc423ab910b9de7ddbe3bf2/pia/lookahead/models/llama/modeling_llama_batch.py#L327-L329
However the model generates wrong result. The original
llama_batch_example.py
gives:The modified model gives:
So
LlamaAttention._attn()
is doing something extra other than just standard attention?