tensorflow / model-optimization

A toolkit to optimize ML models for deployment for Keras and TensorFlow, including quantization and pruning.
https://www.tensorflow.org/model_optimization
Apache License 2.0
1.49k stars 321 forks source link

Is it possible to implement post-training pruning? #621

Open PhoeniXuzoo opened 3 years ago

PhoeniXuzoo commented 3 years ago

MacOS

Motivation

The document just shows how to prune the model during training. Each time, we need to call model.fit. During the process, other weights will be modified. We just want to set some weights to zero. In the meantime, don't change other weights.

The feature have been implemented in Pytorch.

daverim commented 3 years ago

Currently we are working on this @fredrec has more details

Gojo1729 commented 3 years ago

So pruning doesn't work if we are not able to retrain the model after adding the prune_low_magnitude layer ?

Assia17 commented 2 years ago

Hello guys, any news on the topic ?

Black3rror commented 2 months ago

Hello.

Here is a workaround: Create a dummy input/output that matches the input/output size of your model. Set the learning rate to zero (SGD as your optimizer) and train the model. This way, the weights won't change, but they will get pruned.

def post_training_pruning(model, target_sparsity):
    """
    Prune the model using post-training pruning.

    Args:
        model (tf.keras.Model): The model to prune.
        target_sparsity (float): The target sparsity.

    Returns:
        tf.keras.Model: The pruned model.
    """
    pruning_schedule = tfmot.sparsity.keras.ConstantSparsity(target_sparsity=target_sparsity, begin_step=0, frequency=1)
    model_for_pruning = tfmot.sparsity.keras.prune_low_magnitude(model, pruning_schedule=pruning_schedule)

    train_x = np.random.rand(1, *model.input.shape[1:])
    train_y = np.random.rand(1, *model.output.shape[1:])

    callbacks = [tfmot.sparsity.keras.UpdatePruningStep()]
    opt = tf.keras.optimizers.SGD(learning_rate=0)

    model_for_pruning.compile(optimizer=opt, loss='mse')
    model_for_pruning.fit(train_x, train_y, epochs=1, callbacks=callbacks)

    pruned_model = tfmot.sparsity.keras.strip_pruning(model_for_pruning)

    return pruned_model