tensorflow / model-optimization

A toolkit to optimize ML models for deployment for Keras and TensorFlow, including quantization and pruning.
https://www.tensorflow.org/model_optimization
Apache License 2.0
1.49k stars 319 forks source link

QAT support for LayerNormalization #942

Open tom-arm opened 2 years ago

tom-arm commented 2 years ago

System information

Motivation This would be beneficial for models that use this layer - this is for example used in Transformer models.

Describe the feature Be able to run QAT on a model with the LayerNormalization layer.

Describe how existing APIs don't satisfy your use case (optional if obvious) As an example, the following code snippet will fail:

import tensorflow as tf
import tensorflow_model_optimization as tfmot

from tensorflow import keras

model = keras.Sequential([
    keras.layers.InputLayer(input_shape=(28, 28)),
    keras.layers.Reshape(target_shape=(28, 28, 1)),
    keras.layers.Conv2D(filters=12, kernel_size=(3, 3), activation='relu'),
    keras.layers.LayerNormalization(axis=3),
    keras.layers.MaxPooling2D(pool_size=(2, 2)),
    keras.layers.Flatten(),
    keras.layers.Dense(10)
])

quant_model = tfmot.quantization.keras.quantize_model(model)