keras-team / keras-core

A multi-backend implementation of the Keras API, with support for TensorFlow, JAX, and PyTorch.
Apache License 2.0
1.27k stars 115 forks source link

Add `ops.map_coordinates` #906

Closed james77777778 closed 10 months ago

james77777778 commented 10 months ago

Related to keras-team/keras#18442

This PR has implemented ops.map_coordinates for all backends based on the PR from @mihirparadkar keras-team/keras-core#784

It is challenge to obtain a jittable map_coordinates for tensorflow, but I managed to figure out the solution. The key is to use tf.unstack to separate coordinates and form a list of tensor for subsequent operations.

The unit test is borrowed from jax and has been simpified https://github.com/google/jax/blob/bcc545a69232e983ae31b0395f4972979f2789c0/tests/scipy_ndimage_test.py#L79

The standalone script:

import math

import numpy as np

from keras_core.backend.jax.image import map_coordinates as jax_map_coordinates
from keras_core.backend.numpy.image import map_coordinates as np_map_coordinates
from keras_core.backend.tensorflow.image import map_coordinates as tf_map_coordinates
from keras_core.backend.torch.image import map_coordinates as torch_map_coordinates
import tensorflow as tf

np.random.seed(42)
shape = (3, 4, 5)
coords_shape = (2, 3, 4)
dtype = "float32"
x = np.arange(math.prod(shape), dtype=dtype).reshape(shape)
coords = [
    (size - 1) * np.random.uniform(size=coords_shape).astype(dtype)
    for size in shape
]

print("jax:")
print(jax_map_coordinates(x, coords, 1))
print("np:")
print(np_map_coordinates(x, coords, 1))
print("torch:")
print(torch_map_coordinates(x, coords, 1))
print("tf eager:")
print(tf_map_coordinates(x, coords, 1))
print("tf xla:")
print(tf.function(tf_map_coordinates, jit_compile=True)(x, coords, 1))

Results:

```bash Using TensorFlow backend jax: [[[24.009495 50.545628 36.153202 34.760387 ] [18.884958 10.515846 13.828117 40.892403 ] [25.374344 43.34012 15.488769 52.22368 ]] [[39.421623 11.044044 20.851446 15.36548 ] [15.1240015 30.588694 18.357327 28.497757 ] [28.654016 19.465136 19.45043 23.250359 ]]] np: [[[24.009495 50.54563 36.153202 34.76039 ] [18.884958 10.515847 13.828115 40.892403 ] [25.374344 43.340122 15.488769 52.22368 ]] [[39.42162 11.044042 20.851444 15.36548 ] [15.1240015 30.588696 18.357325 28.497759 ] [28.654016 19.465137 19.450432 23.250357 ]]] torch: tensor([[[24.0095, 50.5456, 36.1532, 34.7604], [18.8850, 10.5158, 13.8281, 40.8924], [25.3743, 43.3401, 15.4888, 52.2237]], [[39.4216, 11.0440, 20.8514, 15.3655], [15.1240, 30.5887, 18.3573, 28.4978], [28.6540, 19.4651, 19.4504, 23.2504]]], device='cuda:0') tf eager: tf.Tensor( [[[24.009495 50.545628 36.153202 34.760387 ] [18.884958 10.515846 13.828117 40.892403 ] [25.374344 43.34012 15.488769 52.22368 ]] [[39.421623 11.044044 20.851446 15.36548 ] [15.1240015 30.588694 18.357327 28.497757 ] [28.654016 19.465136 19.45043 23.250359 ]]], shape=(2, 3, 4), dtype=float32) tf xla: tf.Tensor( [[[24.009495 50.545628 36.153202 34.760387 ] [18.884958 10.515846 13.828117 40.892403 ] [25.374344 43.34012 15.488769 52.22368 ]] [[39.421623 11.044044 20.851446 15.36548 ] [15.1240015 30.588694 18.357327 28.497757 ] [28.654016 19.465136 19.45043 23.250359 ]]], shape=(2, 3, 4), dtype=float32) ```
codecov[bot] commented 10 months ago

Codecov Report

Patch coverage: 86.06% and project coverage change: +0.01% :tada:

Comparison is base (b4019bc) 83.63% compared to head (722a9d1) 83.64%.

Additional details and impacted files ```diff @@ Coverage Diff @@ ## main keras-team/keras-core#906 +/- ## ========================================== + Coverage 83.63% 83.64% +0.01% ========================================== Files 318 318 Lines 28391 28556 +165 Branches 5409 5440 +31 ========================================== + Hits 23745 23887 +142 - Misses 3147 3160 +13 - Partials 1499 1509 +10 ``` | [Flag](https://app.codecov.io/gh/keras-team/keras-core/pull/906/flags?src=pr&el=flags&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=keras-team) | Coverage Δ | | |---|---|---| | [keras_core](https://app.codecov.io/gh/keras-team/keras-core/pull/906/flags?src=pr&el=flag&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=keras-team) | `83.54% <86.06%> (+0.01%)` | :arrow_up: | | [keras_core-jax](https://app.codecov.io/gh/keras-team/keras-core/pull/906/flags?src=pr&el=flag&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=keras-team) | `67.29% <15.75%> (-0.30%)` | :arrow_down: | | [keras_core-numpy](https://app.codecov.io/gh/keras-team/keras-core/pull/906/flags?src=pr&el=flag&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=keras-team) | `60.50% <21.21%> (-0.23%)` | :arrow_down: | | [keras_core-tensorflow](https://app.codecov.io/gh/keras-team/keras-core/pull/906/flags?src=pr&el=flag&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=keras-team) | `66.94% <43.03%> (-0.14%)` | :arrow_down: | | [keras_core-torch](https://app.codecov.io/gh/keras-team/keras-core/pull/906/flags?src=pr&el=flag&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=keras-team) | `69.32% <49.09%> (-0.12%)` | :arrow_down: | Flags with carried forward coverage won't be shown. [Click here](https://docs.codecov.io/docs/carryforward-flags?utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=keras-team#carryforward-flags-in-the-pull-request-comment) to find out more. | [Files Changed](https://app.codecov.io/gh/keras-team/keras-core/pull/906?src=pr&el=tree&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=keras-team) | Coverage Δ | | |---|---|---| | [keras\_core/backend/jax/image.py](https://app.codecov.io/gh/keras-team/keras-core/pull/906?src=pr&el=tree&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=keras-team#diff-a2VyYXNfY29yZS9iYWNrZW5kL2pheC9pbWFnZS5weQ==) | `76.00% <42.85%> (-3.42%)` | :arrow_down: | | [keras\_core/backend/numpy/image.py](https://app.codecov.io/gh/keras-team/keras-core/pull/906?src=pr&el=tree&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=keras-team#diff-a2VyYXNfY29yZS9iYWNrZW5kL251bXB5L2ltYWdlLnB5) | `79.06% <71.42%> (-1.49%)` | :arrow_down: | | [keras\_core/ops/image.py](https://app.codecov.io/gh/keras-team/keras-core/pull/906?src=pr&el=tree&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=keras-team#diff-a2VyYXNfY29yZS9vcHMvaW1hZ2UucHk=) | `76.22% <73.68%> (-0.47%)` | :arrow_down: | | [keras\_core/backend/tensorflow/image.py](https://app.codecov.io/gh/keras-team/keras-core/pull/906?src=pr&el=tree&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=keras-team#diff-a2VyYXNfY29yZS9iYWNrZW5kL3RlbnNvcmZsb3cvaW1hZ2UucHk=) | `80.73% <90.47%> (+13.34%)` | :arrow_up: | | [keras\_core/backend/torch/image.py](https://app.codecov.io/gh/keras-team/keras-core/pull/906?src=pr&el=tree&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=keras-team#diff-a2VyYXNfY29yZS9iYWNrZW5kL3RvcmNoL2ltYWdlLnB5) | `78.94% <93.54%> (+8.30%)` | :arrow_up: |

:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Have feedback on the report? Share it here.