Open PhoeniXuzoo opened 3 years ago
Currently we are working on this @fredrec has more details
So pruning doesn't work if we are not able to retrain the model after adding the prune_low_magnitude layer ?
Hello guys, any news on the topic ?
Hello.
Here is a workaround: Create a dummy input/output that matches the input/output size of your model. Set the learning rate to zero (SGD as your optimizer) and train the model. This way, the weights won't change, but they will get pruned.
def post_training_pruning(model, target_sparsity):
"""
Prune the model using post-training pruning.
Args:
model (tf.keras.Model): The model to prune.
target_sparsity (float): The target sparsity.
Returns:
tf.keras.Model: The pruned model.
"""
pruning_schedule = tfmot.sparsity.keras.ConstantSparsity(target_sparsity=target_sparsity, begin_step=0, frequency=1)
model_for_pruning = tfmot.sparsity.keras.prune_low_magnitude(model, pruning_schedule=pruning_schedule)
train_x = np.random.rand(1, *model.input.shape[1:])
train_y = np.random.rand(1, *model.output.shape[1:])
callbacks = [tfmot.sparsity.keras.UpdatePruningStep()]
opt = tf.keras.optimizers.SGD(learning_rate=0)
model_for_pruning.compile(optimizer=opt, loss='mse')
model_for_pruning.fit(train_x, train_y, epochs=1, callbacks=callbacks)
pruned_model = tfmot.sparsity.keras.strip_pruning(model_for_pruning)
return pruned_model
MacOS
Motivation
The document just shows how to prune the model during training. Each time, we need to call model.fit. During the process, other weights will be modified. We just want to set some weights to zero. In the meantime, don't change other weights.
The feature have been implemented in Pytorch.