AnswerDotAI / bert24

Apache License 2.0
66 stars 4 forks source link

Add support for local attention #95

Closed warner-benjamin closed 3 months ago

warner-benjamin commented 3 months ago

This PR adds support for setting local attention by using Flash Attention's built in sliding window support.

It also allows Gemma 2/Character.ai style alternating local/global attention via the global_attn_every_n_layers setting. For a 22 layer model, global attention every three layers seems to work quite well, with a maximum measured speedup of ~20% if all samples in a batch are max sequence length. In practice the speedup appears to be in the ~5%. There appears to be no loss in model performance from alternating every three layers.

warner-benjamin commented 3 months ago

Merging since I've successfully been training with it.