Closed DoubleClark closed 2 weeks ago
Thank you for reaching out. You could use the following codes revised from yours:
import math, torch
import torch.nn.functional as F
from sageattention import sageattn
batch_size = 1
seq_len = 1024
num_heads = 32
head_dim = 128
q = torch.empty((batch_size, num_heads, seq_len, head_dim), dtype=torch.float16, device="cuda").normal_(mean=0.0, std=0.5)
k = torch.empty((batch_size, num_heads, seq_len, head_dim), dtype=torch.float16, device="cuda").normal_(mean=0.0, std=0.5)
v = torch.empty((batch_size, num_heads, seq_len, head_dim), dtype=torch.float16, device="cuda").normal_(mean=0.0, std=0.5)
value = F.scaled_dot_product_attention(q, k, v, is_causal=False)
value_sage = sageattn(q, k, v, is_causal=False)
print(torch.allclose(value, value_sage, atol=1e-3, rtol=1e-3))
Thank you for reaching out. You could use the following codes revised from yours:
import math, torch import torch.nn.functional as F from sageattention import sageattn batch_size = 1 seq_len = 1024 num_heads = 32 head_dim = 128 q = torch.empty((batch_size, num_heads, seq_len, head_dim), dtype=torch.float16, device="cuda").normal_(mean=0.0, std=0.5) k = torch.empty((batch_size, num_heads, seq_len, head_dim), dtype=torch.float16, device="cuda").normal_(mean=0.0, std=0.5) v = torch.empty((batch_size, num_heads, seq_len, head_dim), dtype=torch.float16, device="cuda").normal_(mean=0.0, std=0.5) value = F.scaled_dot_product_attention(q, k, v, is_causal=False) value_sage = sageattn(q, k, v, is_causal=False) print(torch.allclose(value, value_sage, atol=1e-3, rtol=1e-3))
Thanks for your reply, it seems that i give a casual scale factor which may result to the numerical overflow, when try to use the default value of scale, the issus seems disapper. Cheers!
May i ask how to compare the accuracy in kernel level? I try to compare the accuracy with the fp16 version, but the allclose value seems always to be false [i raise the absulate value and relative value to 0.5] I understand that in quant topic, we usually compare the quant result with the dequant result to ensure the consistency in calculation precision, the fp16 comparsion is quite unfair in this case. Howerver, when i check the element performance in output, the diff is still quite large, but in w4a16 quant in linear case, part of the element performace in quant version still have similarity compared with fp16 version.
May you help correct my mistake? it could bu quite helpful if kernel accuracy validation script could be provided.