keras-team / keras

Deep Learning for humans
http://keras.io/
Apache License 2.0
62.05k stars 19.48k forks source link

Keras 3 SpectralNormalization() #19183

Closed kwchan7 closed 8 months ago

kwchan7 commented 9 months ago

There seems a bug with Keras 3 SpectralNormalization

Following https://www.tensorflow.org/api_docs/python/tf/keras/layers/SpectralNormalization The below works in keras 3.0.5

import keras
from keras import layers
import numpy as np

x = np.random.rand(1, 10, 10, 1)
dense = layers.SpectralNormalization(keras.layers.Dense(10))
y = dense(x)
y.shape

Out[1]: TensorShape([1, 10, 10, 10])

but the below will have error.

import keras
from keras import layers
import numpy as np

(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
num_classes = 10

x_train = x_train.astype("float32") / 255.0
y_train = keras.utils.to_categorical(y_train, num_classes)

Input = keras.Input(shape=(28,28,1))
x = keras.layers.Flatten()(Input)
dense = layers.SpectralNormalization(keras.layers.Dense(num_classes, activation = 'softmax'))
Output = dense(x)

model = keras.Model(inputs = Input, outputs = Output)
model.compile(loss = keras.losses.CategoricalCrossentropy(), metrics = ['accuracy'], optimizer = 'Adam')
model.summary()

model.fit(x_train, y_train, epochs = 2)

OperatorNotAllowedInGraphError: Exception encountered when calling SpectralNormalization.call().

Iterating over a symbolic tf.Tensor is not allowed. You can attempt the following resolutions to the problem: If you are running in Graph mode, use Eager execution mode or decorate this function with @tf.function. If you are using AutoGraph, you can try decorating this function with @tf.function. If that does not work, then you may be using an unsupported feature or your source code may not be visible to AutoGraph. See https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/autograph/g3doc/reference/limitations.md#access-to-source-code for more information.

Arguments received by SpectralNormalization.call(): • inputs=tf.Tensor(shape=(32, 784), dtype=float32) • training=True

sachinprasadhs commented 9 months ago

I was able to replicate the issue in PyTorch and Jax as well, below is the error message and Gist attached.

Pytorch:

RuntimeError: Exception encountered when calling SpectralNormalization.call().

Boolean value of Tensor with more than one value is ambiguous

Arguments received by SpectralNormalization.call():
  • inputs=torch.Tensor(shape=torch.Size([32, 784]), dtype=float32)
  • training=True

JAX:

hertschuh commented 8 months ago

Fix is submitted and should be available in keras-nightly tomorrow.

google-ml-butler[bot] commented 8 months ago

Are you satisfied with the resolution of your issue? Yes No