keras-team / keras-core

A multi-backend implementation of the Keras API, with support for TensorFlow, JAX, and PyTorch.
Apache License 2.0
1.27k stars 115 forks source link

Introduce dtype inference and improve dtype in `ops.numpy.*` #938

Closed james77777778 closed 10 months ago

james77777778 commented 10 months ago

This PR unifies the default dtype behavior in ops.numpy.* and ensures that they respect backend.floatx()

A subtle bug has been caught in dropout_rnn_cell_test.py: We should perform a custom mixed precision check because we can't initialize cell with dtype="mixed_float16" in self.run_layer_test.

EDITED:

WIP:

codecov[bot] commented 10 months ago

Codecov Report

Patch coverage: 69.91% and project coverage change: +11.40% :tada:

Comparison is base (6383d8a) 72.28% compared to head (94daabd) 83.69%. Report is 1 commits behind head on main.

Additional details and impacted files ```diff @@ Coverage Diff @@ ## main #938 +/- ## =========================================== + Coverage 72.28% 83.69% +11.40% =========================================== Files 319 320 +1 Lines 28879 29058 +179 Branches 5529 5579 +50 =========================================== + Hits 20876 24320 +3444 + Misses 6632 3195 -3437 - Partials 1371 1543 +172 ``` | [Flag](https://app.codecov.io/gh/keras-team/keras-core/pull/938/flags?src=pr&el=flags&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=keras-team) | Coverage Δ | | |---|---|---| | [keras_core](https://app.codecov.io/gh/keras-team/keras-core/pull/938/flags?src=pr&el=flag&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=keras-team) | `83.58% <69.91%> (+11.35%)` | :arrow_up: | | [keras_core-jax](https://app.codecov.io/gh/keras-team/keras-core/pull/938/flags?src=pr&el=flag&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=keras-team) | `67.02% <54.23%> (+<0.01%)` | :arrow_up: | | [keras_core-numpy](https://app.codecov.io/gh/keras-team/keras-core/pull/938/flags?src=pr&el=flag&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=keras-team) | `60.51% <56.35%> (?)` | | | [keras_core-tensorflow](https://app.codecov.io/gh/keras-team/keras-core/pull/938/flags?src=pr&el=flag&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=keras-team) | `66.99% <51.69%> (+<0.01%)` | :arrow_up: | | [keras_core-torch](https://app.codecov.io/gh/keras-team/keras-core/pull/938/flags?src=pr&el=flag&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=keras-team) | `68.95% <55.08%> (?)` | | Flags with carried forward coverage won't be shown. [Click here](https://docs.codecov.io/docs/carryforward-flags?utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=keras-team#carryforward-flags-in-the-pull-request-comment) to find out more. | [Files Changed](https://app.codecov.io/gh/keras-team/keras-core/pull/938?src=pr&el=tree&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=keras-team) | Coverage Δ | | |---|---|---| | [keras\_core/backend/common/dtypes.py](https://app.codecov.io/gh/keras-team/keras-core/pull/938?src=pr&el=tree&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=keras-team#diff-a2VyYXNfY29yZS9iYWNrZW5kL2NvbW1vbi9kdHlwZXMucHk=) | `61.76% <61.76%> (ø)` | | | [keras\_core/ops/numpy.py](https://app.codecov.io/gh/keras-team/keras-core/pull/938?src=pr&el=tree&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=keras-team#diff-a2VyYXNfY29yZS9vcHMvbnVtcHkucHk=) | `93.85% <62.85%> (-0.57%)` | :arrow_down: | | [keras\_core/backend/torch/numpy.py](https://app.codecov.io/gh/keras-team/keras-core/pull/938?src=pr&el=tree&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=keras-team#diff-a2VyYXNfY29yZS9iYWNrZW5kL3RvcmNoL251bXB5LnB5) | `94.56% <84.21%> (+94.56%)` | :arrow_up: | | [keras\_core/backend/numpy/numpy.py](https://app.codecov.io/gh/keras-team/keras-core/pull/938?src=pr&el=tree&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=keras-team#diff-a2VyYXNfY29yZS9iYWNrZW5kL251bXB5L251bXB5LnB5) | `97.32% <92.85%> (+97.32%)` | :arrow_up: | | [keras\_core/backend/jax/numpy.py](https://app.codecov.io/gh/keras-team/keras-core/pull/938?src=pr&el=tree&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=keras-team#diff-a2VyYXNfY29yZS9iYWNrZW5kL2pheC9udW1weS5weQ==) | `97.77% <93.33%> (-0.23%)` | :arrow_down: | | [keras\_core/backend/tensorflow/numpy.py](https://app.codecov.io/gh/keras-team/keras-core/pull/938?src=pr&el=tree&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=keras-team#diff-a2VyYXNfY29yZS9iYWNrZW5kL3RlbnNvcmZsb3cvbnVtcHkucHk=) | `93.83% <93.33%> (-0.11%)` | :arrow_down: | | [keras\_core/backend/\_\_init\_\_.py](https://app.codecov.io/gh/keras-team/keras-core/pull/938?src=pr&el=tree&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=keras-team#diff-a2VyYXNfY29yZS9iYWNrZW5kL19faW5pdF9fLnB5) | `95.12% <100.00%> (+25.12%)` | :arrow_up: | | [keras\_core/backend/common/\_\_init\_\_.py](https://app.codecov.io/gh/keras-team/keras-core/pull/938?src=pr&el=tree&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=keras-team#diff-a2VyYXNfY29yZS9iYWNrZW5kL2NvbW1vbi9fX2luaXRfXy5weQ==) | `100.00% <100.00%> (ø)` | | ... and [45 files with indirect coverage changes](https://app.codecov.io/gh/keras-team/keras-core/pull/938/indirect-changes?src=pr&el=tree-more&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=keras-team)

:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Have feedback on the report? Share it here.

fchollet commented 10 months ago

Also, we should start testing consistency between symbolic outputs and real op outputs. That's a of checks over all, so it would justify the introduce of a new TestCase for dtypes.

james77777778 commented 10 months ago

We should add dtype checks in unit tests for all operations affected here, to check that we're in fact getting the same dtype across backends, including for array. I think there might be ops where some backends will return float64 instead of float32. This will help us avoid inconsistencies.

Also, we should start testing consistency between symbolic outputs and real op outputs. That's a of checks over all, so it would justify the introduce of a new TestCase for dtypes.

I can add some new test cases in keras_core/ops/numpy_test.py:

class NumpySymbolicDtypeTest(testing.TestCase):
    ...

class NumpyTensorDtypeTest(testing.TestCase):
    ...

Is it good?

However, It may take some time to implement the result_dtype-like function for all backends.

fchollet commented 10 months ago

class NumpyTensorDtypeTest(testing.TestCase):

Yes, that sounds good!

However, It may take some time to implement the result_dtype-like function for all backends.

We may be able to use a test parameterization to save time/code. We can parameterize the input dtype, for instance. But in some cases we may also be able to parameterize the op functions, for groups of ops that have similar arguments.

james77777778 commented 10 months ago

Hi @fchollet

I want to verify whether this PR is on the right track.

I am attempting to implement a Keras Core version of result_dtype in keras_core/backend/common/dtypes.py. Currently, the result matchs jnp.result_dtype when the input is python scalar types (as demonstrated in keras_core/backend/common/dtypes_test.py)

If it is good, I will refactor some of ops.numpy.* and add the previously mentioned tests.

fchollet commented 10 months ago

Keras Core is becoming Keras 3, and we're switching development to the main repository! Please reopen this PR in the keras-team/keras repository. Unfortunately we aren't able to automatically transfer PRs (but we have transferred all issues).