tensorflow / model-optimization

A toolkit to optimize ML models for deployment for Keras and TensorFlow, including quantization and pruning.
https://www.tensorflow.org/model_optimization
Apache License 2.0
1.49k stars 319 forks source link

Yamnet clustering AttributeError: Exception encountered when calling layer "tf.__operators__.add" (type TFOpLambda). #972

Closed MATTYGILO closed 2 years ago

MATTYGILO commented 2 years ago

Prior to filing: check that this should be a bug instead of a feature request. Everything supported, including the compatible versions of TensorFlow, is listed in the overview page of each technique. For example, the overview page of quantization-aware training is here. An issue for anything not supported should be a feature request.

Describe the bug I am using clustering to reduce the size of a yamnet model

System information

TensorFlow version (installed from source or binary): 2.7.0

TensorFlow Model Optimization version (installed from source or binary): Installed using wheel from GitHub on an m1 Mac

Python version: 3.10

Describe the expected behavior final_model = tfmot.clustering.keras.strip_clustering(clustered_model) Strip the clustered model

Describe the current behavior Failing with this error

---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
Input In [10], in <cell line: 1>()
----> 1 final_model = tfmot.clustering.keras.strip_clustering(clustered_model)
      3 final_model.summary()

File ~/.local/lib/python3.10/site-packages/tensorflow_model_optimization/python/core/clustering/keras/cluster.py:356, in strip_clustering(model)
    353   return layer
    355 # Just copy the model with the right callback
--> 356 return tf.keras.models.clone_model(
    357     model, input_tensors=None, clone_function=_strip_clustering_wrapper)

File ~/miniforge3/envs/net/lib/python3.10/site-packages/keras/models.py:456, in clone_model(model, input_tensors, clone_function)
    453   return _clone_sequential_model(
    454       model, input_tensors=input_tensors, layer_fn=clone_function)
    455 else:
--> 456   return _clone_functional_model(
    457       model, input_tensors=input_tensors, layer_fn=clone_function)

File ~/miniforge3/envs/net/lib/python3.10/site-packages/keras/models.py:197, in _clone_functional_model(model, input_tensors, layer_fn)
    193 model_configs, created_layers = _clone_layers_and_model_config(
    194     model, new_input_layers, layer_fn)
    195 # Reconstruct model from the config, using the cloned layers.
    196 input_tensors, output_tensors, created_layers = (
--> 197     functional.reconstruct_from_config(model_configs,
    198                                        created_layers=created_layers))
    199 metrics_names = model.metrics_names
    200 model = Model(input_tensors, output_tensors, name=model.name)

File ~/miniforge3/envs/net/lib/python3.10/site-packages/keras/engine/functional.py:1338, in reconstruct_from_config(config, custom_objects, created_layers)
   1336 while layer_nodes:
   1337   node_data = layer_nodes[0]
-> 1338   if process_node(layer, node_data):
   1339     layer_nodes.pop(0)
   1340   else:
   1341     # If a node can't be processed, stop processing the nodes of
   1342     # the current layer to maintain node ordering.

File ~/miniforge3/envs/net/lib/python3.10/site-packages/keras/engine/functional.py:1282, in reconstruct_from_config.<locals>.process_node(layer, node_data)
   1279 if not layer._preserve_input_structure_in_config:
   1280   input_tensors = (
   1281       base_layer_utils.unnest_if_single_tensor(input_tensors))
-> 1282 output_tensors = layer(input_tensors, **kwargs)
   1284 # Update node index map.
   1285 output_index = (tf.nest.flatten(output_tensors)[0].
   1286                 _keras_history.node_index)

File ~/miniforge3/envs/net/lib/python3.10/site-packages/keras/utils/traceback_utils.py:67, in filter_traceback.<locals>.error_handler(*args, **kwargs)
     65 except Exception as e:  # pylint: disable=broad-except
     66   filtered_tb = _process_traceback_frames(e.__traceback__)
---> 67   raise e.with_traceback(filtered_tb) from None
     68 finally:
     69   del filtered_tb

File ~/miniforge3/envs/net/lib/python3.10/site-packages/tensorflow/python/ops/math_ops.py:1733, in _add_dispatch(x, y, name)
   1712 """The operation invoked by the `Tensor.__add__` operator.
   1713 
   1714   Purpose in the API:
   (...)
   1729   The result of the elementwise `+` operation.
   1730 """
   1731 if not isinstance(y, ops.Tensor) and not isinstance(
   1732     y, sparse_tensor.SparseTensor):
-> 1733   y = ops.convert_to_tensor(y, dtype_hint=x.dtype.base_dtype, name="y")
   1734 if x.dtype == dtypes.string:
   1735   return gen_math_ops.add(x, y, name=name)

Code to reproduce the issue

... My custom yamnet, perhaps get one from tf hub ...

import tensorflow_model_optimization as tfmot

cluster_weights = tfmot.clustering.keras.cluster_weights
CentroidInitialization = tfmot.clustering.keras.CentroidInitialization

clustering_params = {
  'number_of_clusters': 16,
  'cluster_centroids_init': CentroidInitialization.LINEAR
}

# Cluster a whole model
clustered_model = cluster_weights(model, **clustering_params)

clustered_model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=0.00001),
    loss="categorical_crossentropy",
    metrics=["accuracy"],
    run_eagerly=True
)
clustered_model.summary()

final_model = tfmot.clustering.keras.strip_clustering(clustered_model)

final_model.summary()

Similar problem. Mine is for clustering though:

867

MATTYGILO commented 2 years ago

Ok I've fixed the problem using this code:

def cluster_model(model):

    # Info required for clustering
    cluster_weights = tfmot.clustering.keras.cluster_weights
    CentroidInitialization = tfmot.clustering.keras.CentroidInitialization
    clustering_params = {
        'number_of_clusters': 16,
        'cluster_centroids_init': CentroidInitialization.LINEAR
    }

    def apply_clustering_to_dense(layer):

        accepted_layers = [
            tf.keras.layers.Dense,
            tf.keras.layers.GlobalAveragePooling2D,
            tf.keras.layers.ReLU,
            tf.keras.layers.Conv2D,
            tf.keras.layers.BatchNormalization,
            tf.keras.layers.DepthwiseConv2D,
            tf.keras.layers.Activation,
        ]

        for accepted in accepted_layers:
            if isinstance(layer, accepted):
                return cluster_weights(layer, **clustering_params)

        return layer

    return tf.keras.models.clone_model(
        model,
        clone_function=apply_clustering_to_dense,
    )