keras-team / keras-core

A multi-backend implementation of the Keras API, with support for TensorFlow, JAX, and PyTorch.
Apache License 2.0
1.27k stars 115 forks source link

How can we use `ops.reshape` when the `new_shape` is obtained from `ops.shape`? #835

Closed james77777778 closed 1 year ago

james77777778 commented 1 year ago

It failed when using ops.reshape in model construction with dynamic shapes. When using tf.keras, the same script ran without any issue.

import keras_core
import tensorflow as tf

"""
keras_core: Failed
"""

inputs = keras_core.layers.Input(shape=(None, None, 3))
b, h, w, c = keras_core.ops.shape(inputs)
x = keras_core.ops.reshape(inputs, (b, h, w, c, 1))
model = keras_core.models.Model(inputs=inputs, outputs=x)

x = keras_core.ops.random.uniform(shape=(1, 28, 28, 3))
y = model(x)
print(y.shape)

"""
tf.keras: Successed
"""

inputs = tf.keras.layers.Input(shape=(None, None, 3))
b, h, w, c = tf.shape(inputs)
x = tf.reshape(inputs, (b, h, w, c, 1))
model = tf.keras.models.Model(inputs=inputs, outputs=x)

x = tf.random.uniform(shape=(1, 28, 28, 3))
y = model(x)
print(y.shape)

I believe there are two issues:

  1. ops.shape returns tuple of ints or None when using symbolic tensor
  2. new_shape in ops.reshape might contain None (and this caused error)

The solution might be using tf.shape for ops.reshape, but I'm not sure how to implement the fix.

Colab: https://colab.research.google.com/drive/1UYaIxAqmHSUiR4nJyULya8w14Ha19rhB?usp=sharing

james77777778 commented 1 year ago

Just came up with a workaround after posing:

import keras_core
import tensorflow as tf
from keras_core import layers

"""
Workaround
"""
class CustomReshape(layers.Layer):
    def __init__(self, name=None):
        super().__init__(name=name)

    def compute_output_shape(self, input_shape):
        return (*input_shape, 1)

    def call(self, x):
        b, h, w, c = tf.shape(x)
        return tf.reshape(x, (b, h, w, c, 1))

inputs = keras_core.layers.Input(shape=(None, None, 3))
x = CustomReshape()(inputs)
model = keras_core.models.Model(inputs=inputs, outputs=x)

x = tf.random.uniform(shape=(1, 28, 28, 3))
y = model(x)
print(y.shape)
fchollet commented 1 year ago

The custom layer is the right idea.

ops.reshape works with ops.shape, but only with actual tensors -- not just symbolic tensors, that is to say Input objects. That means that a tensor's shape can be always be known inside the call method of a layer, which is why querying ops.shape and calling reshape inside call() works fine.

The reason it worked with tf.keras is that in tf.keras, an Input is backed by a TF tensor. This is not the case with Keras Core, Input is just a standalone Python object and its shape may contain None entries.

james77777778 commented 1 year ago

Thanks for the detailed clarification.

I provide the following backend-agnostic workaround in case anyone encounters this issue

import keras_core
from keras_core import layers

"""
Workaround: ops.reshape works with ops.shape with actual tensors
"""

class CustomReshape(layers.Layer):
    def __init__(self, name=None):
        super().__init__(name=name)

    def compute_output_shape(self, input_shape):
        return (*input_shape, 1)

    def call(self, x):
        b, h, w, c = keras_core.ops.shape(x)
        return keras_core.ops.reshape(x, (b, h, w, c, 1))

inputs = keras_core.layers.Input(shape=(None, None, 3))
x = CustomReshape()(inputs)
model = keras_core.models.Model(inputs=inputs, outputs=x)

x = keras_core.ops.random.uniform(shape=(1, 28, 28, 3))
y = model(x)
print(y.shape)