keras-team / keras

Deep Learning for humans
http://keras.io/
Apache License 2.0
61.06k stars 19.35k forks source link

Add `training=False` behavior for float8-trained `Dense` and `EinsumDense` #19682

Closed james77777778 closed 1 week ago

james77777778 commented 1 week ago

Related to #19671

This PR introduces training=False behavior for float8-trained Dense and EinsumDense layers.

We could eliminate amax_history and preprocess the weights to bypass the transpose op in compiled graph, but this would make the layers unrecoverable for further training. (We want to continue training in fitting)

Perhaps the post-processing could be considered as a future plan.

codecov-commenter commented 1 week ago

Codecov Report

All modified and coverable lines are covered by tests :white_check_mark:

Project coverage is 69.24%. Comparing base (f12a205) to head (0c0a008).

Additional details and impacted files ```diff @@ Coverage Diff @@ ## master #19682 +/- ## ========================================== - Coverage 78.41% 69.24% -9.18% ========================================== Files 498 498 Lines 45507 45513 +6 Branches 8379 8381 +2 ========================================== - Hits 35686 31516 -4170 - Misses 8091 12402 +4311 + Partials 1730 1595 -135 ``` | [Flag](https://app.codecov.io/gh/keras-team/keras/pull/19682/flags?src=pr&el=flags&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=keras-team) | Coverage Δ | | |---|---|---| | [keras](https://app.codecov.io/gh/keras-team/keras/pull/19682/flags?src=pr&el=flag&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=keras-team) | `69.14% <100.00%> (-9.13%)` | :arrow_down: | | [keras-jax](https://app.codecov.io/gh/keras-team/keras/pull/19682/flags?src=pr&el=flag&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=keras-team) | `?` | | | [keras-numpy](https://app.codecov.io/gh/keras-team/keras/pull/19682/flags?src=pr&el=flag&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=keras-team) | `?` | | | [keras-tensorflow](https://app.codecov.io/gh/keras-team/keras/pull/19682/flags?src=pr&el=flag&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=keras-team) | `63.41% <100.00%> (+<0.01%)` | :arrow_up: | | [keras-torch](https://app.codecov.io/gh/keras-team/keras/pull/19682/flags?src=pr&el=flag&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=keras-team) | `62.06% <100.00%> (+<0.01%)` | :arrow_up: | Flags with carried forward coverage won't be shown. [Click here](https://docs.codecov.io/docs/carryforward-flags?utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=keras-team#carryforward-flags-in-the-pull-request-comment) to find out more.

:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Have feedback on the report? Share it here.