tensorflow / model-optimization

A toolkit to optimize ML models for deployment for Keras and TensorFlow, including quantization and pruning.
https://www.tensorflow.org/model_optimization
Apache License 2.0
1.49k stars 320 forks source link

Pruning error for transfer learning models from TF/Keras API #1017

Open frost-is opened 1 year ago

frost-is commented 1 year ago

System information

TensorFlow version (installed from source or binary): 2.10

TensorFlow Model Optimization version (installed from source or binary): 0.7.3

Python version: 3.8

Bug

Hi,

I have been experimenting with ILSVRC2012 lately, and tried pruning pre-trained models that I took from the Keras/TF API. I used the pruning examples from the tfmopt documentation and it worked well for my "homemade" NNs as well as Nasnet NNs. However, with all the other ones I tried (I did not try each and every one of them but still), I got the following error

ValueError: Please initializePrunewith a supported layer. Layers should either be supported by the PruneRegistry (built-in keras layers) or should be aPrunableLayerinstance, or should has a customer definedget_prunable_weightsmethod. You passed: <class 'keras.layers.preprocessing.image_preprocessing.Rescaling'>

It apparently should have been fixed by this PR, yet, well, here we are. It seems I am not the only one who had this problem.

You can find the code I used on this colab. I included an example that worked for reference's purpose.

Thank you for your time !

dansuh17 commented 1 year ago

Hi @rino20 , could you take a look at this issue?

rino20 commented 1 year ago

Hi, sorry for late reply.

I think there are probably some updates in keras layer dependencies but it is not well applied in PruneRegistry. Will take a look when the updates are settled down, we will update it.

Before that, you can temporarily use below code to enabling your pruning work. def get_prunable_weights(): return []

for layer in mobilenet_small.layers: if layer.class.name in ['TFOpLambda', 'Rescaling']: layer.get_prunable_weights = get_prunable_weights

Hope this helps.