Closed staghado closed 3 weeks ago
LGTM (although I am not an expert with these kernels, so if someone else can do a review it would be cool).
To me the qkv = self.rotary_emb(qkv, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, seqlen_offset=seqlen_offset)
line could be mutualized between fa2 and not fa2 but besides that it's fine.
I launched a sanity check run to compare the results of the previous implementation and this one.
I added a single rotary kernel path which also passes the tests.
If we want to be a bit more thorough, we could follow flash_attn and test against the GPT-NeoX Transformers code.
The test run looks sane, the throughput is good and it solves the pad/unpad issue, I think we can merge it
@warner-benjamin not really familiar with GPT-NeoX so feel free to add the test.
is it normal that the GLUE test in test_glue.py is failing?
is it normal that the GLUE test in test_glue.py is failing?
No. They should pass. I think #34 broke them.
For the cos/sin caches, we can tie them at the model level so they are stored only as one in memory.
@staghado I'm merging this into main so we can use it for ablations. We can add cos_sin_cache tying in a future PR.
This PR fixes an issue encountered when using RoPE with the pad/unpad logic where all the sequences in a batch are smaller than the global max sequence length. This issue can be fixed by always padding to the maximum sequence length, but there is an even better solution where we apply the rotary embeddings to each sequence without padding.
This brings some speed improvements especially for long sequences and bigger models(up to 10% for bert-base at 4096).
This is done by combining the functionality of ApplyRotaryEmb and ApplyRotaryEmbQKV_. None of these two classes expects unpadded sequences of QKV as an input but they can be adapted to do so. This is what ApplyRotaryEmbUnpad does, it takes a tensor of shape (total_nnz, 3, nheads, headdim) and applies the rotary kernel inplace to the queries and the keys.
This also adds a test in the tests directory which compares the output of the forward and backward passes of the unpadded RoPE against a PyTorch implementation.