tensorflow / model-optimization

A toolkit to optimize ML models for deployment for Keras and TensorFlow, including quantization and pruning.
https://www.tensorflow.org/model_optimization
Apache License 2.0
1.48k stars 320 forks source link

Per-tensor QAT model Conv2d+BN+relu folding issue #1131

Open sheh opened 1 month ago

sheh commented 1 month ago

Describe the bug I need to train QAT (per-tensor) model and then convert it tflite. But I get "folding issue" described here.

System information

TensorFlow version (installed from source or binary): 2.15.0

TensorFlow Model Optimization version (installed from source or binary): 0.8.0

Python version: 3.10.12

Describe the expected behavior

A 1-layer CNN (conv2d+bn+relu) is folded and converted to tflite after QAT in per-tensor mode without splitting computation graph on multiply "Quantize-Dequatize" parts.

Describe the current behavior

After folding a 1-layer CNN (conv2d+bn+relu) the folded layer is unquantized.

Code to reproduce the issue

import tensorflow as tf
from tensorflow_model_optimization.python.core.keras.compat import keras
from tensorflow_model_optimization.python.core.quantization.keras.default_8bit import \
    default_8bit_quantize_scheme
import tensorflow_model_optimization as tfmot

quantize_apply = tfmot.quantization.keras.quantize_apply
quantize_annotate_model = tfmot.quantization.keras.quantize_annotate_model

def train_qat_convert_tflite(per_tensor):
    model = keras.Sequential([
        keras.layers.InputLayer(input_shape=(128, 128, 3)),

        keras.layers.Conv2D(3, 3, padding='same', use_bias=False),
        keras.layers.BatchNormalization(),
        keras.layers.Activation('relu'),

        keras.layers.Softmax(),
    ])

    annotated_model = quantize_annotate_model(model)
    q_aware_model = quantize_apply(annotated_model,
                                   scheme=default_8bit_quantize_scheme.Default8BitQuantizeScheme(disable_per_axis=per_tensor))

    q_aware_model.compile(
        # optimizer=Adam(learning_rate=learning_rate, epsilon=1e-8, weight_decay=1e-4),
        optimizer='Adam',
        loss=keras.losses.MeanAbsoluteError(),
        metrics=['accuracy'],
    )

    q_aware_model.fit(
        x=tf.random.normal((128, 128, 128, 3)),
        y=tf.random.normal((128, 128, 128, 3)),
        batch_size=16,
        epochs=1,
    )
    q_aware_model.save(f'{per_tensor=}.h5')

    converter = tf.lite.TFLiteConverter.from_keras_model(q_aware_model)
    converter.optimizations = [tf.lite.Optimize.DEFAULT]

    quantized_tflite_model = converter.convert()
    open(f'{per_tensor=}.tflite', "wb").write(quantized_tflite_model)

train_qat_convert_tflite(per_tensor=True)
train_qat_convert_tflite(per_tensor=False)

Screenshots

keras-h5

2024-05-15_13-35

Additional context I tested #552 but in case of a simple 1-layer CNN (see code) there are no custom layers so if statement in _replace function is False and I get the next line. I see that in keras h5 model BN layer is quantized as per-channel because quantization parameters in both cases are tensors not scalar as it is expected for per-tensor mode.