keras-team / keras

Deep Learning for humans
http://keras.io/
Apache License 2.0
62.13k stars 19.49k forks source link

Fix `attention_mask` computation in `MultiHeadAttention` #20488

Closed james77777778 closed 1 week ago

james77777778 commented 1 week ago

This PR fixes an issue in KerasHub where the attention_mask could be incorrectly expanded.

A test has been added to prevent future breakages.

cc @divyashreepathihalli

codecov-commenter commented 1 week ago

Codecov Report

All modified and coverable lines are covered by tests :white_check_mark:

Project coverage is 60.00%. Comparing base (96e07ec) to head (c83e1d2).

:exclamation: There is a different number of reports uploaded between BASE (96e07ec) and HEAD (c83e1d2). Click for more details.

HEAD has 6 uploads less than BASE | Flag | BASE (96e07ec) | HEAD (c83e1d2) | |------|------|------| |keras|4|1| |keras-torch|1|0| |keras-tensorflow|1|0| |keras-jax|1|0|
Additional details and impacted files ```diff @@ Coverage Diff @@ ## master #20488 +/- ## =========================================== - Coverage 82.08% 60.00% -22.08% =========================================== Files 515 515 Lines 47543 47544 +1 Branches 7455 7454 -1 =========================================== - Hits 39024 28529 -10495 - Misses 6709 17264 +10555 + Partials 1810 1751 -59 ``` | [Flag](https://app.codecov.io/gh/keras-team/keras/pull/20488/flags?src=pr&el=flags&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=keras-team) | Coverage Δ | | |---|---|---| | [keras](https://app.codecov.io/gh/keras-team/keras/pull/20488/flags?src=pr&el=flag&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=keras-team) | `60.00% <100.00%> (-21.93%)` | :arrow_down: | | [keras-jax](https://app.codecov.io/gh/keras-team/keras/pull/20488/flags?src=pr&el=flag&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=keras-team) | `?` | | | [keras-numpy](https://app.codecov.io/gh/keras-team/keras/pull/20488/flags?src=pr&el=flag&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=keras-team) | `60.00% <100.00%> (+0.01%)` | :arrow_up: | | [keras-tensorflow](https://app.codecov.io/gh/keras-team/keras/pull/20488/flags?src=pr&el=flag&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=keras-team) | `?` | | | [keras-torch](https://app.codecov.io/gh/keras-team/keras/pull/20488/flags?src=pr&el=flag&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=keras-team) | `?` | | Flags with carried forward coverage won't be shown. [Click here](https://docs.codecov.io/docs/carryforward-flags?utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=keras-team#carryforward-flags-in-the-pull-request-comment) to find out more.

:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Have feedback on the report? Share it here.

divyashreepathihalli commented 1 week ago

Thanks James!