thu-ml / SageAttention

Quantized Attention that achieves speedups of 2.1x and 2.7x compared to FlashAttention2 and xformers, respectively, without lossing end-to-end metrics across various models.
BSD 3-Clause "New" or "Revised" License
318 stars 13 forks source link

Accuracy Comparson in Kernel Level #10

Closed DoubleClark closed 2 weeks ago

DoubleClark commented 2 weeks ago

May i ask how to compare the accuracy in kernel level? I try to compare the accuracy with the fp16 version, but the allclose value seems always to be false [i raise the absulate value and relative value to 0.5] I understand that in quant topic, we usually compare the quant result with the dequant result to ensure the consistency in calculation precision, the fp16 comparsion is quite unfair in this case. Howerver, when i check the element performance in output, the diff is still quite large, but in w4a16 quant in linear case, part of the element performace in quant version still have similarity compared with fp16 version.

May you help correct my mistake? it could bu quite helpful if kernel accuracy validation script could be provided.

batch_size = 1
seq_len = 1024
num_heads = 32
head_dim = 128

scale = math.sqrt(head_dim)

q = torch.empty((batch_size, num_heads, seq_len, head_dim), dtype=torch.float16, device="cuda").normal_(mean=0.0, std=0.5)
k = torch.empty((batch_size, num_heads, seq_len, head_dim), dtype=torch.float16, device="cuda").normal_(mean=0.0, std=0.5)
v = torch.empty((batch_size, num_heads, seq_len, head_dim), dtype=torch.float16, device="cuda").normal_(mean=0.0, std=0.5)

value = F.scaled_dot_product_attention(q, k, v, None, is_causal=False, scale=scale)
value_sage = sageattn(q, k, v, None, is_causal=False, scale=scale, smooth_k=True)
print(torch.allclose(value, value_sage, atol=5e-1, rtol=5e-1))
jt-zhang commented 2 weeks ago

Thank you for reaching out. You could use the following codes revised from yours:

import math, torch
import torch.nn.functional as F
from sageattention import sageattn

batch_size = 1
seq_len = 1024
num_heads = 32
head_dim = 128

q = torch.empty((batch_size, num_heads, seq_len, head_dim), dtype=torch.float16, device="cuda").normal_(mean=0.0, std=0.5)
k = torch.empty((batch_size, num_heads, seq_len, head_dim), dtype=torch.float16, device="cuda").normal_(mean=0.0, std=0.5)
v = torch.empty((batch_size, num_heads, seq_len, head_dim), dtype=torch.float16, device="cuda").normal_(mean=0.0, std=0.5)

value = F.scaled_dot_product_attention(q, k, v, is_causal=False)
value_sage = sageattn(q, k, v, is_causal=False)
print(torch.allclose(value, value_sage, atol=1e-3, rtol=1e-3))
DoubleClark commented 2 weeks ago

Thank you for reaching out. You could use the following codes revised from yours:

import math, torch
import torch.nn.functional as F
from sageattention import sageattn

batch_size = 1
seq_len = 1024
num_heads = 32
head_dim = 128

q = torch.empty((batch_size, num_heads, seq_len, head_dim), dtype=torch.float16, device="cuda").normal_(mean=0.0, std=0.5)
k = torch.empty((batch_size, num_heads, seq_len, head_dim), dtype=torch.float16, device="cuda").normal_(mean=0.0, std=0.5)
v = torch.empty((batch_size, num_heads, seq_len, head_dim), dtype=torch.float16, device="cuda").normal_(mean=0.0, std=0.5)

value = F.scaled_dot_product_attention(q, k, v, is_causal=False)
value_sage = sageattn(q, k, v, is_causal=False)
print(torch.allclose(value, value_sage, atol=1e-3, rtol=1e-3))

Thanks for your reply, it seems that i give a casual scale factor which may result to the numerical overflow, when try to use the default value of scale, the issus seems disapper. Cheers!