tensorflow / model-optimization

A toolkit to optimize ML models for deployment for Keras and TensorFlow, including quantization and pruning.
https://www.tensorflow.org/model_optimization
Apache License 2.0
1.49k stars 320 forks source link

TFOpLambda not supported in INT8 Quantization Aware Training (Mobilenetv3) #1145

Open pedrofrodenas opened 1 week ago

pedrofrodenas commented 1 week ago

Describe the bug

I cannot quantize Mobilenetv3 from keras2 because the hard-swish activation fuction is implemented as a TFOpLambda.

System information

tensorflow version: 2.17 tf_keras version: 2.17 tensorflow_model_optimization version: 0.8.0

TensorFlow Model Optimization version installed from pip

Python version: Python 3.9.19

Describe the expected behavior

Quantization aware training can be applied to keras.applications.MobileNetV3Small using tfmot.quantization.keras.quantize_model

Describe the current behavior

When some layer is a TFOpLambda the following error raises:

AttributeError: Exception encountered when calling layer "tf.operators.add" (type TFOpLambda).

'list' object has no attribute 'dtype'

Call arguments received by layer "tf.operators.add" (type TFOpLambda): • x=['tf.Tensor(shape=(None, 112, 112, 16), dtype=float32)'] • y=3.0 • name=None

Code to reproduce the issue

import os
os.environ["TF_USE_LEGACY_KERAS"] = "1"

import tf_keras as keras

model = keras.applications.MobileNetV3Small(
        input_shape=tuple([224,224,3]),
        alpha=1.0,
        minimalistic=False,
        include_top=True,
        weights="imagenet",
        input_tensor=None,
        classes=1000,
        pooling=None,
        dropout_rate=0.2,
        classifier_activation="softmax",
        include_preprocessing=True,
    )

import tensorflow_model_optimization as tfmot

quantize_model = tfmot.quantization.keras.quantize_model

# q_aware stands for for quantization aware.
q_aware_model = quantize_model(model)
pedrofrodenas commented 1 week ago

I was able to apply quantization but not in all the layers using:

import os
os.environ["TF_USE_LEGACY_KERAS"] = "1"

import tf_keras as keras
import tensorflow as tf
#import keras
#from tensorflow_model_optimization.python.core.keras.compat import keras

model = keras.applications.MobileNetV3Small(
        input_shape=tuple([32,32,3]),
        alpha=1.0,
        minimalistic=False,
        include_top=True,
        weights="imagenet",
        input_tensor=None,
        classes=1000,
        pooling=None,
        dropout_rate=0.2,
        classifier_activation="softmax",
        include_preprocessing=True,
    )

import tensorflow_model_optimization as tfmot

quantize_model = tfmot.quantization.keras.quantize_model

# Annotate layers for quantization
def apply_qat_with_annotations(layer):

    if isinstance(layer, keras.layers.Conv2D) or isinstance(layer, keras.layers.Add) or isinstance(layer, keras.layers.BatchNormalization)or isinstance(layer, keras.layers.DepthwiseConv2D) or isinstance(layer, keras.layers.ReLU)or isinstance(layer, keras.layers.GlobalAveragePooling2D):
        return tfmot.quantization.keras.quantize_annotate_layer(layer)
    else: return layer

# Use `keras.models.clone_model` to apply `apply_quantization_to_dense` 
# to the layers of the model.
annotated_model = keras.models.clone_model(
    model,
    clone_function=apply_qat_with_annotations,
)

quant_aware_model = tfmot.quantization.keras.quantize_apply(annotated_model)

# q_aware stands for for quantization aware.
# q_aware_model = quantize_model(annotated_model)

# `quantize_model` requires a recompile.
quant_aware_model.compile(optimizer='adam',
              loss=keras.losses.SparseCategoricalCrossentropy(from_logits=False),
              metrics=['accuracy'])

print(quant_aware_model.summary())

from tf_keras.datasets import cifar10

# Load the dataset
(train_images, train_labels), (test_images, test_labels) = cifar10.load_data()

quant_aware_model.fit(train_images, train_labels,
                  batch_size=256, epochs=1, validation_split=0.1)

converter = tf.lite.TFLiteConverter.from_keras_model(quant_aware_model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]

quantized_tflite_model = converter.convert()

# Create float TFLite model.
float_converter = tf.lite.TFLiteConverter.from_keras_model(model)
float_tflite_model = float_converter.convert()

with open('quantized.tflite', 'wb') as f:
  f.write(quantized_tflite_model)

with open('float.tflite', 'wb') as f:
  f.write(float_tflite_model)

The resulting graph is a mess of quantization and dequantizations:

Screenshot from 2024-10-08 13-06-02