tensorflow / model-optimization

A toolkit to optimize ML models for deployment for Keras and TensorFlow, including quantization and pruning.
https://www.tensorflow.org/model_optimization
Apache License 2.0
1.48k stars 320 forks source link

Cannot save pruned model with MultiHeadAttention Layer #1077

Open christian-steinmeyer opened 1 year ago

christian-steinmeyer commented 1 year ago

Describe the bug Trying to save a model that wraps a MultiHeadAttention layer in a PruneLowMagnitude, fails with duplicate dataset name.

System information

TensorFlow version (installed from source or binary): 2.13.0rc1

TensorFlow Model Optimization version (installed from source or binary): 0.7.5

Python version: 3.10

Describe the expected behavior Successful model save.

Describe the current behavior When saving a pruned model, I get a ValueError: Unable to create dataset (name already exists) on "mask:0".

Code to reproduce the issue

import tensorflow as tf
import tensorflow_model_optimization as tfmot
import tempfile

if __name__ == '__main__':
    # model
    inputs = tf.keras.layers.Input(shape=(28, 28, 3))
    x = tf.keras.layers.Conv2D(filters=128, kernel_size=3, activation='relu')(inputs)
    x = tf.keras.layers.MultiHeadAttention(num_heads=4, key_dim=128)(query=x, value=x, key=x)
    outputs = tf.keras.layers.Flatten()(x)
    model = tf.keras.Model(inputs=inputs, outputs=outputs)

    model.compile(optimizer='adam', loss='mse')

    # call model to initialize weights
    model(tf.ones((1, 28, 28, 3)))

    # prune model
    pruning_params = {
        'pruning_schedule': tfmot.sparsity.keras.PolynomialDecay(
            initial_sparsity=0.5,
            final_sparsity=0.9,
            begin_step=0,
            end_step=1,
            frequency=1,
        ),
    }
    model_for_pruning = tfmot.sparsity.keras.prune_low_magnitude(model, **pruning_params)

    with tempfile.TemporaryDirectory() as temp_dir:
        model_for_pruning.save(temp_dir + '/model.h5')  # <-- fails

Potentially related to #661 and #944.

christian-steinmeyer commented 1 year ago

Tagging @Xhark as you have worked on similar issues in the past

dansuh17 commented 10 months ago

@Xhark Could you follow-up on this one?