Optimize for flash_attention and piecewise_attention.
flash_attention's backward computation is now split into two kernels, one for the gradient of k&v, the other for the gradient of q. This brings 4x~5x speedup, though still slower than FlasAttention.
Apply less masking. When an input size is divisible by the tile size, no masking along that dimension is used. We further remove some uncessary masking to avoid tl.wheres.
Prefer tl.math.exp2 over tl.exp and tl.math.exp, since it saves fmas.
disable dot I trick when headdim is 128 since it requires a 128 x 128 matrix I.
tune the tile size & num_stages, num_warps better.
update readmes and include latest benchmark results.
Optimize for
flash_attention
andpiecewise_attention
.flash_attention
's backward computation is now split into two kernels, one for the gradient of k&v, the other for the gradient of q. This brings 4x~5x speedup, though still slower than FlasAttention.tl.where
s.tl.math.exp2
overtl.exp
andtl.math.exp
, since it savesfma
s.128 x 128
matrix I.readme
s and include latest benchmark results.