tensorflow / model-optimization

A toolkit to optimize ML models for deployment for Keras and TensorFlow, including quantization and pruning.
https://www.tensorflow.org/model_optimization
Apache License 2.0
1.48k stars 320 forks source link

Stripping disconnects input layer from graph #1063

Open christian-steinmeyer opened 1 year ago

christian-steinmeyer commented 1 year ago

Describe the bug Stripping the pruning layers seems to somehow disconnect the input layer from the graph.

System information

TensorFlow version (installed from source or binary): 2.11 (macos)

TensorFlow Model Optimization version (installed from source or binary): 0.7.4

Python version: 3.10

Describe the expected behavior Pruning a model during training, stripping the pruning layers, then creating a new model based on a subset of layers (e.g. to remove additional targets used during training) should work, if I didn't miss anything. Describe the current behavior It fails, although doing it in the order of pruning it, creating the model and then stripping works.

Code to reproduce the issue

import tempfile

import tensorflow as tf
import numpy as np

from tensorflow import keras
import tensorflow_model_optimization as tfmot

from src.common.path import MODELS_DIR

if __name__ == '__main__':
    # Load MNIST dataset
    mnist = keras.datasets.mnist
    (train_images, train_labels), (test_images, test_labels) = mnist.load_data()

    # Normalize the input image so that each pixel value is between 0 and 1.
    train_images = train_images / 255.0
    test_images = test_images / 255.0

    # Define the model architecture.
    model = keras.Sequential(
        [
            keras.layers.InputLayer(input_shape=(28, 28, 1)),
            keras.layers.Conv2D(filters=12, kernel_size=(3, 3), activation='relu'),
            keras.layers.MaxPooling2D(pool_size=(2, 2)),
            keras.layers.Flatten(),
            keras.layers.Dense(10),
        ]
    )
    model = tf.keras.Model(inputs=model.inputs, outputs=model.outputs)

    prune_low_magnitude = tfmot.sparsity.keras.prune_low_magnitude

    # Compute end step to finish pruning after 2 epochs.
    batch_size = 128
    epochs = 1
    validation_split = 0.1  # 10% of training set will be used for validation set.

    num_images = train_images.shape[0] * (1 - validation_split)
    end_step = np.ceil(num_images / batch_size).astype(np.int32) * epochs
    print("end step", end_step)
    # Define model for pruning.
    pruning_params = {
        'pruning_schedule': tfmot.sparsity.keras.PolynomialDecay(
            initial_sparsity=0.05,
            final_sparsity=0.95,
            begin_step=1,
            end_step=end_step,
            frequency=422,
        )
    }

    model_for_pruning = prune_low_magnitude(model, **pruning_params)

    # `prune_low_magnitude` requires a recompile.
    model_for_pruning.compile(
        optimizer='adam',
        loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
        metrics=['accuracy'],
    )

    model_for_pruning.summary()

    logdir = tempfile.mkdtemp()

    callbacks = [
        tfmot.sparsity.keras.UpdatePruningStep(),
        tfmot.sparsity.keras.PruningSummaries(log_dir=logdir),
    ]

    model_for_pruning.fit(
        train_images,
        train_labels,
        batch_size=batch_size,
        epochs=epochs,
        validation_split=validation_split,
        callbacks=callbacks,
    )

Given the above setup code, running the following snippet fails:

    pruned_model = tfmot.sparsity.keras.strip_pruning(model_for_pruning)
    inputs = [pruned_model.get_layer("input_1").input]
    outputs = pruned_model.get_layer("dense").output
    _new_model = tf.keras.Model(inputs=inputs, outputs=outputs)  # ValueError: Graph disconnected: cannot obtain value for tensor KerasTensor(type_spec=TensorSpec(shape=(None, 28, 28, 1), dtype=tf.float32, name='input_1'), name='input_1', description="created by layer 'input_1'") at layer "conv2d". The following previous layers were accessed without issue: []

while the following snippets works

    inputs = [model_for_pruning.get_layer("input_1").input]
    outputs = model_for_pruning.get_layer("prune_low_magnitude_dense").output
    _new_model = tf.keras.Model(inputs=inputs, outputs=outputs)
    _new_model = tfmot.sparsity.keras.strip_pruning(_new_model)
    pruned_model = tfmot.sparsity.keras.strip_pruning(model_for_pruning)
    inputs = [pruned_model.get_layer("conv2d").input]  # skipping the input layer
    outputs = pruned_model.get_layer("dense").output
    _new_model = tf.keras.Model(inputs=inputs, outputs=outputs)