keras-team / keras

Deep Learning for humans
http://keras.io/
Apache License 2.0
61.73k stars 19.44k forks source link

Problem with framework agnostic KerasVariable slicing with another KerasVariable #18439

Open egesko opened 1 year ago

egesko commented 1 year ago

I defined a KerasVariable with shape (n,d) in a keras.Layer() using self.add_weight(). I've also defined another KerasVariable with shape (1) , dtype="int32", and value 0.

self.first_variable = self.add_weight(
    initializer="zeros", shape=(self.N,input_shape[-1]), trainable=False
)
self.second_variable = self.add_weight(initializer="zeros",shape=(1), trainable=False, dtype="int32")

During a call to this custom layer, I'm trying to retrieve a specific index of the first variable using the 2nd variable with:

self.first_variable[self.second_variable.value]

This works as expected in pytorch backend, but throws an error in tensorflow backend.

Only integers, slices (`:`), ellipsis (`...`), tf.newaxis (`None`) and scalar tf.int32/tf.int64 tensors are valid indices, got <tf.Variable 'custom_layer/variable_1:0' shape=(1,) dtype=int32>

Arguments received by CustomLayer.call():
  • x=tf.Tensor(shape=(None, 1600), dtype=float32)
  • training=True
egesko commented 1 year ago

Ideally, I'm looking for a framework agnostic solution to this, i.e. not implementing separate logic that depends on config.backend. Is this possible?

fchollet commented 1 year ago

You should be able to do this using the ops.slice operation.

egesko commented 1 year ago

@fchollet

Slightly related question. In layer.call():

batch_size = keras.ops.shape(x)[0]
indices = keras.ops.arange(self.second_variable, stop=self.second_variable+batch_size, step=1, dtype="int32")

Throws the error:

TypeError: arange() received an invalid combination of arguments - got (Variable, Tensor, device=str, dtype=torch.dtype, step=int), but expected one of:
 * (Number end, *, Tensor out, torch.dtype dtype, torch.layout layout, torch.device device, bool pin_memory, bool requires_grad)
 * (Number start, Number end, *, torch.dtype dtype, torch.layout layout, torch.device device, bool pin_memory, bool requires_grad)
 * (Number start, Number end, Number step, *, Tensor out, torch.dtype dtype, torch.layout layout, torch.device device, bool pin_memory, bool requires_grad)

Again, self.second_variable is a kerasVariable of shape (1,) and dtype "int32". Is there a way to convert values of KerasVariable into Number?

fchollet commented 1 year ago

In backend/torch/numpy.py, function def arange(start, stop=None, step=1, dtype=None):, IMO we should be calling convert_to_tensor on start and stop and add a test that checks variables inputs.

fchollet commented 1 year ago

Can you open a PR?

egesko commented 1 year ago

Working on it now. Thanks for the quick reply!