facebookresearch / xformers

Hackable and optimized Transformers building blocks, supporting a composable construction.
https://facebookresearch.github.io/xformers/
Other
8k stars 565 forks source link

`memory_efficient_attention` runs in f32 with `autocast` #742

Open mitchellnw opened 1 year ago

mitchellnw commented 1 year ago

Hello, I am using torch.cuda.amp.autocast with bfloat16.

I noticed that the xformers RotaryEmbedding produces float32 outputs, which then requires casting before passing to memory_efficient_attention.

However, this raises the question -- whether to cast the keys and queries to bfloat16 or the values to float32? I believe it's the former but if possible would appreciate your input to confirm -- thanks!

danthe3rd commented 1 year ago

Hi, I would recommend that you cast your keys/queries to bf16 for best performance.

Teoge commented 1 year ago

Hi, @danthe3rd , I've encountered a similar issue where memory_efficient_attention seems to be running on fp32 precision during training, despite being enclosed within an autocast environment. Although manually casting the qkv does solve the problem, I'm curious if there's a more streamlined solution available. It would be great if the function could be made autocast compatible, automatically handling the casting to either fp16 or bf16. This would provide a more convenient and efficient approach. Thanks!

danthe3rd commented 1 year ago

That's a good point indeed - this would require some changes on our side to achieve this tho. Updating the issue title accordingly

Luciennnnnnn commented 1 month ago

@danthe3rd Any progress?