tensorflow / model-optimization

A toolkit to optimize ML models for deployment for Keras and TensorFlow, including quantization and pruning.
Apache License 2.0
1.49k stars 323 forks source link

Activity Regularizer not working with quantization aware training (QAT) #802

Open bayesian-mind opened 3 years ago

bayesian-mind commented 3 years ago

Describe the bug Activity Regularizer not working with quantization aware training (QAT). On further evaluation I also saw that it creates a tensor of type <class 'tensorflow.python.framework.ops.Tensor'> for Actvity regularization loss where as for bias and kernel regularization loss it creates eager tensor which is easy to access.

System information

TensorFlow version (installed from source or binary): TF 2.3

TensorFlow Model Optimization version (installed from source or binary): pip install

Python version: 3.6.9

Describe the expected behavior The Activity regularization loss is also supposed to be a eager tensor that has numpy values associated with it just like bias and kernel. This happens only on quantized models, non-quantized model have eager activity regularization loss.

Describe the current behavior Currently the activity regularization loss tensor is of type <class 'tensorflow.python.framework.ops.Tensor'> which makes the values inaccessible even when the train function is running in eager mode.

Code to reproduce the issue

import numpy as np
import tensorflow as tf
from tensorflow_model_optimization.python.core.quantization.keras import quantize
from tensorflow.python import keras
l = keras.layers


def layers_list():
  return [
      l.Conv2D(32, 5, padding='same', activation='relu',
               input_shape=image_input_shape(), activity_regularizer=tf.keras.regularizers.l2(l=0.0001), kernel_regularizer=tf.keras.regularizers.l2(l=0.0001)),
      l.MaxPooling2D((2, 2), (2, 2), padding='same'),
      # TODO(pulkitb): Add BatchNorm when transformations are ready.
      # l.BatchNormalization(),
      l.Conv2D(64, 5, padding='same', activation='relu', activity_regularizer=tf.keras.regularizers.l2(l=0.0001), kernel_regularizer=tf.keras.regularizers.l2(l=0.0001)),
      l.MaxPooling2D((2, 2), (2, 2), padding='same'),
      l.Dense(1024, activation='relu'),
      l.Dense(10, activation='softmax')

def sequential_model():
  return keras.Sequential(layers_list())

def functional_model():
  """Builds an MNIST functional model."""
  inp = keras.Input(image_input_shape())
  x = l.Conv2D(32, 5, padding='same', activation='relu', activity_regularizer=tf.keras.regularizers.l2(l=0.0001), kernel_regularizer=tf.keras.regularizers.l2(l=0.0001))(inp)
  x = l.MaxPooling2D((2, 2), (2, 2), padding='same')(x)
  # TODO(pulkitb): Add BatchNorm when transformations are ready.
  # x = l.BatchNormalization()(x)
  x = l.Conv2D(64, 5, padding='same', activation='relu', activity_regularizer=tf.keras.regularizers.l2(l=0.0001), kernel_regularizer=tf.keras.regularizers.l2(l=0.0001))(x)
  x = l.MaxPooling2D((2, 2), (2, 2), padding='same')(x)
  x = l.Flatten()(x)
  x = l.Dense(1024, activation='relu')(x)
  x = l.Dropout(0.4)(x)
  out = l.Dense(10, activation='softmax')(x)

  return keras.models.Model([inp], [out])

def image_input_shape(img_rows=28, img_cols=28):
  if tf.keras.backend.image_data_format() == 'channels_first':
    return 1, img_rows, img_cols
    return img_rows, img_cols, 1

def preprocessed_data(img_rows=28,
  """Get data for mnist training and evaluation."""
  (x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()

  if tf.keras.backend.image_data_format() == 'channels_first':
    x_train = x_train.reshape(x_train.shape[0], 1, img_rows, img_cols)
    x_test = x_test.reshape(x_test.shape[0], 1, img_rows, img_cols)
    x_train = x_train.reshape(x_train.shape[0], img_rows, img_cols, 1)
    x_test = x_test.reshape(x_test.shape[0], img_rows, img_cols, 1)

  x_train = x_train.astype('float32')
  x_test = x_test.astype('float32')
  x_train /= 255
  x_test /= 255

  # convert class vectors to binary class matrices
  y_train = tf.keras.utils.to_categorical(y_train, num_classes)
  y_test = tf.keras.utils.to_categorical(y_test, num_classes)

  return x_train, y_train, x_test, y_test

model = functional_model() #sequential_model()
x_train, y_train, x_test, y_test = preprocessed_data()

    loss='categorical_crossentropy', optimizer='sgd', metrics=['accuracy'])
model.fit(x_train, y_train, batch_size=500)
_, model_accuracy = model.evaluate(x_test, y_test, verbose=0)

print("Quantizing model")

quantized_model = quantize.quantize_model(model)
    loss='categorical_crossentropy', optimizer='sgd', metrics=['accuracy'])

quantized_model.fit(x_train, y_train, batch_size=500)
_, quantized_model_accuracy = quantized_model.evaluate(
    x_test, y_test, verbose=0)

Error Output

  1/120 [..............................] - ETA: 0s - loss: 2.3153 - accuracy: 0.1040WARNING:tensorflow:Callbacks method `on_train_batch_end` is slow compared to the batch time (batch time: 0.0017s vs `on_train_batch_end` time: 0.0110s). Check your callbacks.
120/120 [==============================] - 1s 11ms/step - loss: 2.2161 - accuracy: 0.3372
Quantizing model
[<tf.Tensor 'conv2d/ActivityRegularizer_2/truediv:0' shape=() dtype=float32>, <tf.Tensor: shape=(), dtype=float32, numpy=0.00021372623>, <tf.Tensor 'conv2d_1/ActivityRegularizer_2/truediv:0' shape=() dtype=float32>, <tf.Tensor: shape=(), dtype=float32, numpy=0.004322933>]
Traceback (most recent call last):
  File "<path to python file>/tempTrain.py", line 95, in <module>
AttributeError: 'list' object has no attribute 'numpy'

Additional context I have reported this issue on tensorflow main github issues too as I am not sure which team is responsible for this. Link to issue raised in tf github - [Tensorflow Issue] (https://github.com/tensorflow/tensorflow/issues/51680). Update, person assigned asked me to report it with model-optimization team. SO the above issue has been closed.

narduzzi commented 1 month ago

It seems that the losses are a list of several losses. Can you try the following code instead?

print([x.numpy() for x in quantized_model.losses])