tensorflow / addons

Useful extra functionality for TensorFlow 2.x maintained by SIG-addons
Apache License 2.0
1.69k stars 610 forks source link

What will happen with layers which are in tfa, but not in other keras frameworks and which do not work with Keras 3 (I'm intrested in WeightNormalization layer) #2869

Open Kurdakov opened 3 months ago

Kurdakov commented 3 months ago

System information

Describe the bug while master branch has fixed imports for Keras 3 class WeightNormalization(tf.keras.layers.Wrapper) won't work with Keras 3

Code to reproduce the issue

import tensorflow as tf
import tensorflow_addons as tfa

from tensorflow.keras.layers import Conv1D, Embedding, MaxPooling1D, Dense, Flatten
from tensorflow.keras.models import Sequential
from tensorflow.keras.datasets import imdb
from tensorflow.keras.preprocessing import sequence
from tensorflow.keras.optimizers import Adam

max_words = 800

(Xtrain, ytrain), (Xtest, ytest) = imdb.load_data(num_words=1000)

Xtrain = sequence.pad_sequences(Xtrain, maxlen=max_words)
Xtest = sequence.pad_sequences(Xtest, maxlen=max_words)        

model = Sequential()
model.add(Embedding(1000, 500, input_length=max_words))
model.add(tfa.layers.WeightNormalization(Conv1D(64, 3, activation='relu')))
model.add(MaxPooling1D(2,2))
model.add(tfa.layers.WeightNormalization(Conv1D(32, 3, activation='relu')))
model.add(MaxPooling1D(2,2))
model.add(Flatten())
model.add(Dense(10, activation='relu'))
model.add(Dense(1, activation='sigmoid'))
model.summary()

model.compile(optimizer=Adam(.0001), metrics=['accuracy'], loss='binary_crossentropy')
model.fit(Xtrain, ytrain, validation_split=.2, epochs=10)

problems:

def compute_output_shape(self, input_shape):

uses as_list(), Keras 3 does not support it, removal of as_list helps.

other problems which I failed to resolve are in creation of self._naked_clone_layer

the problem is essentially is that class WeightNormalization is absent in other keras frameworks, but it does not work in tfa with Keras 3 either.

I understand that tfa is near end of support (and already almost an year in minimal support mode), but then the question is - what to use in place of WeightNormalization layer in Keras 3?

Kurdakov commented 3 months ago

what I tried: instead of serialization/deserialization

               layer_config = self.layer.get_config()
               layer_config["trainable"] = False
               self._naked_clone_layer = type(self.layer).from_config(layer_config)

that avoids build issues, but then weight could not be set

Kurdakov commented 3 months ago

with functional Model while weights could be set (considering fix to serialization in previous comment so that _naked_clone_layer could be built and also removal of as_list in computed shape )

import tensorflow as tf
import tensorflow_addons as tfa

from tensorflow.keras.layers import Conv1D, Embedding, MaxPooling1D, Dense, Flatten,Input
from tensorflow.keras.models import Model
from tensorflow.keras.datasets import imdb
from tensorflow.keras.preprocessing import sequence
from tensorflow.keras.optimizers import Adam

max_words = 800

(Xtrain, ytrain), (Xtest, ytest) = imdb.load_data(num_words=1000)

Xtrain = sequence.pad_sequences(Xtrain, maxlen=max_words)
Xtest = sequence.pad_sequences(Xtest, maxlen=max_words)  

input = Input(shape=(800,))
x = Embedding(1000,500)(input)
x = tfa.layers.WeightNormalization(Conv1D(64, 3, activation='relu'))(x)
x = MaxPooling1D(2,2)(x)
x = tfa.layers.WeightNormalization(Conv1D(32, 3, activation='relu'))(x)
x = MaxPooling1D(2,2)(x)
x = Flatten()(x)
x = Dense(10, activation='relu')(x)
out = Dense(1, activation='sigmoid')(x)
model = Model(inputs=input, outputs=[out])
model.summary()

model.compile(optimizer=Adam(.0001), metrics=['accuracy'], loss='binary_crossentropy')
model.fit(Xtrain, ytrain, validation_split=.2, epochs=10)

but there is exception Can not convert a NoneType into a Tensor or Operation which happens in

        def _update_weights():
            # Ensure we read `self.g` after _update_weights.
            with tf.control_dependencies(self._initialize_weights(inputs)):
Kurdakov commented 3 months ago

finally changing

with tf.control_dependencies(self._initialize_weights(inputs)):

to with tf.control_dependencies([self._initialize_weights(inputs)]):

I see tensorflow running