AnswerDotAI / bert24

Apache License 2.0
60 stars 3 forks source link

Add RoPE with FlexBert blocks #36

Closed staghado closed 4 months ago

staghado commented 4 months ago

This implements RoPE for BERT. See #3

It uses the rotary embedding kernel from flash_attn. In the unpadded version, we first have to pad -> rope -> unpad, this is cheap especially for the model sizes we are considering.

The unpadded rope version still achieves better throughput than the padded one at large sequence lengths even with the pad/unpad logic which means that the rotary operation is negligible compared to the attention. compare-seqlen-rope

ohallstrom commented 4 months ago

Also, a question to the other reviewers: Would we want to support of RoPE based on position IDs?

ohallstrom commented 4 months ago

Also, it seems like @staghado has incorporated some code from #35 , we should make sure these changes align fully with #35 once it has been merged