tensorflow / model-optimization

A toolkit to optimize ML models for deployment for Keras and TensorFlow, including quantization and pruning.
https://www.tensorflow.org/model_optimization
Apache License 2.0
1.48k stars 320 forks source link

Allow quantization of tied weights #994

Open hunse opened 2 years ago

hunse commented 2 years ago

System information

Motivation

Numerous common networks use tied weights of some kind (i.e. the same weights used more than one place in the model), for example autoencoders or language models with shared embedding/deembedding weights. Currently, these models are not supported for quantization because quantize_apply uses keras.models.clone_model internally, which "will not preserve the uniqueness of shared objects within the model" (as per the docstring).

Describe the feature

The feature is to support quantization of models with tied weights. The same underlying variable would be used in multiple locations (as per the original unquantized model). Ideally, different layers using the same variable would have separate control over quantization (i.e. some layers could have the shared weights be quantized, while others have them unquantized).

Describe how the feature helps achieve the use case

I don't have a clear idea about how this feature should be implemented, which is the motivation for this issue. TensorFlow does have some support for serialization with shared objects (e.g. saving models in the "tf" format, which uses SharedObjectSavingScope internally), but I'm not sure if anything is compatible with clone_model, or if quantize_apply would have to be completely redone in a way that avoids clone_model.

Describe how existing APIs don't satisfy your use case (optional if obvious)

I've tried to quantize a model with shared weights, and run into various problems (depending on how exactly I do the sharing). All of these problems are expected, because clone_model does not support shared weights and will re-instantiate the model with separate variables for each location where the shared variable occurs.

js1010 commented 2 years ago

Hi hunse@, thanks for your input! We haven't considered this feature yet in shared layers, but will consider this to be included in our next batch of updates.

Thanks!

hunse commented 2 years ago

Thanks @js1010.

I was able to get something working for my own code base, by modifying the clone_model_with_weights function that's used in quantize_apply to use Keras's SharedObjectSavingScope and SharedObjectLoadingScope. I did also have to modify Keras's SharedObjectConfig to do self[generic_utils.SHARED_OBJECT_KEY] = self.object_id right away.

import tensorflow as tf
from keras.layers import deserialize as deserialize_layer
from keras.utils import generic_utils

class SharedObjectConfig(generic_utils.SharedObjectConfig):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

        # set object ID right away, so that objects can be shared even when part of
        # a "clone" operation with mixed order of serialization and deserialization
        self[generic_utils.SHARED_OBJECT_KEY] = self.object_id

class SharedObjectSavingScope(generic_utils.SharedObjectSavingScope):
    def create_config(self, base_config, obj):
        """Create a new SharedObjectConfig for a given object."""
        shared_object_config = SharedObjectConfig(base_config, self._next_id)
        self._next_id += 1
        try:
            self._shared_objects_config[obj] = shared_object_config
        except TypeError:
            # If the object is unhashable (e.g. a subclass of `AbstractBaseClass`
            # that has not overridden `__hash__`), a `TypeError` will be thrown.
            # We'll just continue on without shared object support.
            pass
        return shared_object_config

def clone_model_with_weights(model_to_clone):
    def clone_fn(layer):
        serial = generic_utils.serialize_keras_object(layer)
        return deserialize_layer(serial)

    with (
        SharedObjectSavingScope(),
        generic_utils.SharedObjectLoadingScope(),
    ):
        cloned_model = keras.models.clone_model(model_to_clone, clone_function=clone_fn)
    cloned_model.set_weights(model_to_clone.get_weights())

    return cloned_model

I'm using it to support a layer like this, which holds a reference to another layer (a tf.keras.layers.Embedding layer), and can then use the transpose of the weights for de-embedding.

import tensorflow as tf
from keras.layers import deserialize as deserialize_layer
from keras.utils.generic_utils import serialize_keras_object

class TiedDeembedding(tf.keras.layers.Layer):
    def __init__(self, embedding_layer, **kwargs):
        super().__init__(**kwargs)
        self.embedding_layer = embedding_layer

    def build(self, input_shape):
        super().build(input_shape)
        self.embeddings = self.embedding_layer.embeddings

    def call(self, x):
        return tf.matmul(x, self.embeddings, transpose_b=True)

    def get_config(self):
        config = super().get_config()
        config.update(
            dict(
                embedding_layer=serialize_keras_object(self.embedding_layer),
            )
        )
        return config

    @classmethod
    def from_config(cls, config, custom_objects=None):
        config = config.copy()
        embedding_layer = config["embedding_layer"]
        if isinstance(embedding_layer, dict):
            config["embedding_layer"] = deserialize_layer(
                embedding_layer, custom_objects=custom_objects
            )

        return cls(**config)