tensorflow / model-optimization

A toolkit to optimize ML models for deployment for Keras and TensorFlow, including quantization and pruning.
https://www.tensorflow.org/model_optimization
Apache License 2.0
1.48k stars 320 forks source link

Pruning: Keras subclassed model increased support #155

Open alanchiao opened 4 years ago

alanchiao commented 4 years ago

Currently the pruning API will throw an error when a subclassed model is passed to it. Users can get around this by diving into the subclassed models and applying pruning to individual Sequential/Functional models and tf.keras layers.

Better support is important for various cases (e.g. Object Detection, BERT examples) and issues such as this one.

We can provide better support for pruning an entire subclassed model.

Implementation-wise, we can iterate through the layers of a subclassed model (and nested models) and applying pruning to all of them. Replacing a layer in an already created model will be tricky and we'd have to do this without clone_model.

nutsiepully commented 4 years ago

It seems to me we are mixing a few issues. I want to make sure I understand the problem correctly. Please correct me if I'm wrong.

Issue is about handling models recursively - a Keras Model which contains another model, not about subclass models.

If one of the layers in the model is Sequential, or Functional we should ideally traverse that model further and prune all the layers within them.

So, if the user provides the the following.

model1 = tf.keras.Sequential([tf.keras.layers.Dense(3)])
model2 = tf.keras.Sequential([model1, tf.keras.layers.Dense(2)])

we should prune both the Dense layers, and recursively traverse the model. I agree with that.

I am not quite sure how the subclass models play a role here. It is possible a user combines a keras Sequential, Functional and Layer within a subclass model. In that case, the user has to explicitly prune whatever they want.

Since we can't reliably clone a subclass model, it isn't possible for us to apply pruning to the model.

Hackerman28 commented 3 years ago

Hi @nutsiepully. I am currently having problems with pruning nested keras models. My model is based on the matterport's mask-rcnn repo. I have a model which contains an inner model which is called multiple times within the outer model but has shared weights. The inner model definition is defined here. So when I'm using prune_low_magnitude API for pruning, its giving me the following error ValueError: Please initialize Prune with a supported layer. Layers should either be a PrunableLayer instance, or should be supported by the PruneRegistry. You passed: <class 'tensorflow.python.keras.engine.functional.Functional'> So I defined a clone_function which returns prune_low_magnitude(inner_model). def clone_func(layer): if isinstance(layer, tensorflow.python.keras.engine.base_layer.AddLoss): return layer if isinstance(layer, tensorflow.python.keras.engine.base_layer.AddMetric): return layer if isinstance(layer, tensorflow.python.keras.engine.functional.Functional): return return tfmot.sparsity.keras.prune_low_magnitude(layer, pruning_params) return tfmot.sparsity.keras.prune_low_magnitude(layer, pruning_params) In the above code when the inner functional model is passed as layer, it is called with prune_low_magnitude API to prune the inner model or that's what I hoped for. But I'm getting the following error. Traceback (most recent call last): File "nucoco.py", line 510, in augmentation=augmentation File "/backup/Radar-RGB-Attentive-Multimodal-Object-Detection/Radar_RGB_Camera_Object_Detection/mrcnn/model.py", line 2972, in prune_train self.keras_model = tf.keras.models.clone_model(self.keras_model, clone_function=clone_func) File "/home/mcw/miniconda3/envs/lib/python3.6/site-packages/tensorflow/python/keras/models.py", line 429, in clone_model model, input_tensors=input_tensors, layer_fn=clone_function) File "/home/mcw/miniconda3/envs/lib/python3.6/site-packages/tensorflow/python/keras/models.py", line 201, in _clone_functional_model created_layers=created_layers)) File "/home/mcw/miniconda3/envs/lib/python3.6/site-packages/tensorflow/python/keras/engine/functional.py", line 1214, in reconstruct_from_config process_node(layer, node_data) File "/home/mcw/miniconda3/envs/lib/python3.6/site-packages/tensorflow/python/keras/engine/functional.py", line 1162, in process_node output_tensors = layer(input_tensors, **kwargs) UnboundLocalError: local variable 'kwargs' referenced before assignment Can u help me with this error?

teijeong commented 3 years ago

Hi @liyunlu0618 , can you check the current status?

liyunlu0618 commented 3 years ago

We recently added support for pruning nested models, see this PR.

For subclass models, since keras doesn't support cloning, we still don't have a model-level API. You can still re-construct the model and wrapper the layers to prune with the pruning API.

didadida-r commented 2 years ago

it seems subclass prune still not support in 0.7.1

gnhearx commented 2 years ago

Good day everyone :)

I would also like to chip in to this discussion as I too have been struggling to get Nested Model Pruning to work correctly.

I couldn't help noticing that support has been added but for some reason my implementation does not seem to work. Expected behaviour:

Observed behaviour:

It seems to me like the pruning is not even taking place. However, if I instantiate a model without nested models and run the exact same logic pipeline, then pruning acts exactly as expected and all layers are pruned successfully.

Am I perhaps missing a step that is not mentioned in the documentation for tensorflow/keras model pruning? Any help would be greatly appreciated.

Side notes:

LIU-FAYANG commented 2 years ago

@alanchiao Hi, recently I'm working on pruning subclass models and the model seems not able to converge with pruning. I tired to apply the prune_low_magnitude API directly on the layers to be pruned within the subclasses and the pruning schedule applied to each layer is the same. https://github.com/tensorflow/model-optimization/issues/965 mentioned pruning callback might cause this issue if we use this as a workaround method, could you share your opinion on this? Is there any other things that I need to do to use the workaround method you mentioned to prune a subclass model? I think I missed out some steps to make this workaround method work. Thanks for your kind help!