tensorflow / model-optimization

A toolkit to optimize ML models for deployment for Keras and TensorFlow, including quantization and pruning.
https://www.tensorflow.org/model_optimization
Apache License 2.0
1.48k stars 320 forks source link

`tf.split` or `tf.transpose` cause errors for quantize-aware training with `quantize_apply` #1062

Open Janus-Shiau opened 1 year ago

Janus-Shiau commented 1 year ago

Describe the bug

We are trying to implement some network like ShuffleNetV2 but encounter some error when quantize_apply the model.

image

I believe ShuffleNet or related ideas are popular in edge devices, please kindly help us to resolve this proble.

Any advice is welcome.

System information

TensorFlow version (installed from source or binary): 2.7.0

TensorFlow Model Optimization version (installed from source or binary): 0.7.0

Python version: 3.8.13

Describe the expected behavior

Just add quantization-aware operator in to the model.

Describe the current behavior

When running the provided code, either the tf.transpose or tf.split will cause error to Tensorflow Model Optimization.

The error message due to tf.split before convolution layers:

ValueError: Exception encountered when calling layer "bn3" (type BatchNormalization).

Shape must be rank 4 but is rank 5 for '{{node bn3/FusedBatchNormV3}} = FusedBatchNormV3[T=DT_FLOAT, U=DT_FLOAT, data_format="NHWC", epsilon=0.001, exponential_avg_factor=1, is_training=false](Placeholder, bn3/ReadVariableOp, bn3/ReadVariableOp_1, bn3/FusedBatchNormV3/ReadVariableOp, bn3/FusedBatchNormV3/ReadVariableOp_1)' with input shapes: [1,?,128,128,32], [32], [32], [32], [32].

The error message due to tf.transpose:

ValueError: Exception encountered when calling layer "tf.compat.v1.transpose" (type TFOpLambda).

Dimension must be 6 but is 5 for '{{node tf.compat.v1.transpose/transpose}} = Transpose[T=DT_FLOAT, Tperm=DT_INT32](tf.compat.v1.transpose/transpose/a, tf.compat.v1.transpose/transpose/perm)' with input shapes: [1,?,128,128,2,32], [5].

Code to reproduce the issue

Just run the following code you will get the error message due to tf.split.

from __future__ import annotations

from typing import Callable, Optional

import tensorflow as tf
import tensorflow_model_optimization as tfmot
from tensorflow.keras import layers

SKIP_LAYER = [
    "resize",
    "Resize",
    "reshape",
    "Reshape",
    "concat",
    "Concat" "ExpandDims",
    "Repeats",
    "Shape",
    "strided_slice",
    "Tile",
]

def quantize_model(
    model: tf.keras.Model,
    annotate: Optional[Callable] = None,
    quantize_scope: Optional[dict[str, tf.keras.layers.Layer]] = None,
) -> tf.keras.Model:
    quantize_scope = {} if quantize_scope is None else quantize_scope

    def annotate(layer):
        if any([name in layer.name for name in SKIP_LAYER]):
            return layer
        else:
            return tfmot.quantization.keras.quantize_annotate_layer(layer)

    anno_model = tf.keras.models.clone_model(model, clone_function=annotate)
    with tfmot.quantization.keras.quantize_scope(quantize_scope):
        model = tfmot.quantization.keras.quantize_apply(anno_model)

    return model

def channel_shuffle(tensor: tf.Tensor, groups: int = 2) -> tf.Tensor:
    """Channel shuffle operation."""
    _, height, width, num_channels = tensor.shape.as_list()
    assert num_channels % groups == 0

    tensor = tf.reshape(tensor, [-1, height, width, groups, num_channels // groups])
    tensor = tf.transpose(tensor, [0, 1, 2, 4, 3])
    tensor = tf.identity(tensor, name="channel_shuffle")

    tensor = tf.reshape(tensor, [-1, height, width, num_channels])
    return tensor

def simple_nn(img_input: tf.Tensor) -> tf.Tensor:
    latent = layers.Conv2D(32, 1, padding="same", use_bias=False, name="conv1")(img_input)
    latent = layers.BatchNormalization(name="bn1")(latent)
    latent = layers.ReLU(name="relu1")(latent)

    latent = layers.DepthwiseConv2D(3, 1, padding="same", name="conv2")(img_input)
    latent = layers.BatchNormalization(name="bn2")(latent)

    latent = layers.Conv2D(32, 1, padding="same", use_bias=False, name="conv3")(img_input)
    latent = layers.BatchNormalization(name="bn3")(latent)
    latent = layers.ReLU(name="relu3")(latent)

    return latent

def split_like_nn(img_input: tf.Tensor) -> tf.Tensor:
    latent = layers.Conv2D(64, 1, padding="same", use_bias=False, name="conv0")(img_input)
    latent = layers.BatchNormalization(name="bn0")(latent)
    latent = layers.ReLU(name="relu0")(latent)

    latent_0, latent_1 = tf.split(latent, 2, axis=-1)
    latent_0 = simple_nn(latent_0)
    latent = tf.concat([latent_0, latent_1], axis=-1)

    latent = channel_shuffle(latent)

    return latent

if __name__ == "__main__":
    img_input = tf.keras.Input((128, 128, 1), dtype=tf.float32, name="img")

    outputs = split_like_nn(img_input)

    model = tf.keras.Model(inputs=img_input, outputs=outputs, name="PoseNetV2")
    model.summary()

    model_qat = quantize_model(model)
    model_qat.summary()

You can just comment the following three lines of code will get the error message from tf.transpose.

 latent_0, latent_1 = tf.split(latent, 2, axis=-1)
 latent_0 = simple_nn(latent_0)
 latent = tf.concat([latent_0, latent_1], axis=-1)
guillem-ms commented 1 year ago

Hi! I'm also suffering from the same error using tf.split Is there any fix coming soon?

DerryFitz commented 1 year ago

Hi, I'm getting the same error too with tf.transpose and tf.permute, any update on a solution?

or-ims commented 7 months ago

Hi, tf.nn.depthtospace causes the same error. I'd be very happy about any advice how to solve this :)

robertatdm commented 3 months ago

I think I had the same issue. I could overcome this error by wrapping tf.split() in a keras layer:

@keras.saving.register_keras_serializable(package="MyLayers", name="SplitLayer")
class SplitLayer(keras.layers.Layer):
    def __init__(self, num_or_size_splits, axis, **kwargs):
        super(SplitLayer, self).__init__(**kwargs)
        self.num_or_size_splits = num_or_size_splits
        self.axis = axis

    def call(self, inputs):
        return tf.split(inputs, self.num_or_size_splits, axis=self.axis)

    def get_config(self):
        config = super(SplitLayer, self).get_config()
        config.update({
            'num_or_size_splits': self.num_or_size_splits,
            'axis': self.axis,
        })
        return config

I quantized my model (RetinaNet) like this:

def quantize_model(model):        
    def quantize_annotate(layer):
        layer_types_to_avoid = (kcv.layers.AnchorGenerator, kcv.models.retinanet.LabelEncoder, kcv.layers.NonMaxSuppression, my_retinanet.MyPredictionDecoder)
        if isinstance(layer, layer_types_to_avoid) or "split" in layer.name:
            return layer
        return tfmot.quantization.keras.quantize_annotate_layer(layer)

    annotated_model = tf.keras.models.clone_model(
        model,
        clone_function=quantize_annotate,
    )

    with tfmot.quantization.keras.quantize_scope():
        quantized_model = tfmot.quantization.keras.quantize_apply(annotated_model)

    return quantized_model

tf version: 2.15.1 keras version: 2.15.0 tfmot version: 0.7.5