Open ikhatri opened 3 years ago
You could create a custom QuantizeConfig
that delegates the config for your standard Keras sublayers to existing QuantizeConfig instances.
Instead, a simpler way would be to quantize the sublayers based on type, as in the Quantize some layers example. Would that work for your use case ?
Could you provide an example on how to do the first option? That's exactly what I'd like to do, but I couldn't quite figure out how I would do so.
As far as the second option goes, I gave it a try but unfortunately it ends up being quite unwieldy. Here's an example to demonstrate. I was storing the sub-layers in my custom layer object but then instead of calling the layer directly I made a function that accepts them as args and calls them individually. While it works just fine, it makes the code a huge mess to read.
Here's the file where I define my custom keras layer (and the functional wrapper thing I mentioned above):
import numpy as np
import tensorflow as tf
from tensorflow.keras import layers
# Magic fix for RTX GPUs
gpus = tf.config.experimental.list_physical_devices("GPU")
for gpu in gpus:
tf.config.experimental.set_memory_growth(gpu, True)
class MnistFeedForward(layers.Layer):
def __init__(self, name: str = "mnist_feedforward", **kwargs) -> None:
super().__init__(name=name)
self._build()
def _build(self) -> None:
self.conv2d = layers.Conv2D(
filters=12,
kernel_size=(3, 3),
activation="relu",
name="MFF_conv2d",
)
self.pool = layers.MaxPool2D(pool_size=(2, 2), name="MFF_maxpool")
self.reshape = layers.Flatten()
def call(self, inputs: tf.Tensor) -> tf.Tensor:
with tf.compat.v1.variable_scope("subnet"):
x = self.conv2d(inputs)
x = self.pool(x)
x = self.reshape(x)
# This is a no-op, just to test if constant muls are supported by the quantization framework
mask = np.ones(x.shape[1], np.float32)
output = tf.constant(mask) * x
return output
def mnist_feed_fw_func(inputs, conv2d, pool, reshape):
with tf.compat.v1.variable_scope("subnet"):
x = conv2d(inputs)
x = pool(x)
x = reshape(x)
# This is a no-op, just to test if constant muls are supported by the quantization framework
mask = np.ones(x.shape[1], np.float32)
output = tf.constant(mask) * x
return output
and here's a small program to test the code above:
from pathlib import Path
import tensorflow as tf
import numpy as np
from tensorflow import keras
import tensorflow_model_optimization as tfmot
from network import MnistFeedForward, mnist_feed_fw_func
# Magic fix for RTX GPUs
gpus = tf.config.experimental.list_physical_devices("GPU")
for gpu in gpus:
tf.config.experimental.set_memory_growth(gpu, True)
# Load MNIST dataset
mnist = keras.datasets.mnist
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()
# Normalize the input image so that each pixel value is between 0 to 1.
train_images = train_images / 255.0
test_images = test_images / 255.0
# Define the model graph via the Keras Functional API
mff = MnistFeedForward()
inputs = keras.layers.Input(shape=(28, 28))
x = keras.layers.Reshape(target_shape=(28, 28, 1))(inputs)
x = mff(x) # This does not work
# x = mnist_feed_fw_func(x, mff.conv2d, mff.pool, mff.reshape) # replacing the line above with this, does work
outputs = keras.layers.Dense(10)(x)
model = keras.Model(inputs=inputs, outputs=outputs)
# Train the digit classification model
model.compile(
optimizer="adam",
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=["accuracy"],
)
model.summary()
print()
model.fit(
train_images,
train_labels,
epochs=1,
validation_split=0.1,
)
quantize_model = tfmot.quantization.keras.quantize_model
# Helper function uses `quantize_annotate_layer` to annotate that only supported layers should be quantized
def apply_quantization(layer):
registry = tfmot.quantization.keras.default_8bit.Default8BitQuantizeRegistry()
if registry.supports(layer):
return tfmot.quantization.keras.quantize_annotate_layer(layer)
return layer
# Use `tf.keras.models.clone_model` to apply `apply_quantization_to_dense`
# to the layers of the model.
custom_objects = {"mnist_feedforward": MnistFeedForward}
with tf.keras.utils.custom_object_scope(custom_objects):
annotated_model = tf.keras.models.clone_model(
model,
clone_function=apply_quantization,
)
# Now that the Dense layers are annotated,
# `quantize_apply` actually makes the model quantization aware.
q_aware_model = tfmot.quantization.keras.quantize_apply(annotated_model)
# `quantize_model` requires a recompile.
q_aware_model.compile(
optimizer="adam",
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=["accuracy"],
)
q_aware_model.summary()
print()
# Fine tune the quantized model
train_images_subset = train_images[0:1000] # out of 60000
train_labels_subset = train_labels[0:1000]
print("Fine-tuning the quantized model:")
q_aware_model.fit(
train_images_subset,
train_labels_subset,
batch_size=500,
epochs=1,
validation_split=0.1,
)
_, baseline_model_accuracy = model.evaluate(test_images, test_labels, verbose=0)
_, q_aware_model_accuracy = q_aware_model.evaluate(test_images, test_labels, verbose=0)
print("Baseline test accuracy:", baseline_model_accuracy)
print("Quant test accuracy:", q_aware_model_accuracy)
Running this code fails with the following traceback:
Traceback (most recent call last):
File "mnist_quantize.py", line 79, in <module>
q_aware_model = tfmot.quantization.keras.quantize_apply(annotated_model)
File "/home/ikhatri/miniconda3/envs/quantize/lib/python3.8/site-packages/tensorflow_model_optimization/python/core/keras/metrics.py", line 64, in inner
raise error
File "/home/ikhatri/miniconda3/envs/quantize/lib/python3.8/site-packages/tensorflow_model_optimization/python/core/keras/metrics.py", line 59, in inner
results = func(*args, **kwargs)
File "/home/ikhatri/miniconda3/envs/quantize/lib/python3.8/site-packages/tensorflow_model_optimization/python/core/quantization/keras/quantize.py", line 465, in quantize_apply
transformed_model, layer_quantize_map = quantize_transform.apply(
File "/home/ikhatri/miniconda3/envs/quantize/lib/python3.8/site-packages/tensorflow_model_optimization/python/core/quantization/keras/default_8bit/default_8bit_quantize_layout_transform.py", line 71, in apply
return model_transformer.ModelTransformer(
File "/home/ikhatri/miniconda3/envs/quantize/lib/python3.8/site-packages/tensorflow_model_optimization/python/core/quantization/keras/graph_transformations/model_transformer.py", line 613, in transform
transformed_model = keras.Model.from_config(self._config, custom_objects)
File "/home/ikhatri/miniconda3/envs/quantize/lib/python3.8/site-packages/tensorflow/python/keras/engine/training.py", line 2261, in from_config
return functional.Functional.from_config(
File "/home/ikhatri/miniconda3/envs/quantize/lib/python3.8/site-packages/tensorflow/python/keras/engine/functional.py", line 668, in from_config
input_tensors, output_tensors, created_layers = reconstruct_from_config(
File "/home/ikhatri/miniconda3/envs/quantize/lib/python3.8/site-packages/tensorflow/python/keras/engine/functional.py", line 1275, in reconstruct_from_config
process_layer(layer_data)
File "/home/ikhatri/miniconda3/envs/quantize/lib/python3.8/site-packages/tensorflow/python/keras/engine/functional.py", line 1257, in process_layer
layer = deserialize_layer(layer_data, custom_objects=custom_objects)
File "/home/ikhatri/miniconda3/envs/quantize/lib/python3.8/site-packages/tensorflow/python/keras/layers/serialization.py", line 173, in deserialize
return generic_utils.deserialize_keras_object(
File "/home/ikhatri/miniconda3/envs/quantize/lib/python3.8/site-packages/tensorflow/python/keras/utils/generic_utils.py", line 346, in deserialize_keras_object
(cls, cls_config) = class_and_config_for_serialized_keras_object(
File "/home/ikhatri/miniconda3/envs/quantize/lib/python3.8/site-packages/tensorflow/python/keras/utils/generic_utils.py", line 296, in class_and_config_for_serialized_keras_object
raise ValueError('Unknown ' + printable_module_name + ': ' + class_name)
ValueError: Unknown layer: MnistFeedForward
Thank you in advance for any assistance you can provide :)
I'm wondering how I can create a quantization configuration for a custom layer (which implements the keras layer class) where my custom layer is composed of other standard keras layers (such as Conv2D, BatchNorm etc.)
Assuming that every sub-layer in my custom one is either supported or can be skipped for quantization, can I supply a quantization config that just recursively checks the existing
default8bitquantizeregistry
or something like that?