FlagOpen / FlagAttention

A collection of memory efficient attention operators implemented in the Triton language.
Other
204 stars 14 forks source link

Optimization of operators #2

Closed iclementine closed 9 months ago

iclementine commented 9 months ago

Optimize for flash_attention and piecewise_attention.

  1. flash_attention's backward computation is now split into two kernels, one for the gradient of k&v, the other for the gradient of q. This brings 4x~5x speedup, though still slower than FlasAttention.
  2. Apply less masking. When an input size is divisible by the tile size, no masking along that dimension is used. We further remove some uncessary masking to avoid tl.wheres.
  3. Prefer tl.math.exp2 over tl.exp and tl.math.exp, since it saves fmas.
  4. disable dot I trick when headdim is 128 since it requires a 128 x 128 matrix I.
  5. tune the tile size & num_stages, num_warps better.
  6. update readmes and include latest benchmark results.