Closed warner-benjamin closed 4 months ago
LGTM. I think SDPA not working with unpadded inputs is extremely minor at the moment, since I doubt there's a large crowd that is both savvy enough to finetune with unpadded inputs and wouldn't want to use flash_attn
.
Merging this so we can integrate it to the changes in https://github.com/AnswerDotAI/bert24/pull/36
This PR fixes some attention bugs, adds a config option to use FA2
attn_use_fa2
, and adds a test to compare FA2 and SDPA backends. The test allows a 1% error rate in model parameters post training, as there appears to be some non-determinism.The Glue test errors when calling SDPA with unpadded inputs, but all other unpadded tests pass, so I currently skip Glue SDPA when unpadded with SDPA.