tensorflow / model-optimization

A toolkit to optimize ML models for deployment for Keras and TensorFlow, including quantization and pruning.
https://www.tensorflow.org/model_optimization
Apache License 2.0
1.49k stars 319 forks source link

Pruning only works for small batch sizes #973

Open vvolhejn opened 2 years ago

vvolhejn commented 2 years ago

Describe the bug When using prune_low_magnitude(), my model is not pruned if the batch size is low.

System information

TensorFlow version (installed from source or binary): 2.8.0 installed via pip

TensorFlow Model Optimization version (installed from source or binary): 0.7.2 installed via pip

Python version: 3.9.10

Describe the expected behavior

model_for_pruning.fit should sparsify the model independent of the batch size.

Describe the current behavior

If the batch size is larger than 2 (this is the threshold in my example, at least), the network is not pruned.

Code to reproduce the issue

Based on the Pruning with Keras tutorial.

import tempfile

import numpy as np
import tensorflow as tf
import tensorflow_model_optimization as tfmot

def main(batch_size):
    model = tf.keras.Sequential([
        tf.keras.layers.InputLayer(input_shape=(28, 28)),
        tf.keras.layers.Reshape(target_shape=(28, 28, 1)),
        tf.keras.layers.Conv2D(filters=12, kernel_size=(3, 3), activation='relu'),
        tf.keras.layers.MaxPooling2D(pool_size=(2, 2)),
        tf.keras.layers.Flatten(),
        tf.keras.layers.Dense(10)
    ])

    model.compile(
          loss=tf.keras.losses.MeanSquaredError(),
          optimizer='adam',
          metrics=['accuracy']
    )

    model_for_pruning = tfmot.sparsity.keras.prune_low_magnitude(model)

    log_dir = tempfile.mkdtemp()
    callbacks = [
        tfmot.sparsity.keras.UpdatePruningStep(),
    ]

    model_for_pruning.compile(
          loss=tf.keras.losses.MeanSquaredError(),
          optimizer='adam',
          metrics=['accuracy']
    )

    model_for_pruning.fit(
        np.random.randn(100, 28, 28).astype(np.float32),
        np.random.randn(100, 10).astype(np.float32),
        callbacks=callbacks,
        epochs=2,
        batch_size=batch_size,
        # validation_split=0.1,
        verbose=0,
    )

    weights = model_for_pruning.get_weights()[1]
    # A sanity check to show we're looking at the right weights.
    print(f"(Checking weights of shape {weights.shape})")
    # What part of the weights are zeros?
    print(
        f"Sparsity with batch size {batch_size}:",
        (weights == 0).mean(),
    )

main(batch_size=1)
main(batch_size=2)
main(batch_size=3)
main(batch_size=32)

This prints:

2022-05-25 16:44:16.340347: I tensorflow/core/platform/cpu_feature_guard.cc:151] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
/Users/vaclav/prog/venv/lib/python3.9/site-packages/tensorflow_model_optimization/python/core/sparsity/keras/pruning_wrapper.py:233: UserWarning: `layer.add_variable` is deprecated and will be removed in a future version. Please use `layer.add_weight` method instead.
  self.pruning_step = self.add_variable(
/Users/vaclav/prog/venv/lib/python3.9/site-packages/tensorflow_model_optimization/python/core/sparsity/keras/pruning_wrapper.py:212: UserWarning: `layer.add_variable` is deprecated and will be removed in a future version. Please use `layer.add_weight` method instead.
  mask = self.add_variable(
/Users/vaclav/prog/venv/lib/python3.9/site-packages/tensorflow_model_optimization/python/core/sparsity/keras/pruning_wrapper.py:219: UserWarning: `layer.add_variable` is deprecated and will be removed in a future version. Please use `layer.add_weight` method instead.
  threshold = self.add_variable(
(Checking weights of shape (3, 3, 1, 12))
Sparsity with batch size 1: 0.5
(Checking weights of shape (3, 3, 1, 12))
Sparsity with batch size 2: 0.5
(Checking weights of shape (3, 3, 1, 12))
Sparsity with batch size 3: 0.0
(Checking weights of shape (3, 3, 1, 12))
Sparsity with batch size 32: 0.0

So when the batch size is 1 or 2, everything works fine. But for anything larger, the model is not pruned.

thaink commented 2 years ago

@rino20 Could you take a look?

rino20 commented 2 years ago

Hi @vvolhejn ,

Since you haven't set the pruning parameters, the default option is applied - ConstantSparsity, with pruning frequency 100. https://github.com/tensorflow/model-optimization/blob/master/tensorflow_model_optimization/python/core/sparsity/keras/pruning_schedule.py#L141

That means, your model will be pruned at every 100 steps. Your example will run less than 100 steps if batchsize is larger than 2, so that's why you don't get pruned result (the training finishes before applying pruning)

Hope this helps,

vvolhejn commented 2 years ago

Thank you for clearing this up. It makes sense explained like this, but it feels a bit unintuitive to me :/ The docs say "frequency: Only apply pruning every frequency steps." which doesn't seem to imply the first pruning happens after frequency steps.

rino20 commented 2 years ago

Sorry for the confusion. We wished that "Only" in the sentence implies it, but that might not be enough.