tensorflow / model-optimization

A toolkit to optimize ML models for deployment for Keras and TensorFlow, including quantization and pruning.
https://www.tensorflow.org/model_optimization
Apache License 2.0
1.48k stars 320 forks source link

Add a default PruningPolicy that filters out any layers not supported by the API #1098

Open annietllnd opened 9 months ago

annietllnd commented 9 months ago

Hi TFMOT team! I have created a workaround for an edge-case in my project, and in my head it should be possible to have it be the default behavior in the API. Creating this feature request as a suggestion - let me know what you think!

System information

Running on TF 2.11. Unfortunately, I currently don't have the bandwidth to contribute the feature request.

Motivation

This feature request is for an implementation of the PruningPolicy that allows pruning for layers that are supported by the PruningRegistry.

Short background. When calling the prune_low_magnitude or similar functions, it's possible to ignore certain layers according to a pruning policy. By implementing the abstract class PruningPolicy, you can check that the model and layers fulfill certain requirements. One built-in implementation of this already exists, namely the PruneForLatencyOnXNNPack. A call can look like this:

model = prune_low_magnitude(
      keras.Sequential([
          layers.Dense(10, activation='relu', input_shape=(100,)),
          layers.Dense(2, activation='sigmoid')
      ]),
      pruning_policy=PruneForLatencyOnXNNPack(),
      **pruning_params)

Currently, since no pruning policy is default, the API will try and prune layers that are not compatible. In one of our use-cases, we had to implement a policy that checks if the layer is supported (to avoid trying to prune a TFOpLambda layer, as an example). Explicitly, we are safeguarding the API by skipping data that itself knows that it doesn't support. I'm suggesting to add another implementation to the API, which simply calls the supports function linked above. If possible, I'd also use it as a default value for the pruning_policy parameter (this part may come with additional issues for some use-cases, so that would be an optional part of this feature request).

I think it would help people using the API to avoid confusing bugs in edge-cases. If there's an existing way to do this that we have overlooked, I'm happy to get that feedback. Let me know if the suggestion needs elaboration.

Thank you for your time! Annie

doyeonkim0 commented 9 months ago

@cdh4696 Could you take a look at this? Thank you! :)