lucidrains / PaLM-rlhf-pytorch

Implementation of RLHF (Reinforcement Learning with Human Feedback) on top of the PaLM architecture. Basically ChatGPT but with PaLM
MIT License
7.71k stars 669 forks source link

Flash Attention 2 #54

Closed conceptofmind closed 4 months ago

conceptofmind commented 1 year ago

Hi Phil,

I was wondering what your thoughts on adding Flash Attention 2 are?

n, device, h = x.shape[1], x.device, self.heads

# pre layernorm

x = self.norm(x)

# attention queries, keys, values, and feedforward inner

q, k, v, ff = self.fused_attn_ff_proj(x).split(self.fused_dims, dim=-1)

# split heads
# they use multi-query single-key-value attention, yet another Noam Shazeer paper
# they found no performance loss past a certain scale, and more efficient decoding obviously
# https://arxiv.org/abs/1911.02150

q = rearrange(q, "b n (h d) -> b h n d", h=h)

# rotary embeddings

positions = self.get_rotary_embedding(n, device)
q = apply_rotary_pos_emb(positions, q)
k = apply_rotary_pos_emb(positions, k)

k = rearrange(k, 'b ... -> b 1 ...').expand_as(q)
v = rearrange(v, 'b ... -> b 1 ...').expand_as(q)

"""
q: (batch_size, seqlen, nheads, headdim)
k: (batch_size, seqlen, nheads_k, headdim)
v: (batch_size, seqlen, nheads_k, headdim)
out: (batch_size, seqlen, nheads, headdim).
"""

q, k, v = map(lambda x: rearrange(x, 'b h n d -> b n h d'), (q, k, v)))

attn = flash_attn_func(q, k, v, dropout_p=0.0, softmax_scale=self.scale, causal=True)

# merge heads

out = rearrange(attn, "b n h d -> b n (h d)")
out = self.attn_out(out) + self.ff_out(ff)

Thank you,

Enrico