keras-team / keras

Deep Learning for humans
http://keras.io/
Apache License 2.0
61.29k stars 19.38k forks source link

Backward compatibility issues: arbitrary inputs and overreliance on custom layers #19314

Open AlexanderLutsenko opened 3 months ago

AlexanderLutsenko commented 3 months ago

Hi! Thank you guys for the better, cleaner new Keras! The promise of backend-agnostic models is just fantastic.

Problem is, the new framework seems to have lost some capabilities of its predecessor. Here's an example code which works perfectly in Keras 2:

Keras 2

import tensorflow as tf
from tensorflow import keras
import numpy as np

def TestModel():
   x = keras.Input(batch_shape=(1, None))

   b, l = tf.shape(x)
   y = x[:, l // 2]
   return keras.Model(inputs=x, outputs=y)

input = np.random.normal(size=(1, 10))
model = TestModel()
output = model(input)
print(output)

I want the same in Keras 3 with minimal reliance on custom layers. Alas, the straightforward approach fails.

Keras 3

import tensorflow as tf
from tensorflow import keras
import numpy as np

def TestModel():
    x = keras.Input(batch_shape=(1, None))

    shape = keras.layers.Lambda(lambda x: keras.ops.convert_to_tensor(keras.ops.shape(x)))(input)
    b, l = shape[0], shape[1]
    y = x[:, l // 2]  # ValueError: Attempt to convert a value (Ellipsis) with an unsupported type (<class 'ellipsis'>) to a Tensor.
    return keras.Model(inputs=x, outputs=y)

input = np.random.normal(size=(1, 10))
model = TestModel()
output = model(input)
print(output)
Traceback (most recent call last):
  File "/home/alex/PycharmProjects/nobuco/examples/keras3.py", line 57, in <module>
    model = TestModel()
  File "/home/alex/PycharmProjects/nobuco/examples/keras3.py", line 39, in TestModel
    y = x[:, l // 2]  # ValueError: Attempt to convert a value (Ellipsis) with an unsupported type (<class 'ellipsis'>) to a Tensor.
  File "/home/alex/anaconda3/envs/tf_new/lib/python3.9/site-packages/keras/src/backend/common/keras_tensor.py", line 291, in __getitem__
    return ops.GetItem().symbolic_call(self, key)
  File "/home/alex/anaconda3/envs/tf_new/lib/python3.9/site-packages/keras/src/ops/operation.py", line 51, in symbolic_call
    outputs = self.compute_output_spec(*args, **kwargs)
  File "/home/alex/anaconda3/envs/tf_new/lib/python3.9/site-packages/keras/src/ops/numpy.py", line 2694, in compute_output_spec
    num_ellipses = remaining_key.count(Ellipsis)
  File "/home/alex/anaconda3/envs/tf_new/lib/python3.9/site-packages/tensorflow/python/ops/tensor_math_operator_overrides.py", line 138, in _tensor_equals_factory
    return math_ops.tensor_equals(self, other)
  File "/home/alex/anaconda3/envs/tf_new/lib/python3.9/site-packages/tensorflow/python/util/traceback_utils.py", line 153, in error_handler
    raise e.with_traceback(filtered_tb) from None
  File "/home/alex/anaconda3/envs/tf_new/lib/python3.9/site-packages/tensorflow/python/framework/constant_op.py", line 108, in convert_to_eager_tensor
    return ops.EagerTensor(value, ctx.device_name, dtype)
ValueError: Attempt to convert a value (Ellipsis) with an unsupported type (<class 'ellipsis'>) to a Tensor.

Custom layers do work, but I'd like to avoid them because otherwise I would need to supply custom_objects on model loading.

class Shape(keras.Layer):
    def call(self, x):
        shape = keras.ops.shape(x)
        n_dims = len(shape)
        shape = keras.ops.convert_to_tensor(shape)
        return tuple(shape[i] for i in range(n_dims))

class Slice(keras.Layer):
    def call(self, x, k):
        return x[:, k]

def TestModel():
    x = keras.Input(batch_shape=(1, None))

    b, l = Shape()(x)
    y = Slice()(x, l // 2)
    return keras.Model(inputs=x, outputs=y)

Another drawback is that Keras layers (in contrast to e.g. Pytorch) do not allow arbitrary input signatures. In my example, generalizing Slice layer to the same degree as __getitem__ method is not possible.

class SliceGeneric(keras.Layer):
    def call(self, x, slices):
        return x.__getitem__(slices)

def TestModel():
    x = keras.Input(batch_shape=(1, None))

    b, l = Shape()(x)
    y = SliceGeneric()(x, slices=(slice(None, None, None), l // 2))  # ValueError: In a nested call() argument, you cannot mix tensors and non-tensors. Received invalid mixed argument: slices=(slice(None, None, None), <KerasTensor shape=(), dtype=int32, sparse=False, name=keras_tensor_3>)
    return keras.Model(inputs=x, outputs=y)
Traceback (most recent call last):
  File "/home/alex/PycharmProjects/nobuco/examples/keras3.py", line 57, in <module>
    model = TestModel()
  File "/home/alex/PycharmProjects/nobuco/examples/keras3.py", line 51, in TestModel
    y = SliceGeneric()(x, slices=(slice(None, None, None), l // 2))  # ValueError: In a nested call() argument, you cannot mix tensors and non-tensors. Received invalid mixed argument: slices=(slice(None, None, None), <KerasTensor shape=(), dtype=int32, sparse=False, name=keras_tensor_3>)
  File "/home/alex/anaconda3/envs/tf_new/lib/python3.9/site-packages/keras/src/utils/traceback_utils.py", line 123, in error_handler
    raise e.with_traceback(filtered_tb) from None
  File "/home/alex/anaconda3/envs/tf_new/lib/python3.9/site-packages/keras/src/layers/layer.py", line 1402, in __init__
    raise ValueError(
ValueError: In a nested call() argument, you cannot mix tensors and non-tensors. Received invalid mixed argument: slices=(slice(None, None, None), <KerasTensor shape=(), dtype=int32, sparse=False, name=keras_tensor_3>)

Why it matters

There's a tool that converts Pytorch models to Tensorflow using Keras 2, the approach proved successful. Now I'd like to take it a step further and bring the awesomeness of JAX to the Pytorch crowd.

The converter establishes a one-to-one correspondence between Pytorch modules/ops and the equivalent Keras implementations. Both have the same signatures, except for tensors which are framework-specific.

Below, we traced three Pytorch ops (shape, __floordiv__ and __getitem__) which we then convert to Keras independently from each other. That is why I want generic __getitem__ in Keras 3.

import torch
from torch import nn

class PytorchModule(nn.Module):
    def forward(self, x):
        b, l = x.shape
        y = x[:, l // 2]
        return y

trace

So, for the new Keras' perceived lack of flexibility. Is that considered a flaw or rather a deliberate design choice? Why do some ops only work when wrapped in a layer? Is there a work-around? Any help will be greatly appreciated.

SuryanarayanaY commented 3 months ago

Hi @AlexanderLutsenko ,

Replicated the reported issue with Keras3. Keras2 works fine. Attached gist for reference.

May need to check whether this is a bug or design changes. Thanks

sachinprasadhs commented 3 months ago

cc: @VarunS1997

grasskin commented 3 months ago

Hi @AlexanderLutsenko, I agree this is a pain point with Keras, working with different tensors we haven't yet standardized an indexing method. Ideally we want users to be able to create an ops.array() and have access to indexing in all backends. We are actively exploring solutions to this.

We are aiming to have an operator where for example, the corresponding Keras translation of the pytorch code you added would be y = ops.at(x)[:, l // 2]. Note that for the time being this is experimental and work in progress.

AlexanderLutsenko commented 3 months ago

@grasskin Thanks for the clarity!

One thing I still don't understand is why it works fine inside a custom layer:

class CustomLayer(keras.Layer):
    def call(self, x):
        b, l = keras.ops.shape(x)
        y = x[:, l // 2]
        return y

def TestModel():
    x = keras.Input(batch_shape=(1, None))
    y = CustomLayer()(x)
    return keras.Model(inputs=x, outputs=y)

Can it be made to work without a custom layer? Is that on the to-do list?

pedrofrodenas commented 2 months ago

Hello, thanks AlexanderLutsenko for raising this issue, yes, I am facing the same issues with keras3

cfiliu commented 2 weeks ago

Hi! Alexander's tool is an incredible breakthrough for anyone who wants to use models from different frameworks with Keras. It would be amazing to be able to update it with Keras3 and take full advantage of its potential, so please fix this bug. Thanks!