I was wondering what your thoughts on adding Flash Attention 2 are?
n, device, h = x.shape[1], x.device, self.heads
# pre layernorm
x = self.norm(x)
# attention queries, keys, values, and feedforward inner
q, k, v, ff = self.fused_attn_ff_proj(x).split(self.fused_dims, dim=-1)
# split heads
# they use multi-query single-key-value attention, yet another Noam Shazeer paper
# they found no performance loss past a certain scale, and more efficient decoding obviously
# https://arxiv.org/abs/1911.02150
q = rearrange(q, "b n (h d) -> b h n d", h=h)
# rotary embeddings
positions = self.get_rotary_embedding(n, device)
q = apply_rotary_pos_emb(positions, q)
k = apply_rotary_pos_emb(positions, k)
k = rearrange(k, 'b ... -> b 1 ...').expand_as(q)
v = rearrange(v, 'b ... -> b 1 ...').expand_as(q)
"""
q: (batch_size, seqlen, nheads, headdim)
k: (batch_size, seqlen, nheads_k, headdim)
v: (batch_size, seqlen, nheads_k, headdim)
out: (batch_size, seqlen, nheads, headdim).
"""
q, k, v = map(lambda x: rearrange(x, 'b h n d -> b n h d'), (q, k, v)))
attn = flash_attn_func(q, k, v, dropout_p=0.0, softmax_scale=self.scale, causal=True)
# merge heads
out = rearrange(attn, "b n h d -> b n (h d)")
out = self.attn_out(out) + self.ff_out(ff)
Hi Phil,
I was wondering what your thoughts on adding Flash Attention 2 are?
Thank you,
Enrico