keras-team / keras

Deep Learning for humans
http://keras.io/
Apache License 2.0
61.62k stars 19.42k forks source link

Support advanced ND scatter/reduce operations #20061

Open aboubezari opened 1 month ago

aboubezari commented 1 month ago

Proposal

Support operations that scatter updates into a tensor with the following requirements:

Essentially, support tensorflow's tensor_scatter_nd_* operations, like:

Design discussion

API design

Currently, there is already a keras.ops.scatter_update function. This only supports a default value of 0 for the tensor, and doesn't support reductions.

Option 1: Upgrade keras.ops.scatter_update to take in optional tensor and reduction arguments.

Option 2: Create a dedicated tensor_scatter_nd api for these new functions.

Technical implementation

Overall the technical implementation is straightfoward, just flatten indices to the 1D case and invoke the segmentation functions of each backend (e.g. segment_max for TF and scatter_reduce for torch).

Implementation plan

I have already implemented a tensor_scatter_nd function for Nuro's internal use and it works well for all our cases (and is XLA compatible for all backends). Once we settle on a design, I can upstream the implementation.

aboubezari commented 1 month ago

cc @haohuanw

haohuanw commented 1 month ago

@fchollet let us know which option you prefer on these ops. i think scatter_update is closer to what jax and torch has and i am slightly prefer to add support on scatter_update