TensorFlow's tensor_scatter_nd_update expects the values to be of the same type as the tensor. Currently, this leads to TensorFlowTensor.index_update failing since it uses tf.fill, which produces int32/ float32 tensors, independent of the tensor data type itself.
Reproducible example:
import eagerpy as ep
import tensorflow as tf
x_int32 = ep.astensor(tf.range(4, dtype=tf.int32))
x_int64 = ep.astensor(tf.range(4, dtype=tf.int64))
indices = (ep.astensor(tf.constant([0, 2])),)
x_int32.index_update(indices, 0) # this works
x_int32.index_update(indices, 0.0) # this fails
x_int64.index_update(indices, 0) # this fails
x_int64.index_update(indices, 0.0) # this fails
TensorFlow's
tensor_scatter_nd_update
expects the values to be of the same type as the tensor. Currently, this leads toTensorFlowTensor.index_update
failing since it usestf.fill
, which produces int32/ float32 tensors, independent of the tensor data type itself.Reproducible example:
Similar for float32/ float64 tensors.
Solution:
Cast values to the dtype of the raw tensor.