Closed dorpxam closed 1 year ago
From: https://pytorch.org/blog/accelerated-pytorch-2/
PyTorch 2.0 release are the Flash Attention kernel (sdpa_flash, for 16-bit floating point training and inference on Nvidia GPUs with SM80+ architecture level)
I do have a SM80+ GPU, a NVIDIA RTX 4080 (SM89), but unfortunatly by tweeking the A100 GPU:
if device_properties.major >= 8:
instead of
if device_properties.major == 8 and device_properties.minor == 0:
I got a error message: No available kernel. Aborting execution.
on F.scaled_dot_product_attention
A warning was printed just before:
UserWarning: Expected query, key and value to all be of dtype: {Half, BFloat16}. Got Query dtype: float, Key dtype: float, and Value dtype: float instead.
And make sense with the "accelerated-pytorch-2" quote.
I call the code with this model settings and using the same testing code than the README.md
model = BSRoformer(
dim = 384,
depth = 6,
time_transformer_depth = 1,
freq_transformer_depth = 1
)
Maybe I do something wrong?
Additional Information:
By making a simple casting in the Attention
module -> forward()
if exists(self.rotary_embed):
q = self.rotary_embed.rotate_queries_or_keys(q)
k = self.rotary_embed.rotate_queries_or_keys(k)
out = self.attend(q.type(torch.bfloat16), k.type(torch.bfloat16), v.type(torch.bfloat16)).type(x.dtype)
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)
No more error but I presume that the float32 -> bfloat16 -> float32
casting is something bad for the arithmetic precision?
EDIT From: https://cloud.google.com/tpu/docs/bfloat16?hl=en
The dynamic range of bfloat16 and float32 are equivalent. However, bfloat16 takes up half the memory space.
And
Most computations within a deep neural network can accomplish a task with the same accuracy using a lower-precision values. Some models can even reach a higher accuracy with lower-precision values.
Is the BS-RoFormer model compatible with this kind of assertion?
Thank you very much for your code. You rock!
Is Flash Attention only supported by A100 GPU ?