Describe the bug
Activity Regularizer not working with quantization aware training (QAT). On further evaluation I also saw that it creates a tensor of type <class 'tensorflow.python.framework.ops.Tensor'> for Actvity regularization loss where as for bias and kernel regularization loss it creates eager tensor which is easy to access.
System information
TensorFlow version (installed from source or binary): TF 2.3
TensorFlow Model Optimization version (installed from source or binary): pip install
Python version: 3.6.9
Describe the expected behavior
The Activity regularization loss is also supposed to be a eager tensor that has numpy values associated with it just like bias and kernel. This happens only on quantized models, non-quantized model have eager activity regularization loss.
Describe the current behavior
Currently the activity regularization loss tensor is of type <class 'tensorflow.python.framework.ops.Tensor'> which makes the values inaccessible even when the train function is running in eager mode.
Code to reproduce the issue
import numpy as np
import tensorflow as tf
from tensorflow_model_optimization.python.core.quantization.keras import quantize
from tensorflow.python import keras
l = keras.layers
tf.config.run_functions_eagerly(True)
def layers_list():
return [
l.Conv2D(32, 5, padding='same', activation='relu',
input_shape=image_input_shape(), activity_regularizer=tf.keras.regularizers.l2(l=0.0001), kernel_regularizer=tf.keras.regularizers.l2(l=0.0001)),
l.MaxPooling2D((2, 2), (2, 2), padding='same'),
# TODO(pulkitb): Add BatchNorm when transformations are ready.
# l.BatchNormalization(),
l.Conv2D(64, 5, padding='same', activation='relu', activity_regularizer=tf.keras.regularizers.l2(l=0.0001), kernel_regularizer=tf.keras.regularizers.l2(l=0.0001)),
l.MaxPooling2D((2, 2), (2, 2), padding='same'),
l.Flatten(),
l.Dense(1024, activation='relu'),
l.Dropout(0.4),
l.Dense(10, activation='softmax')
]
def sequential_model():
return keras.Sequential(layers_list())
def functional_model():
"""Builds an MNIST functional model."""
inp = keras.Input(image_input_shape())
x = l.Conv2D(32, 5, padding='same', activation='relu', activity_regularizer=tf.keras.regularizers.l2(l=0.0001), kernel_regularizer=tf.keras.regularizers.l2(l=0.0001))(inp)
x = l.MaxPooling2D((2, 2), (2, 2), padding='same')(x)
# TODO(pulkitb): Add BatchNorm when transformations are ready.
# x = l.BatchNormalization()(x)
x = l.Conv2D(64, 5, padding='same', activation='relu', activity_regularizer=tf.keras.regularizers.l2(l=0.0001), kernel_regularizer=tf.keras.regularizers.l2(l=0.0001))(x)
x = l.MaxPooling2D((2, 2), (2, 2), padding='same')(x)
x = l.Flatten()(x)
x = l.Dense(1024, activation='relu')(x)
x = l.Dropout(0.4)(x)
out = l.Dense(10, activation='softmax')(x)
return keras.models.Model([inp], [out])
def image_input_shape(img_rows=28, img_cols=28):
if tf.keras.backend.image_data_format() == 'channels_first':
return 1, img_rows, img_cols
else:
return img_rows, img_cols, 1
def preprocessed_data(img_rows=28,
img_cols=28,
num_classes=10):
"""Get data for mnist training and evaluation."""
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
if tf.keras.backend.image_data_format() == 'channels_first':
x_train = x_train.reshape(x_train.shape[0], 1, img_rows, img_cols)
x_test = x_test.reshape(x_test.shape[0], 1, img_rows, img_cols)
else:
x_train = x_train.reshape(x_train.shape[0], img_rows, img_cols, 1)
x_test = x_test.reshape(x_test.shape[0], img_rows, img_cols, 1)
x_train = x_train.astype('float32')
x_test = x_test.astype('float32')
x_train /= 255
x_test /= 255
# convert class vectors to binary class matrices
y_train = tf.keras.utils.to_categorical(y_train, num_classes)
y_test = tf.keras.utils.to_categorical(y_test, num_classes)
return x_train, y_train, x_test, y_test
model = functional_model() #sequential_model()
model.summary()
x_train, y_train, x_test, y_test = preprocessed_data()
model.compile(
loss='categorical_crossentropy', optimizer='sgd', metrics=['accuracy'])
model.fit(x_train, y_train, batch_size=500)
_, model_accuracy = model.evaluate(x_test, y_test, verbose=0)
print("Quantizing model")
quantized_model = quantize.quantize_model(model)
print(quantized_model.losses)
quantized_model.compile(
loss='categorical_crossentropy', optimizer='sgd', metrics=['accuracy'])
print(tf.math.add_n(model.losses).numpy())
quantized_model.fit(x_train, y_train, batch_size=500)
_, quantized_model_accuracy = quantized_model.evaluate(
x_test, y_test, verbose=0)
Error Output
1/120 [..............................] - ETA: 0s - loss: 2.3153 - accuracy: 0.1040WARNING:tensorflow:Callbacks method `on_train_batch_end` is slow compared to the batch time (batch time: 0.0017s vs `on_train_batch_end` time: 0.0110s). Check your callbacks.
120/120 [==============================] - 1s 11ms/step - loss: 2.2161 - accuracy: 0.3372
Quantizing model
[<tf.Tensor 'conv2d/ActivityRegularizer_2/truediv:0' shape=() dtype=float32>, <tf.Tensor: shape=(), dtype=float32, numpy=0.00021372623>, <tf.Tensor 'conv2d_1/ActivityRegularizer_2/truediv:0' shape=() dtype=float32>, <tf.Tensor: shape=(), dtype=float32, numpy=0.004322933>]
Traceback (most recent call last):
File "<path to python file>/tempTrain.py", line 95, in <module>
print(quantized_model.losses.numpy())
AttributeError: 'list' object has no attribute 'numpy'
Additional context
I have reported this issue on tensorflow main github issues too as I am not sure which team is responsible for this.
Link to issue raised in tf github - [Tensorflow Issue] (https://github.com/tensorflow/tensorflow/issues/51680). Update, person assigned asked me to report it with model-optimization team. SO the above issue has been closed.
Describe the bug Activity Regularizer not working with quantization aware training (QAT). On further evaluation I also saw that it creates a tensor of type
<class 'tensorflow.python.framework.ops.Tensor'>
for Actvity regularization loss where as for bias and kernel regularization loss it creates eager tensor which is easy to access.System information
TensorFlow version (installed from source or binary): TF 2.3
TensorFlow Model Optimization version (installed from source or binary): pip install
Python version: 3.6.9
Describe the expected behavior The Activity regularization loss is also supposed to be a eager tensor that has numpy values associated with it just like bias and kernel. This happens only on quantized models, non-quantized model have eager activity regularization loss.
Describe the current behavior Currently the activity regularization loss tensor is of type
<class 'tensorflow.python.framework.ops.Tensor'>
which makes the values inaccessible even when the train function is running in eager mode.Code to reproduce the issue
Error Output
Additional context I have reported this issue on tensorflow main github issues too as I am not sure which team is responsible for this. Link to issue raised in tf github - [Tensorflow Issue] (https://github.com/tensorflow/tensorflow/issues/51680). Update, person assigned asked me to report it with model-optimization team. SO the above issue has been closed.