Closed iiSeymour closed 1 year ago
@iiSeymour hey Chris! been a while since that oxford nanopore contracting gig, hope your team is doing well
threw in a quick change in 0.3.1, let me know if the following works
import torch
from rotary_embedding_torch import RotaryEmbedding
rotary_emb = RotaryEmbedding(
dim = 32,
seq_before_head_dim = True
)
q = torch.randn(1, 1024, 8, 64)
k = torch.randn(1, 1024, 8, 64)
q = rotary_emb.rotate_queries_or_keys(q)
k = rotary_emb.rotate_queries_or_keys(k)
Thanks @lucidrains 🙌🏻 yes all good thanks!
Works great and I see a nice speed-up but I'm still seeing one final costly CatArrayBatchedCopy
in the profile which is down to the final torch.cat
in apply_rotary_emb
. For my problem size the following gives a 10% improvement:
def apply_rotary_emb(freqs, t, start_index=0, seq_dim=-2):
rot_dim, seq_len = freqs.shape[-1], t.shape[seq_dim]
freqs = freqs[-seq_len:].to(t)
end_index = start_index + rot_dim
cos_values, sin_values = freqs.cos(), freqs.sin()
rotated_t = rotate_half(t[..., start_index:end_index])
t[..., start_index:end_index] = t[..., start_index:end_index].mul_(cos_values).add_(rotated_t.mul_(sin_values))
return t
Thanks @lucidrains 🙌🏻 yes all good thanks!
Works great and I see a nice speed-up but I'm still seeing one final costly
CatArrayBatchedCopy
in the profile which is down to the finaltorch.cat
inapply_rotary_emb
. For my problem size the following gives a 10% improvement:def apply_rotary_emb(freqs, t, start_index=0, seq_dim=-2): rot_dim, seq_len = freqs.shape[-1], t.shape[seq_dim] freqs = freqs[-seq_len:].to(t) end_index = start_index + rot_dim cos_values, sin_values = freqs.cos(), freqs.sin() rotated_t = rotate_half(t[..., start_index:end_index]) t[..., start_index:end_index] = t[..., start_index:end_index].mul_(cos_values).add_(rotated_t.mul_(sin_values)) return t
does your change work for both training and inference?
@iiSeymour also, are you aware of flash attention 2 and all the rotary embedding fusion it does?
@lucidrains no, I was not, that sounds ideal! I can see flash_attn_func
is a drop-in replacement for memory_efficient_attention
but I don't see how to fuse?
N, T, C = x.shape
QKV = self.in_proj(x).view(N, T, self.nhead, 3 * self.head_dim)
Q, K, V = QKV.chunk(3, dim=-1)
Q = self.rotary_emb.rotate_queries_or_keys(Q)
K = self.rotary_emb.rotate_queries_or_keys(K)
attn_output = memory_efficient_attention(Q, K, V) # -> flash_attn_func(Q, K, V)
x = self.out_proj(attn_output.view(N, T, C))
@iiSeymour flash attention 2 will eventually be in pytorch 2.1, however i'm not sure about their fused rotary embedding offerings. that repository does way more than just attention these days (also has fused rmsnorm, feedforwards, etc)
@iiSeymour recommend you just follow the instructions at that repository, and issues go to that repo of course
Hey @lucidrains 👋🏻
Is it possible to support
(batch, seq len, heads, dimension of head)
for use withmemory_efficient_attention
which wants the sequence length indim=1
to avoid some costly permutes?