Closed staghado closed 4 months ago
Also, a question to the other reviewers: Would we want to support of RoPE based on position IDs?
Also, it seems like @staghado has incorporated some code from #35 , we should make sure these changes align fully with #35 once it has been merged
This implements RoPE for BERT. See #3
It uses the rotary embedding kernel from flash_attn. In the unpadded version, we first have to pad -> rope -> unpad, this is cheap especially for the model sizes we are considering.
The unpadded rope version still achieves better throughput than the padded one at large sequence lengths even with the pad/unpad logic which means that the rotary operation is negligible compared to the attention.