Open mitchellnw opened 1 year ago
Hi,
I would recommend that you cast your keys/queries to bf16
for best performance.
Hi, @danthe3rd , I've encountered a similar issue where memory_efficient_attention
seems to be running on fp32 precision during training, despite being enclosed within an autocast environment. Although manually casting the qkv does solve the problem, I'm curious if there's a more streamlined solution available. It would be great if the function could be made autocast compatible, automatically handling the casting to either fp16 or bf16. This would provide a more convenient and efficient approach. Thanks!
That's a good point indeed - this would require some changes on our side to achieve this tho. Updating the issue title accordingly
@danthe3rd Any progress?
Hello, I am using
torch.cuda.amp.autocast
withbfloat16
.I noticed that the xformers
RotaryEmbedding
producesfloat32
outputs, which then requires casting before passing tomemory_efficient_attention
.However, this raises the question -- whether to cast the keys and queries to
bfloat16
or the values tofloat32
? I believe it's the former but if possible would appreciate your input to confirm -- thanks!