alipay / PainlessInferenceAcceleration

Creative Commons Attribution 4.0 International
283 stars 18 forks source link

Changing naive attention to SDPA gives wrong result for batched llama example #22

Open learning-chip opened 7 months ago

learning-chip commented 7 months ago

I attempted to swap-in FlashAttention for batched llama, by simply changing self._attn() to self._sdp_attn() inside LlamaAttention.forward():

https://github.com/alipay/PainlessInferenceAcceleration/blob/6280cb2f097ba0bc6bc423ab910b9de7ddbe3bf2/pia/lookahead/models/llama/modeling_llama_batch.py#L372-L375

https://github.com/alipay/PainlessInferenceAcceleration/blob/6280cb2f097ba0bc6bc423ab910b9de7ddbe3bf2/pia/lookahead/models/llama/modeling_llama_batch.py#L404-L407

where _sdp_attn is defined as:

https://github.com/alipay/PainlessInferenceAcceleration/blob/6280cb2f097ba0bc6bc423ab910b9de7ddbe3bf2/pia/lookahead/models/llama/modeling_llama_batch.py#L327-L329

However the model generates wrong result. The original llama_batch_example.py gives:

lookahead:False time:3.326s speed:35.5token/s response:["I'm here to help you.\nI'm here to help you with any questions or problems you might have. I'm a highly advanced AI language model, so I can provide information, answer questions, and even help you with your daily tasks.\n\nIs there something specific you would like to", "about a chicken crossing the road\n\nSure! Here's a classic one:\n\nWhy did the chicken cross the road?\n\nTo get to the other side... of the bar!\n\nI hope you found that one amusing!</s></s></s></s></s></s></s></s></s></s>"]
...

The modified model gives:

lookahead:False time:3.271s speed:39.1token/s response:['the    “ nobody nobody     “ nobody   “ nobody  “ nobody   “ nobody  “ nobody  “ nobody  “ nobody  “ nobody  “ nobody  “ nobody  “ nobody  “ nobody  “ nobody  “ nobody  “ nobody  “ nobody  “ nobody ', 'nobody “ nobody “ nobody “ nobody to nobody. Unterscheidung the. Unterscheidung nobody. Unterscheidung ( ,  “ nobody, MS nobody, MS nobodyMS nobodyMS nobodyMS nobodyMS nobodyMS nobody. Unterscheidung,MS nobodyMS nobody. Unterscheidung,MS nobodyMS nobodyMS nobodyMS nobody. UnterscheidungMS nobodyMS']

So LlamaAttention._attn() is doing something extra other than just standard attention?

learning-chip commented 7 months ago

For the LlamaAttention._attn() implementation:

https://github.com/alipay/PainlessInferenceAcceleration/blob/6280cb2f097ba0bc6bc423ab910b9de7ddbe3bf2/pia/lookahead/models/llama/modeling_llama_batch.py#L299-L325

The self.norm_coef is never defined, and its else branch is never entered. So the code is equivalent to (I checked that code below gives identical result):

    def _attn(self, query, key, value, attention_mask=None, head_mask=None):
        if query.size(0) == 1:
            attn_weights = torch.baddbmm(attention_mask.squeeze(0), query.squeeze(0),
                                         key.squeeze(0).transpose(-1, -2))
        else:
            attn_weights = torch.matmul(key, query.transpose(-1, -2)).transpose(-1, -2)
            attn_weights = attn_weights.add_(attention_mask)

        attn_weights = nn.functional.softmax(attn_weights, dim=-1)
        attn_output = torch.matmul(attn_weights, value)
        return attn_output, attn_weights

which is a very standard masked SDPA... Why is it not equivalent to using scaled_dot_product_attention ?

learning-chip commented 7 months ago

OK I see, the original _attn() function is missing the scaling factor. Setting scale=1.0 for scaled_dot_product_attention fixes the problem.

Here's a simple test:

import torch
import torch.nn.functional as F

def attn(query, key, value, attention_mask=None):
        if query.size(0) == 1:
            attn_weights = torch.baddbmm(attention_mask.squeeze(0), query.squeeze(0),
                                         key.squeeze(0).transpose(-1, -2))
        else:
            attn_weights = torch.matmul(key, query.transpose(-1, -2)).transpose(-1, -2)
            attn_weights = attn_weights.add_(attention_mask)

        attn_weights = F.softmax(attn_weights, dim=-1)
        attn_output = torch.matmul(attn_weights, value)
        return attn_output

def sdp_attn(query, key, value, attention_mask=None):
    return F.scaled_dot_product_attention(
        query, key, value, attn_mask=attention_mask, scale=1.0)

@torch.inference_mode()
def main():
    torch.backends.cuda.enable_mem_efficient_sdp(False)
    torch.backends.cuda.enable_flash_sdp(False)
    torch.backends.cuda.enable_math_sdp(True)

    batch_size = 2
    q_len, kv_len = 9, 9
    head_num = 32
    hidden_size = 128

    q_shape = [batch_size, head_num, q_len, hidden_size]
    kv_shape = [batch_size, head_num, kv_len, hidden_size]
    mask_shape = [batch_size, 1, q_len, kv_len]

    dtype = torch.float16
    device = "cuda"
    query = torch.randn(q_shape, dtype=dtype).to(device)
    key = torch.randn(kv_shape, dtype=dtype).to(device)
    value = torch.randn(kv_shape, dtype=dtype).to(device)
    attention_mask = torch.zeros(mask_shape, dtype=dtype).to(device)

    out = attn(query, key, value, attention_mask)
    sdp_out = sdp_attn(query, key, value, attention_mask)

    print(torch.allclose(out, sdp_out)) # True

if __name__ == "__main__":
    main()
learning-chip commented 7 months ago

Now with

    def _sdp_attn(self, query, key, value, attention_mask=None, head_mask=None):
        with torch.backends.cuda.sdp_kernel(enable_math=False): 
            return F.scaled_dot_product_attention(query, key, value, attn_mask=attention_mask, scale=1.0), None

I can get correct results:

lookahead:False time:3.198s speed:36.9token/s response:["I'm here to help you.\nI'm here to help you with any questions or problems you might have. I'm a highly advanced AI language model, so I can provide information, answer questions, and even help you with your daily tasks.\n\nIs there something specific you would like to", "about a chicken crossing the road\n\nSure! Here's a classic one:\n\nWhy did the chicken cross the road?\n\nTo get to the other side... of the bar!\n\nI hope you found that one amusing!</s></s></s></s></s></s></s></s></s></s>"]

lookahead:False time:1.500s speed:78.7token/s response:["I'm here to help you.\nI'm here to help you with any questions or problems you might have. I'm a highly advanced AI language model, so I can provide information, answer questions, and even help you with your daily tasks.\n\nIs there something specific you would like to", "about a chicken crossing the road\n\nSure! Here's a classic one:\n\nWhy did the chicken cross the road?\n\nTo get to the other side... of the bar!\n\nI hope you found that one amusing!</s></s></s></s></s></s></s></s></s></s>"]

lookahead:True time:1.305s speed:90.4token/s response:["I'm here to help you.\nI'm here to help you with any questions or problems you might have. I'm a highly advanced AI language model, so I can provide information, answer questions, and even help you with your daily tasks.\n\nIs there something specific you would like to", "about a chicken crossing the road\n\nSure! Here's a classic one:\n\nWhy did the chicken cross the road?\n\nTo get to the other side... of the bar!\n\nI hope you found that one amusing!</s></s></s></s></s></s></s></s></s></s>"]

lookahead:True time:0.976s speed:120.9token/s response:["I'm here to help you.\nI'm just an AI, I don't have personal experiences or emotions like humans do, but I'm here to assist you in any way I can. Is there something specific you would like to know or discuss?\n\nPlease let me know if", "about a chicken crossing the road\n\nSure! Here's a classic one:\n\nWhy did the chicken cross the road?\n\nTo get to the other side... of the bar!\n\nI hope you found that one amusing!</s></s></s></s></s></s></s></s></s></s>"]