This PR adds the attn_implementation arg mentioned in #169 to allow users to choose between sdpa, eager, and flash_attention_2. If attn_implementation is not specified, it falls back to the original behavior, which uses either sdpa or eager, depending on the PyTorch version.
This PR adds the
attn_implementation
arg mentioned in #169 to allow users to choose betweensdpa
,eager
, andflash_attention_2
. Ifattn_implementation
is not specified, it falls back to the original behavior, which uses eithersdpa
oreager
, depending on the PyTorch version.