AnswerDotAI / bert24

Apache License 2.0
60 stars 3 forks source link

Attention fixes #35

Closed warner-benjamin closed 4 months ago

warner-benjamin commented 4 months ago

This PR fixes some attention bugs, adds a config option to use FA2 attn_use_fa2, and adds a test to compare FA2 and SDPA backends. The test allows a 1% error rate in model parameters post training, as there appears to be some non-determinism.

The Glue test errors when calling SDPA with unpadded inputs, but all other unpadded tests pass, so I currently skip Glue SDPA when unpadded with SDPA.

bclavie commented 4 months ago

LGTM. I think SDPA not working with unpadded inputs is extremely minor at the moment, since I doubt there's a large crowd that is both savvy enough to finetune with unpadded inputs and wouldn't want to use flash_attn.

Merging this so we can integrate it to the changes in https://github.com/AnswerDotAI/bert24/pull/36