keras-team / keras-core

A multi-backend implementation of the Keras API, with support for TensorFlow, JAX, and PyTorch.
Apache License 2.0
1.27k stars 115 forks source link

Refactor torch's `affine_transform` #929

Closed james77777778 closed 10 months ago

james77777778 commented 10 months ago

Since map_coordinates has been implemented in torch, we can now use it to perform affine_transform. Several performance issues have been resolved to align with the original implementation using tnn.grid_sample.

tnn.grid_sample ops.image.map_coordinates
1.702s 0.255s

The codebase will be cleaner and more consistent with other backends after the changes.

codecov[bot] commented 10 months ago

Codecov Report

Patch coverage: 87.50% and project coverage change: +0.02% :tada:

Comparison is base (0aa999b) 83.65% compared to head (227416b) 83.68%.

Additional details and impacted files ```diff @@ Coverage Diff @@ ## main #929 +/- ## ========================================== + Coverage 83.65% 83.68% +0.02% ========================================== Files 318 318 Lines 28666 28637 -29 Branches 5464 5462 -2 ========================================== - Hits 23980 23964 -16 + Misses 3168 3158 -10 + Partials 1518 1515 -3 ``` | [Flag](https://app.codecov.io/gh/keras-team/keras-core/pull/929/flags?src=pr&el=flags&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=keras-team) | Coverage Δ | | |---|---|---| | [keras_core](https://app.codecov.io/gh/keras-team/keras-core/pull/929/flags?src=pr&el=flag&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=keras-team) | `83.58% <87.50%> (+0.02%)` | :arrow_up: | | [keras_core-jax](https://app.codecov.io/gh/keras-team/keras-core/pull/929/flags?src=pr&el=flag&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=keras-team) | `67.17% <0.00%> (+0.06%)` | :arrow_up: | | [keras_core-numpy](https://app.codecov.io/gh/keras-team/keras-core/pull/929/flags?src=pr&el=flag&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=keras-team) | `60.41% <0.00%> (+0.06%)` | :arrow_up: | | [keras_core-tensorflow](https://app.codecov.io/gh/keras-team/keras-core/pull/929/flags?src=pr&el=flag&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=keras-team) | `66.85% <0.00%> (+0.06%)` | :arrow_up: | | [keras_core-torch](https://app.codecov.io/gh/keras-team/keras-core/pull/929/flags?src=pr&el=flag&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=keras-team) | `69.19% <87.50%> (+0.01%)` | :arrow_up: | Flags with carried forward coverage won't be shown. [Click here](https://docs.codecov.io/docs/carryforward-flags?utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=keras-team#carryforward-flags-in-the-pull-request-comment) to find out more. | [Files Changed](https://app.codecov.io/gh/keras-team/keras-core/pull/929?src=pr&el=tree&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=keras-team) | Coverage Δ | | |---|---|---| | [keras\_core/ops/image.py](https://app.codecov.io/gh/keras-team/keras-core/pull/929?src=pr&el=tree&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=keras-team#diff-a2VyYXNfY29yZS9vcHMvaW1hZ2UucHk=) | `76.22% <ø> (ø)` | | | [keras\_core/backend/torch/image.py](https://app.codecov.io/gh/keras-team/keras-core/pull/929?src=pr&el=tree&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=keras-team#diff-a2VyYXNfY29yZS9iYWNrZW5kL3RvcmNoL2ltYWdlLnB5) | `83.80% <87.50%> (+4.85%)` | :arrow_up: |

:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Have feedback on the report? Share it here.