tensorflow / model-optimization

A toolkit to optimize ML models for deployment for Keras and TensorFlow, including quantization and pruning.
https://www.tensorflow.org/model_optimization
Apache License 2.0
1.5k stars 323 forks source link

Support for custom layers composed of standard keras layers #756

Open ikhatri opened 3 years ago

ikhatri commented 3 years ago

I'm wondering how I can create a quantization configuration for a custom layer (which implements the keras layer class) where my custom layer is composed of other standard keras layers (such as Conv2D, BatchNorm etc.)

Assuming that every sub-layer in my custom one is either supported or can be skipped for quantization, can I supply a quantization config that just recursively checks the existing default8bitquantizeregistry or something like that?

fredrec commented 3 years ago

You could create a custom QuantizeConfig that delegates the config for your standard Keras sublayers to existing QuantizeConfig instances.

Instead, a simpler way would be to quantize the sublayers based on type, as in the Quantize some layers example. Would that work for your use case ?

ikhatri commented 3 years ago

Could you provide an example on how to do the first option? That's exactly what I'd like to do, but I couldn't quite figure out how I would do so.

As far as the second option goes, I gave it a try but unfortunately it ends up being quite unwieldy. Here's an example to demonstrate. I was storing the sub-layers in my custom layer object but then instead of calling the layer directly I made a function that accepts them as args and calls them individually. While it works just fine, it makes the code a huge mess to read.

Here's the file where I define my custom keras layer (and the functional wrapper thing I mentioned above):

import numpy as np
import tensorflow as tf
from tensorflow.keras import layers

# Magic fix for RTX GPUs
gpus = tf.config.experimental.list_physical_devices("GPU")
for gpu in gpus:
    tf.config.experimental.set_memory_growth(gpu, True)

class MnistFeedForward(layers.Layer):
    def __init__(self, name: str = "mnist_feedforward", **kwargs) -> None:
        super().__init__(name=name)
        self._build()

    def _build(self) -> None:
        self.conv2d = layers.Conv2D(
            filters=12,
            kernel_size=(3, 3),
            activation="relu",
            name="MFF_conv2d",
        )
        self.pool = layers.MaxPool2D(pool_size=(2, 2), name="MFF_maxpool")
        self.reshape = layers.Flatten()

    def call(self, inputs: tf.Tensor) -> tf.Tensor:
        with tf.compat.v1.variable_scope("subnet"):
            x = self.conv2d(inputs)
            x = self.pool(x)
            x = self.reshape(x)

        # This is a no-op, just to test if constant muls are supported by the quantization framework
        mask = np.ones(x.shape[1], np.float32)
        output = tf.constant(mask) * x
        return output

def mnist_feed_fw_func(inputs, conv2d, pool, reshape):
    with tf.compat.v1.variable_scope("subnet"):
        x = conv2d(inputs)
        x = pool(x)
        x = reshape(x)

    # This is a no-op, just to test if constant muls are supported by the quantization framework
    mask = np.ones(x.shape[1], np.float32)
    output = tf.constant(mask) * x
    return output

and here's a small program to test the code above:

from pathlib import Path
import tensorflow as tf
import numpy as np
from tensorflow import keras
import tensorflow_model_optimization as tfmot

from network import MnistFeedForward, mnist_feed_fw_func

# Magic fix for RTX GPUs
gpus = tf.config.experimental.list_physical_devices("GPU")
for gpu in gpus:
    tf.config.experimental.set_memory_growth(gpu, True)

# Load MNIST dataset
mnist = keras.datasets.mnist
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()

# Normalize the input image so that each pixel value is between 0 to 1.
train_images = train_images / 255.0
test_images = test_images / 255.0

# Define the model graph via the Keras Functional API

mff = MnistFeedForward()

inputs = keras.layers.Input(shape=(28, 28))
x = keras.layers.Reshape(target_shape=(28, 28, 1))(inputs)
x = mff(x) # This does not work
# x = mnist_feed_fw_func(x, mff.conv2d, mff.pool, mff.reshape) # replacing the line above with this, does work
outputs = keras.layers.Dense(10)(x)
model = keras.Model(inputs=inputs, outputs=outputs)

# Train the digit classification model
model.compile(
    optimizer="adam",
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=["accuracy"],
)

model.summary()
print()

model.fit(
    train_images,
    train_labels,
    epochs=1,
    validation_split=0.1,
)

quantize_model = tfmot.quantization.keras.quantize_model

# Helper function uses `quantize_annotate_layer` to annotate that only supported layers should be quantized
def apply_quantization(layer):
    registry = tfmot.quantization.keras.default_8bit.Default8BitQuantizeRegistry()
    if registry.supports(layer):
        return tfmot.quantization.keras.quantize_annotate_layer(layer)
    return layer

# Use `tf.keras.models.clone_model` to apply `apply_quantization_to_dense`
# to the layers of the model.
custom_objects = {"mnist_feedforward": MnistFeedForward}
with tf.keras.utils.custom_object_scope(custom_objects):
    annotated_model = tf.keras.models.clone_model(
        model,
        clone_function=apply_quantization,
    )

# Now that the Dense layers are annotated,
# `quantize_apply` actually makes the model quantization aware.
q_aware_model = tfmot.quantization.keras.quantize_apply(annotated_model)

# `quantize_model` requires a recompile.
q_aware_model.compile(
    optimizer="adam",
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=["accuracy"],
)

q_aware_model.summary()
print()

# Fine tune the quantized model
train_images_subset = train_images[0:1000]  # out of 60000
train_labels_subset = train_labels[0:1000]

print("Fine-tuning the quantized model:")

q_aware_model.fit(
    train_images_subset,
    train_labels_subset,
    batch_size=500,
    epochs=1,
    validation_split=0.1,
)

_, baseline_model_accuracy = model.evaluate(test_images, test_labels, verbose=0)

_, q_aware_model_accuracy = q_aware_model.evaluate(test_images, test_labels, verbose=0)

print("Baseline test accuracy:", baseline_model_accuracy)
print("Quant test accuracy:", q_aware_model_accuracy)

Running this code fails with the following traceback:

Traceback (most recent call last):
  File "mnist_quantize.py", line 79, in <module>
    q_aware_model = tfmot.quantization.keras.quantize_apply(annotated_model)
  File "/home/ikhatri/miniconda3/envs/quantize/lib/python3.8/site-packages/tensorflow_model_optimization/python/core/keras/metrics.py", line 64, in inner
    raise error
  File "/home/ikhatri/miniconda3/envs/quantize/lib/python3.8/site-packages/tensorflow_model_optimization/python/core/keras/metrics.py", line 59, in inner
    results = func(*args, **kwargs)
  File "/home/ikhatri/miniconda3/envs/quantize/lib/python3.8/site-packages/tensorflow_model_optimization/python/core/quantization/keras/quantize.py", line 465, in quantize_apply
    transformed_model, layer_quantize_map = quantize_transform.apply(
  File "/home/ikhatri/miniconda3/envs/quantize/lib/python3.8/site-packages/tensorflow_model_optimization/python/core/quantization/keras/default_8bit/default_8bit_quantize_layout_transform.py", line 71, in apply
    return model_transformer.ModelTransformer(
  File "/home/ikhatri/miniconda3/envs/quantize/lib/python3.8/site-packages/tensorflow_model_optimization/python/core/quantization/keras/graph_transformations/model_transformer.py", line 613, in transform
    transformed_model = keras.Model.from_config(self._config, custom_objects)
  File "/home/ikhatri/miniconda3/envs/quantize/lib/python3.8/site-packages/tensorflow/python/keras/engine/training.py", line 2261, in from_config
    return functional.Functional.from_config(
  File "/home/ikhatri/miniconda3/envs/quantize/lib/python3.8/site-packages/tensorflow/python/keras/engine/functional.py", line 668, in from_config
    input_tensors, output_tensors, created_layers = reconstruct_from_config(
  File "/home/ikhatri/miniconda3/envs/quantize/lib/python3.8/site-packages/tensorflow/python/keras/engine/functional.py", line 1275, in reconstruct_from_config
    process_layer(layer_data)
  File "/home/ikhatri/miniconda3/envs/quantize/lib/python3.8/site-packages/tensorflow/python/keras/engine/functional.py", line 1257, in process_layer
    layer = deserialize_layer(layer_data, custom_objects=custom_objects)
  File "/home/ikhatri/miniconda3/envs/quantize/lib/python3.8/site-packages/tensorflow/python/keras/layers/serialization.py", line 173, in deserialize
    return generic_utils.deserialize_keras_object(
  File "/home/ikhatri/miniconda3/envs/quantize/lib/python3.8/site-packages/tensorflow/python/keras/utils/generic_utils.py", line 346, in deserialize_keras_object
    (cls, cls_config) = class_and_config_for_serialized_keras_object(
  File "/home/ikhatri/miniconda3/envs/quantize/lib/python3.8/site-packages/tensorflow/python/keras/utils/generic_utils.py", line 296, in class_and_config_for_serialized_keras_object
    raise ValueError('Unknown ' + printable_module_name + ': ' + class_name)
ValueError: Unknown layer: MnistFeedForward

Thank you in advance for any assistance you can provide :)