lucidrains / BS-RoFormer

Implementation of Band Split Roformer, SOTA Attention network for music source separation out of ByteDance AI Labs
MIT License
384 stars 13 forks source link

Flash Attention support #8

Closed dorpxam closed 10 months ago

dorpxam commented 10 months ago

Thank you very much for your code. You rock!

Is Flash Attention only supported by A100 GPU ?

dorpxam commented 10 months ago

From: https://pytorch.org/blog/accelerated-pytorch-2/

PyTorch 2.0 release are the Flash Attention kernel (sdpa_flash, for 16-bit floating point training and inference on Nvidia GPUs with SM80+ architecture level)

I do have a SM80+ GPU, a NVIDIA RTX 4080 (SM89), but unfortunatly by tweeking the A100 GPU:

if device_properties.major >= 8:

instead of

if device_properties.major == 8 and device_properties.minor == 0:

I got a error message: No available kernel. Aborting execution. on F.scaled_dot_product_attention

A warning was printed just before:

UserWarning: Expected query, key and value to all be of dtype: {Half, BFloat16}. Got Query dtype: float, Key dtype: float, and Value dtype: float instead.

And make sense with the "accelerated-pytorch-2" quote.

I call the code with this model settings and using the same testing code than the README.md

model = BSRoformer(
    dim = 384,
    depth = 6,
    time_transformer_depth = 1,
    freq_transformer_depth = 1
)

Maybe I do something wrong?

dorpxam commented 10 months ago

Additional Information:

By making a simple casting in the Attention module -> forward()

if exists(self.rotary_embed):
      q = self.rotary_embed.rotate_queries_or_keys(q)
      k = self.rotary_embed.rotate_queries_or_keys(k)

out = self.attend(q.type(torch.bfloat16), k.type(torch.bfloat16), v.type(torch.bfloat16)).type(x.dtype)

out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)

No more error but I presume that the float32 -> bfloat16 -> float32 casting is something bad for the arithmetic precision?

EDIT From: https://cloud.google.com/tpu/docs/bfloat16?hl=en

The dynamic range of bfloat16 and float32 are equivalent. However, bfloat16 takes up half the memory space.

And

Most computations within a deep neural network can accomplish a task with the same accuracy using a lower-precision values. Some models can even reach a higher accuracy with lower-precision values.

Is the BS-RoFormer model compatible with this kind of assertion?