sony / ctm

227 stars 12 forks source link

Mismatch between QKVFlashAttention and QKVAttentionLegacy #5

Closed paulhuangkm closed 5 months ago

paulhuangkm commented 6 months ago

Dear authors,

I am working with V100 GPUs, and unfortunately flash attention does not support V100. I noticed that the dimension manipulations in QKVFlashAttention and QKVAttentionLegacy are different, which makes the provided checkpoints not working on V100 GPUs (or with QKVAttentionLegacy in general).

To resolve this issue, I have updated the dimension manipulation of QKVAttentionLegacy and also included xformers scaled product attention as an additional option. Should I create a pull request? (Left image: original QKVAttentionLegacy, Right image updated QKVAttentionLegacy)

Thanks!

Kim-Dongjun commented 6 months ago

Hi Paul,

That looks fantastic to us. We previously tried to solve this issue but failed due to time constraints. It would be really appreciated if you could create a pull request!

Best, Dongjun

paulhuangkm commented 6 months ago

Hi @Kim-Dongjun ,

I have created a pull request at #6.

Best, Paul

ChiehHsinJesseLai commented 5 months ago

Hi @paulhuangkm,

Thanks for the request. Now merged!

Best, Chieh-Hsin (Jesse)