keras-team / keras-cv

Industry-strength Computer Vision workflows with Keras
Other
1.01k stars 330 forks source link

Error with quickstart example using keras_core #1940

Closed jacoverster closed 1 year ago

jacoverster commented 1 year ago

I made small changes to the quickstart example for compatibility with the keras_core API (marked below).

Happy to contribute to a fix with some guidance.

The error is as follows:

Epoch 1/8
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
[<ipython-input-4-5394a246c22b>](https://localhost:8080/#) in <cell line: 58>()
     56 
     57 # Train your model
---> 58 model.fit(
     59     train_dataset,
     60     validation_data=test_dataset,

1 frames
[/usr/local/lib/python3.10/dist-packages/keras_cv/layers/fusedmbconv.py](https://localhost:8080/#) in call(self, inputs)
    210         if self.strides == 1 and self.input_filters == self.output_filters:
    211             if self.survival_probability:
--> 212                 x = keras.layers.Dropout(
    213                     self.survival_probability,
    214                     noise_shape=(None, 1, 1, 1),

ValueError: Exception encountered when calling FusedMBConvBlock.call().

tf.function only supports singleton tf.Variables created on the first call. Make sure the tf.Variable is only created once or created outside tf.function. See https://www.tensorflow.org/guide/function#creating_tfvariables for more information.

Arguments received by FusedMBConvBlock.call():
  • inputs=tf.Tensor(shape=(None, 75, 75, 32), dtype=float32)

Code to reproduce:

!pip install -q keras-core keras-cv tensorflow --upgrade
import os

os.environ["KERAS_BACKEND"] = "tensorflow"
import tensorflow as tf
# from tensorflow import keras                                <--- CHANGE ---
import keras_core as keras
import keras_cv
import tensorflow_datasets as tfds

# Create a preprocessing pipeline with augmentations
BATCH_SIZE = 16
NUM_CLASSES = 3
augmenter = keras.Sequential(
    [
        keras_cv.layers.RandomFlip(),
        keras_cv.layers.RandAugment(value_range=(0, 255)),
        keras_cv.layers.CutMix(),
    ]
)

def preprocess_data(images, labels, augment=False):
    # labels = tf.one_hot(labels, NUM_CLASSES)                <--- CHANGE ---
    labels = keras.ops.one_hot(labels, NUM_CLASSES)
    inputs = {"images": images, "labels": labels}
    outputs = augmenter(inputs) if augment else inputs
    return outputs['images'], outputs['labels']

train_dataset, test_dataset = tfds.load(
    'rock_paper_scissors',
    as_supervised=True,
    split=['train', 'test'],
)
train_dataset = train_dataset.batch(BATCH_SIZE).map(
    lambda x, y: preprocess_data(x, y, augment=True),
        num_parallel_calls=tf.data.AUTOTUNE).prefetch(
            tf.data.AUTOTUNE)
test_dataset = test_dataset.batch(BATCH_SIZE).map(
    preprocess_data, num_parallel_calls=tf.data.AUTOTUNE).prefetch(
        tf.data.AUTOTUNE)

# Create a model using a pretrained backbone
backbone = keras_cv.models.EfficientNetV2Backbone.from_preset(
    "efficientnetv2_b0_imagenet"
)
model = keras_cv.models.ImageClassifier(
    backbone=backbone,
    num_classes=NUM_CLASSES,
    activation="softmax",
)
model.compile(
    loss='categorical_crossentropy',
    optimizer=keras.optimizers.Adam(learning_rate=1e-5),
    metrics=['accuracy']
)

# Train your model
model.fit(
    train_dataset,
    validation_data=test_dataset,
    epochs=8,
)
jacoverster commented 1 year ago

Running the example with the JAX backend:

os.environ["KERAS_BACKEND"] = "jax"

Produces this keras.ops.one_hot related error. I'm not familiar enough with JAX to comment, but might be of interest to the keras_core team.

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
[<ipython-input-3-7e84aa32f170>](https://localhost:8080/#) in <cell line: 30>()
     28     split=['train', 'test'],
     29 )
---> 30 train_dataset = train_dataset.batch(BATCH_SIZE).map(
     31     lambda x, y: preprocess_data(x, y, augment=True),
     32         num_parallel_calls=tf.data.AUTOTUNE).prefetch(

34 frames
    [... skipping hidden 8 frame]

[/usr/local/lib/python3.10/dist-packages/jax/_src/dtypes.py](https://localhost:8080/#) in _canonicalize_dtype(x64_enabled, allow_opaque_dtype, dtype)
    146     dtype_ = np.dtype(dtype)
    147   except TypeError as e:
--> 148     raise TypeError(f'dtype {dtype!r} not understood') from e
    149 
    150   if x64_enabled:

TypeError: in user code:

    File "<ipython-input-3-7e84aa32f170>", line 31, in None  *
        lambda x, y: preprocess_data(x, y, augment=True)
    File "<ipython-input-3-7e84aa32f170>", line 20, in preprocess_data  *
        labels = keras.ops.one_hot(labels, NUM_CLASSES)
    File "/usr/local/lib/python3.10/dist-packages/keras_core/src/ops/nn.py", line 965, in one_hot  *
        x, num_classes, axis=axis, dtype=dtype or backend.floatx()
    File "/usr/local/lib/python3.10/dist-packages/keras_core/src/backend/jax/nn.py", line 391, in one_hot  *
        return jnn.one_hot(x, num_classes, axis=axis, dtype=dtype)
    File "/usr/local/lib/python3.10/dist-packages/jax/_src/nn/functions.py", line 466, in one_hot  *
        return _one_hot(x, num_classes, dtype=dtype, axis=axis)
    File "/usr/local/lib/python3.10/dist-packages/jax/_src/traceback_util.py", line 166, in reraise_with_filtered_traceback  **
        return fun(*args, **kwargs)
    File "/usr/local/lib/python3.10/dist-packages/jax/_src/pjit.py", line 250, in cache_miss
        outs, out_flat, out_tree, args_flat, jaxpr = _python_pjit_helper(
    File "/usr/local/lib/python3.10/dist-packages/jax/_src/pjit.py", line 158, in _python_pjit_helper
        args_flat, _, params, in_tree, out_tree, _ = infer_params_fn(
    File "/usr/local/lib/python3.10/dist-packages/jax/_src/api.py", line 306, in infer_params
        return pjit.common_infer_params(pjit_info_args, *args, **kwargs)
    File "/usr/local/lib/python3.10/dist-packages/jax/_src/pjit.py", line 491, in common_infer_params
        avals.append(shaped_abstractify(a))
    File "/usr/local/lib/python3.10/dist-packages/jax/_src/api_util.py", line 565, in shaped_abstractify
        return _shaped_abstractify_slow(x)
    File "/usr/local/lib/python3.10/dist-packages/jax/_src/api_util.py", line 552, in _shaped_abstractify_slow
        dtype = dtypes.canonicalize_dtype(x.dtype, allow_opaque_dtype=True)
    File "/usr/local/lib/python3.10/dist-packages/jax/_src/dtypes.py", line 163, in canonicalize_dtype
        return _canonicalize_dtype(config.x64_enabled, allow_opaque_dtype, dtype)
    File "/usr/local/lib/python3.10/dist-packages/jax/_src/dtypes.py", line 148, in _canonicalize_dtype
        raise TypeError(f'dtype {dtype!r} not understood') from e

    TypeError: dtype tf.int64 not understood
ianstenbit commented 1 year ago

@jacoverster thanks for the report!

The Jax issue is likely because you're using a Keras Core op in a TF-only preprocessing workflow:

def preprocess_data(images, labels, augment=False):
    # labels = tf.one_hot(labels, NUM_CLASSES)                <--- CHANGE ---
    labels = keras.ops.one_hot(labels, NUM_CLASSES)
    inputs = {"images": images, "labels": labels}
    outputs = augmenter(inputs) if augment else inputs
    return outputs['images'], outputs['labels']

The preprocessing pipeline here is running in tf.data, so it should use pure TF ops internally. I suspect if you switch back to tf.one_hot this error will ago away.

With respect to the error with the TF backend -- this is less obvious to me. Some things I'd try are:

I will take a deeper look at this as soon as I have a chance -- thanks again for the thorough report!

jbischof commented 1 year ago

@ianstenbit sadly using tf.one_hot does not fix the issue. I get error (gist):

NotImplementedError: Cannot convert a symbolic tf.Tensor (args_0:0) to a numpy array. This error may indicate that you're trying to pass a Tensor to a NumPy call, which is not supported.

My guess is that KerasCV using JAX is creating JAX augmentation layers, which cannot consume a tf.Tensor. My guess is combining tf.data with other backends could be tricky unless you try to jugging both Keras Core and tf.keras.

I had zero problem using a TensorFlow backend (gist)

jacoverster commented 1 year ago

Thanks for the quick response. The JAX backend works with the changes, but model.fit() still fails with the same error as above. I tested it using TensorFlow backend with your gist.

ValueError: Exception encountered when calling FusedMBConvBlock.call().

tf.function only supports singleton tf.Variables created on the first call. Make sure the tf.Variable is only created once or created outside tf.function. See https://www.tensorflow.org/guide/function#creating_tfvariables for more information.

Arguments received by FusedMBConvBlock.call():
  • inputs=tf.Tensor(shape=(None, 75, 75, 32), dtype=float32)
ianstenbit commented 1 year ago

@jacoverster this looks like an issue with the MBConvBlock's dropout layer. I will send a fix in a moment.

ianstenbit commented 1 year ago

This will be fixed by https://github.com/keras-team/keras-cv/pull/1951

jacoverster commented 1 year ago

Great work, thanks guys.