This PR adds support for setting local attention by using Flash Attention's built in sliding window support.
It also allows Gemma 2/Character.ai style alternating local/global attention via the global_attn_every_n_layers setting. For a 22 layer model, global attention every three layers seems to work quite well, with a maximum measured speedup of ~20% if all samples in a batch are max sequence length. In practice the speedup appears to be in the ~5%. There appears to be no loss in model performance from alternating every three layers.
This PR adds support for setting local attention by using Flash Attention's built in sliding window support.
It also allows Gemma 2/Character.ai style alternating local/global attention via the
global_attn_every_n_layers
setting. For a 22 layer model, global attention every three layers seems to work quite well, with a maximum measured speedup of ~20% if all samples in a batch are max sequence length. In practice the speedup appears to be in the ~5%. There appears to be no loss in model performance from alternating every three layers.