jonasrauber / eagerpy

PyTorch, TensorFlow, JAX and NumPy — all of them natively using the same code
https://eagerpy.jonasrauber.de
MIT License
693 stars 39 forks source link

TensorFlowTensor.index_update fails for int64/float64 tensors and int/float values #58

Open ymerkli opened 1 year ago

ymerkli commented 1 year ago

TensorFlow's tensor_scatter_nd_update expects the values to be of the same type as the tensor. Currently, this leads to TensorFlowTensor.index_update failing since it uses tf.fill, which produces int32/ float32 tensors, independent of the tensor data type itself.

Reproducible example:

import eagerpy as ep
import tensorflow as tf

x_int32 = ep.astensor(tf.range(4, dtype=tf.int32))
x_int64 = ep.astensor(tf.range(4, dtype=tf.int64))
indices = (ep.astensor(tf.constant([0, 2])),)

x_int32.index_update(indices, 0)  # this works
x_int32.index_update(indices, 0.0)  # this fails
x_int64.index_update(indices, 0)  # this fails
x_int64.index_update(indices, 0.0)  # this fails

Similar for float32/ float64 tensors.

Solution:

Cast values to the dtype of the raw tensor.