keras-team / keras

Deep Learning for humans
http://keras.io/
Apache License 2.0
61.06k stars 19.35k forks source link

Add `scan` op #19681

Closed james77777778 closed 1 week ago

james77777778 commented 1 week ago

Related to #19519

Originally, tf.scan stacks the output of the f as the final output and doesn't support the separated carry throughout the loop iterations. I've reimplemented scan for tensorflow backend to align with the behavior of jax.lax.scan.

However, the current implementation requires that the output (y) of f must have the same shape/dtype as carry. This is not required in other backends.

codecov-commenter commented 1 week ago

Codecov Report

Attention: Patch coverage is 96.62162% with 5 lines in your changes are missing coverage. Please review.

Project coverage is 78.49%. Comparing base (10c27c0) to head (fbab23e).

Files Patch % Lines
keras/src/backend/tensorflow/core.py 92.30% 2 Missing and 2 partials :warning:
keras/api/_tf_keras/keras/ops/__init__.py 0.00% 1 Missing :warning:
Additional details and impacted files ```diff @@ Coverage Diff @@ ## master #19681 +/- ## ========================================== + Coverage 78.42% 78.49% +0.07% ========================================== Files 498 498 Lines 45551 45699 +148 Branches 8394 8448 +54 ========================================== + Hits 35723 35872 +149 + Misses 8094 8093 -1 Partials 1734 1734 ``` | [Flag](https://app.codecov.io/gh/keras-team/keras/pull/19681/flags?src=pr&el=flags&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=keras-team) | Coverage Δ | | |---|---|---| | [keras](https://app.codecov.io/gh/keras-team/keras/pull/19681/flags?src=pr&el=flag&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=keras-team) | `78.34% <96.62%> (+0.07%)` | :arrow_up: | | [keras-jax](https://app.codecov.io/gh/keras-team/keras/pull/19681/flags?src=pr&el=flag&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=keras-team) | `61.94% <19.59%> (-0.13%)` | :arrow_down: | | [keras-numpy](https://app.codecov.io/gh/keras-team/keras/pull/19681/flags?src=pr&el=flag&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=keras-team) | `56.30% <39.86%> (-0.05%)` | :arrow_down: | | [keras-tensorflow](https://app.codecov.io/gh/keras-team/keras/pull/19681/flags?src=pr&el=flag&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=keras-team) | `63.38% <48.64%> (-0.05%)` | :arrow_down: | | [keras-torch](https://app.codecov.io/gh/keras-team/keras/pull/19681/flags?src=pr&el=flag&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=keras-team) | `62.01% <39.18%> (-0.08%)` | :arrow_down: | Flags with carried forward coverage won't be shown. [Click here](https://docs.codecov.io/docs/carryforward-flags?utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=keras-team#carryforward-flags-in-the-pull-request-comment) to find out more.

:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Have feedback on the report? Share it here.