tensorflow / model-optimization

A toolkit to optimize ML models for deployment for Keras and TensorFlow, including quantization and pruning.
https://www.tensorflow.org/model_optimization
Apache License 2.0
1.49k stars 319 forks source link

Strange behaviour pruning with small datasets #975

Open niciwalter99 opened 2 years ago

niciwalter99 commented 2 years ago

Describe the bug Prune Low Magnitude seems not to update the weights to 0 (I am using Constant Sparsity), when using a small dataset for training (1000 Images).

System information

TensorFlow version: 2.8.1

TensorFlow Model Optimization version (installed from source or binary): 0.7.2

Python version: 3.8.10

Describe the expected behavior Pruning should be working as normal on small dataset as it is working on bigger datasets.

Describe the current behavior The weights are not updated to 0 after a model_for_pruning.fit run (see in Code Example). The exact same example works if you increase the size of the dataset (var dataset_size) to 10000 or change the batch_size to 16. I don't think, that this is intended when using the Constant Sparsity Feature or am I doing something wrong here? Code to reproduce the issue

# Random "images" for test prupose
dataset_size = 1000
train_images = np.random.rand(dataset_size, 128,128,3) 
train_labels = np.random.rand(dataset_size, 10)

model = keras.Sequential([
  keras.layers.InputLayer(input_shape=(128, 128,3)),
  keras.layers.Reshape(target_shape=(128, 128, 3)),
  keras.layers.Conv2D(filters=12, kernel_size=(3, 3), activation='relu'),
  keras.layers.MaxPooling2D(pool_size=(2, 2)),
  keras.layers.Flatten(),
  keras.layers.Dense(10)
])
prune_low_magnitude = tfmot.sparsity.keras.prune_low_magnitude

batch_size = 64
epochs = 2
validation_split = 0.1

pruning_params = {
      'pruning_schedule': tfmot.sparsity.keras.ConstantSparsity(0.8, 0)
}

model_for_pruning = prune_low_magnitude(model, **pruning_params)
model_for_pruning.compile(optimizer='adam',
              loss=tf.keras.losses.BinaryCrossentropy(from_logits=False),
              metrics=['accuracy'])

callbacks = [
  tfmot.sparsity.keras.UpdatePruningStep(),
  tfmot.sparsity.keras.PruningSummaries(log_dir=logdir),
]

model_for_pruning.fit(train_images, train_labels,
                  batch_size=batch_size, epochs=epochs, validation_split=validation_split,
                  callbacks=callbacks)

print(model_for_pruning.get_weights()) 
#Output contains no pruned weights / zero values
sngyhan commented 2 years ago

@rino20 Could you take a look?

YannPourcenoux commented 2 years ago

What if you look at the weights of the model after stripping the pruning with this?

rino20 commented 2 years ago

Hi @YannPourcenoux

Since you haven't set the pruning parameters, the default option is applied - ConstantSparsity, with pruning frequency 100. https://github.com/tensorflow/model-optimization/blob/master/tensorflow_model_optimization/python/core/sparsity/keras/pruning_schedule.py#L141

That means, your model will be pruned at every 100 steps. Your example will run less than 100 steps, (1000/64*2 < 100) so that's why you don't get pruned result (the training finishes before applying pruning)

When you have batchsize 16, it will over 100 steps (1000/16 *2 > 100) so it will be pruned well.

Hope this helps,

niciwalter99 commented 2 years ago

Okay that makes sense. But why does the 'frequency' argument exist at all in the ConstantSparsity function? It's useful in the PolynomialDecay function, but isn't it useless to have a frequency, when you just want a constant sparsity pruning and fine tune afterwards?