lucidrains / rotary-embedding-torch

Implementation of Rotary Embeddings, from the Roformer paper, in Pytorch
MIT License
535 stars 43 forks source link

Support for sequence length ordering #12

Closed iiSeymour closed 1 year ago

iiSeymour commented 1 year ago

Hey @lucidrains 👋🏻

Is it possible to support (batch, seq len, heads, dimension of head) for use with memory_efficient_attention which wants the sequence length in dim=1 to avoid some costly permutes?

lucidrains commented 1 year ago

@iiSeymour hey Chris! been a while since that oxford nanopore contracting gig, hope your team is doing well

threw in a quick change in 0.3.1, let me know if the following works

import torch
from rotary_embedding_torch import RotaryEmbedding

rotary_emb = RotaryEmbedding(
    dim = 32,
    seq_before_head_dim = True
)

q = torch.randn(1, 1024, 8, 64)
k = torch.randn(1, 1024, 8, 64)

q = rotary_emb.rotate_queries_or_keys(q)
k = rotary_emb.rotate_queries_or_keys(k)
iiSeymour commented 1 year ago

Thanks @lucidrains 🙌🏻 yes all good thanks!

Works great and I see a nice speed-up but I'm still seeing one final costly CatArrayBatchedCopy in the profile which is down to the final torch.cat in apply_rotary_emb. For my problem size the following gives a 10% improvement:

def apply_rotary_emb(freqs, t, start_index=0, seq_dim=-2):
    rot_dim, seq_len = freqs.shape[-1], t.shape[seq_dim]
    freqs = freqs[-seq_len:].to(t)
    end_index = start_index + rot_dim
    cos_values, sin_values = freqs.cos(), freqs.sin()
    rotated_t = rotate_half(t[..., start_index:end_index])
    t[..., start_index:end_index] = t[..., start_index:end_index].mul_(cos_values).add_(rotated_t.mul_(sin_values))
    return t
lucidrains commented 1 year ago

Thanks @lucidrains 🙌🏻 yes all good thanks!

Works great and I see a nice speed-up but I'm still seeing one final costly CatArrayBatchedCopy in the profile which is down to the final torch.cat in apply_rotary_emb. For my problem size the following gives a 10% improvement:

def apply_rotary_emb(freqs, t, start_index=0, seq_dim=-2):
    rot_dim, seq_len = freqs.shape[-1], t.shape[seq_dim]
    freqs = freqs[-seq_len:].to(t)
    end_index = start_index + rot_dim
    cos_values, sin_values = freqs.cos(), freqs.sin()
    rotated_t = rotate_half(t[..., start_index:end_index])
    t[..., start_index:end_index] = t[..., start_index:end_index].mul_(cos_values).add_(rotated_t.mul_(sin_values))
    return t

does your change work for both training and inference?

lucidrains commented 1 year ago

@iiSeymour also, are you aware of flash attention 2 and all the rotary embedding fusion it does?

iiSeymour commented 1 year ago

@lucidrains no, I was not, that sounds ideal! I can see flash_attn_func is a drop-in replacement for memory_efficient_attention but I don't see how to fuse?

N, T, C = x.shape
QKV = self.in_proj(x).view(N, T, self.nhead, 3 * self.head_dim)
Q, K, V = QKV.chunk(3, dim=-1)

Q = self.rotary_emb.rotate_queries_or_keys(Q)
K = self.rotary_emb.rotate_queries_or_keys(K)

attn_output = memory_efficient_attention(Q, K, V)  # -> flash_attn_func(Q, K, V)                                                                                                                                                                                           

x = self.out_proj(attn_output.view(N, T, C))
lucidrains commented 1 year ago

@iiSeymour flash attention 2 will eventually be in pytorch 2.1, however i'm not sure about their fused rotary embedding offerings. that repository does way more than just attention these days (also has fused rmsnorm, feedforwards, etc)

lucidrains commented 1 year ago

@iiSeymour recommend you just follow the instructions at that repository, and issues go to that repo of course