Dao-AILab / flash-attention

Fast and memory-efficient exact attention
BSD 3-Clause "New" or "Revised" License
14.06k stars 1.31k forks source link

Support for Volta arch #68

Open vadimcn opened 2 years ago

vadimcn commented 2 years ago

Would you mind explaining what is the sticking point for supporting Volta arch (sm70) devices? As far as I understand, CUTLASS does support Volta.

tridao commented 2 years ago

The tensor cores (for matmul) on Voltas are different compared to the tensor cores on Turing and Ampere, and the shared memory layout required is different. The xformers team has an implementation of FlashAttention / memory-efficient attention for V100 that you could use.

Each Volta tensor core multiplies 4 matrices of shape 8x4 and 4x4 (i.e. the mma.sync.aligned.m8n8k4 instruction). Each Turing and Ampere tensor core multiply 1 matrix of shape 16x8 and 8x8, or 16x16 and 16x8 (i.e. the mma.sync.aligned.m16n8k8 and mma.sync.aligned.m16n8k16 instruction). This means that inputs to the mma instructions need to be laid out differently in shared memory. This shared memory load/store is the trickiest part. You'll notice that Cutlass always has separate codepath / implementation for volta, while turing/ampere implementation is mostly the same.

vadimcn commented 2 years ago

The xformers team has an implementation of FlashAttention / memory-efficient attention for V100 that you could use.

Thanks, I'll take a look.

I've seen mentions of memory-efficient attention in connection with FlashAttenction before, but could not wrap my head around what's different between the two.
Do they implement different algorithms? How do they stack up in terms of performance and memory usage? Do they support the same features, like unpadding and sparse attention?

tridao commented 2 years ago

We've been collaborating with the xformers team. It's (mostly) the same algorithm, with different implementations. xformers will dispatch to our FlashAttention code when it's available (e.g., on Turing/Ampere) and faster (e.g., backward pass almost always dispatches to our FlashAttention code I think). In other cases it will dispatch to their cutlass-based implementation.

I haven't kept up with what features are exposed in xformers, but you can check out their API and docs.