alexzhang13 / flashattention2-custom-mask

Triton implementation of FlashAttention2 that adds Custom Masks.
Apache License 2.0
71 stars 6 forks source link

precision issue #7

Open dyhBUPT opened 2 months ago

dyhBUPT commented 2 months ago

Hi, thanks for your work. I'm trying your fa2-cm but it raises error because of the following assertion:

assert do.is_contiguous()
assert q.stride() == k.stride() == v.stride() == o.stride() == do.stride()

I solve this problem by using .contiguous() as follows:

q = q.contiguous()
k = k.contiguous()
v = v.contiguous()
o = o.contiguous()
do = do.contiguous()

Another error is, it only supports fp16 inputs, so I convert q/k/v from fp32 to fp16, and convert the out from fp16 to fp32 after fa2.

After these corrections, I can run fa2-cm normally. However, the results seem bad, because the gradients explode. I want to ask the possible reasons. It's because of my aforementioned modifications?

Looking forward to your reply~

alexzhang13 commented 2 months ago

Hi,

For FP32 support, I need to make a minor edit. This was attempted by someone else earlier, but their code was buggy so I had to revert it. I will make this edit later this week and let you know!

yyhyyh17 commented 1 month ago

Hi,

For FP32 support, I need to make a minor edit. This was attempted by someone else earlier, but their code was buggy so I had to revert it. I will make this edit later this week and let you know!

Do you have any ideas why fp32 will produce wrong results in backward?