AnswerDotAI / bert24

Apache License 2.0
25 stars 3 forks source link

Rotary Embedding for unpadded sequences #55

Closed staghado closed 3 weeks ago

staghado commented 1 month ago

This PR fixes an issue encountered when using RoPE with the pad/unpad logic where all the sequences in a batch are smaller than the global max sequence length. This issue can be fixed by always padding to the maximum sequence length, but there is an even better solution where we apply the rotary embeddings to each sequence without padding.

This brings some speed improvements especially for long sequences and bigger models(up to 10% for bert-base at 4096).

This is done by combining the functionality of ApplyRotaryEmb and ApplyRotaryEmbQKV_. None of these two classes expects unpadded sequences of QKV as an input but they can be adapted to do so. This is what ApplyRotaryEmbUnpad does, it takes a tensor of shape (total_nnz, 3, nheads, headdim) and applies the rotary kernel inplace to the queries and the keys.

This also adds a test in the tests directory which compares the output of the forward and backward passes of the unpadded RoPE against a PyTorch implementation.

NohTow commented 4 weeks ago

LGTM (although I am not an expert with these kernels, so if someone else can do a review it would be cool). To me the qkv = self.rotary_emb(qkv, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, seqlen_offset=seqlen_offset) line could be mutualized between fa2 and not fa2 but besides that it's fine.

I launched a sanity check run to compare the results of the previous implementation and this one.

warner-benjamin commented 4 weeks ago

I added a single rotary kernel path which also passes the tests.

If we want to be a bit more thorough, we could follow flash_attn and test against the GPT-NeoX Transformers code.

NohTow commented 3 weeks ago

The test run looks sane, the throughput is good and it solves the pad/unpad issue, I think we can merge it

staghado commented 3 weeks ago

@warner-benjamin not really familiar with GPT-NeoX so feel free to add the test.

staghado commented 3 weeks ago

is it normal that the GLUE test in test_glue.py is failing?

warner-benjamin commented 3 weeks ago

is it normal that the GLUE test in test_glue.py is failing?

No. They should pass. I think #34 broke them.

staghado commented 3 weeks ago

For the cos/sin caches, we can tie them at the model level so they are stored only as one in memory.

warner-benjamin commented 3 weeks ago

@staghado I'm merging this into main so we can use it for ablations. We can add cos_sin_cache tying in a future PR.