tensorflow / model-optimization

A toolkit to optimize ML models for deployment for Keras and TensorFlow, including quantization and pruning.
https://www.tensorflow.org/model_optimization
Apache License 2.0
1.49k stars 319 forks source link

Stripping Quantized Model #958

Open Khalil-2020 opened 2 years ago

Khalil-2020 commented 2 years ago

Describe the bug I need help in this one please: I want to re-implent the "strip_pruning" function described in this link (line 222): https://github.com/tensorflow/model-optimization/blob/v0.7.2/tensorflow_model_optimization/python/core/sparsity/keras/prune.py#L222-L270 But I want this time to apply it to the quantized model so I can try to do the following: apply quantization to a model then stripping the quantized model so I can next apply the pruning (instead of applying pruning then stripping the model then applying quantization like in the guide in the tenorflow page)

Code to reproduce the issue def stripping_quantize(model): if not isinstance(model, keras.Model): raise ValueError( 'Expected model to be a tf.keras.Model instance but got: ', model)

def _strip_quant_wrap(layer): if isinstance(layer, tf.keras.Model): return keras.models.clone_model( layer, input_tensors=None, clone_function=_strip_quant_wrap) if (layer.class.name=="QuantizeWrapperV2"): if not hasattr(layer.layer, '_batch_input_shape') and hasattr( layer, '_batch_input_shape'): layer.layer._batch_input_shape = layer._batch_input_shape return layer.layer return layer

return keras.models.clone_model( model, input_tensors=None, clone_function=_strip_quant_wrap)


model_q=stripping_quantize(quantized_model) when I apply prunning to model_q, I get the following errors :

Screenshots

AttributeError Traceback (most recent call last) Input In [41], in <cell line: 1>() ----> 1 model_for_pruning = tf.keras.models.clone_model( 2 model_q, 3 clone_function=apply_pruning_to_layers,)

File ~\anaconda3\envs\base\lib\site-packages\keras\models.py:456, in clone_model(model, input_tensors, clone_function) 453 return _clone_sequential_model( 454 model, input_tensors=input_tensors, layer_fn=clone_function) 455 else: --> 456 return _clone_functional_model( 457 model, input_tensors=input_tensors, layer_fn=clone_function)

File ~\anaconda3\envs\base\lib\site-packages\keras\models.py:197, in _clone_functional_model(model, input_tensors, layer_fn) 193 model_configs, created_layers = _clone_layers_and_model_config( 194 model, new_input_layers, layer_fn) 195 # Reconstruct model from the config, using the cloned layers. 196 input_tensors, output_tensors, created_layers = ( --> 197 functional.reconstruct_from_config(model_configs, 198 created_layers=created_layers)) 199 metrics_names = model.metrics_names 200 model = Model(input_tensors, output_tensors, name=model.name)

File ~\anaconda3\envs\pbase\lib\site-packages\keras\engine\functional.py:1338, in reconstruct_from_config(config, custom_objects, created_layers) 1336 while layer_nodes: 1337 node_data = layer_nodes[0] -> 1338 if process_node(layer, node_data): 1339 layer_nodes.pop(0) 1340 else: 1341 # If a node can't be processed, stop processing the nodes of 1342 # the current layer to maintain node ordering.

File ~\anaconda3\envs\base\lib\site-packages\keras\engine\functional.py:1282, in reconstruct_from_config..process_node(layer, node_data) 1279 if not layer._preserve_input_structure_in_config: 1280 input_tensors = ( 1281 base_layer_utils.unnest_if_single_tensor(input_tensors)) -> 1282 output_tensors = layer(input_tensors, **kwargs) 1284 # Update node index map. 1285 output_index = (tf.nest.flatten(output_tensors)[0]. 1286 _keras_history.node_index)

File ~\anaconda3\envs\base\lib\site-packages\keras\utils\traceback_utils.py:67, in filter_traceback..error_handler(*args, **kwargs) 65 except Exception as e: # pylint: disable=broad-except 66 filtered_tb = _process_traceback_frames(e.traceback) ---> 67 raise e.with_traceback(filtered_tb) from None 68 finally: 69 del filtered_tb

File ~\anaconda3\envs\base\lib\site-packages\tensorflow\python\autograph\impl\api.py:692, in convert..decorator..wrapper(*args, **kwargs) 690 except Exception as e: # pylint:disable=broad-except 691 if hasattr(e, 'ag_error_metadata'): --> 692 raise e.ag_error_metadata.to_exception(e) 693 else: 694 raise

AttributeError: Exception encountered when calling layer "prune_low_magnitude_stem_conv" (type PruneLowMagnitude).

in user code:

File "C:\Users\ASUS\anaconda3\envs\base\lib\site-packages\tensorflow_model_optimization\python\core\sparsity\keras\pruning_wrapper.py", line 288, in call  *
    self.add_update(self.pruning_obj.weight_mask_op())
File "C:\Users\ASUS\anaconda3\envs\base\lib\site-packages\tensorflow_model_optimization\python\core\sparsity\keras\pruning_impl.py", line 254, in weight_mask_op  *
    return tf.group(self._weight_assign_objs())
File "C:\Users\ASUS\anaconda3\envs\base\lib\site-packages\tensorflow_model_optimization\python\core\sparsity\keras\pruning_impl.py", line 225, in update_var  *
    return tf_compat.assign(variable, reduced_value)
File "C:\Users\ASUS\anaconda3\envs\base\lib\site-packages\tensorflow_model_optimization\python\core\keras\compat.py", line 28, in assign  *
    return ref.assign(value, name=name)

AttributeError: 'Tensor' object has no attribute 'assign'

Call arguments received: • inputs=tf.Tensor(shape=(None, 151, 151, 3), dtype=float32) • training=False • kwargs=<class 'inspect._empty'>

Additional context If there is anyway to apply quantization then pruning in tensorflow I would like to know how. THANK YOU!

dina00 commented 1 year ago

From https://github.com/tensorflow/model-optimization/blob/master/tensorflow_model_optimization/python/core/quantization/keras/quantize.py You can use the following function to strip the quantization wrappers from model layers.

def extract_original_model(model_to_unwrap):
        """Extracts original model by removing wrappers."""
        layer_quantize_map = {}
        requires_output_quantize = set()

        def _unwrap(layer):
            #if not isinstance(layer, quantize_annotate_mod.QuantizeAnnotate):
            if not ("quant" in layer.name)  :
                return layer

            annotate_wrapper = layer
            # pylint: disable=protected-access
            if layer._inbound_nodes and len(layer._inbound_nodes) == 1:
                node = layer._inbound_nodes[0]
                inbound_layers = tf.nest.flatten(node.inbound_layers)
                if len(inbound_layers) == 1 and not isinstance(
                        inbound_layers[0], quantize_annotate_mod.QuantizeAnnotate):
                    requires_output_quantize.add(inbound_layers[0].name)

            layer_quantize_map[annotate_wrapper.layer.name] = {
                'quantize_config': annotate_wrapper.quantize_config
            }
            return annotate_wrapper.layer

        unwrapped_model = tf.keras.models.clone_model(
            model_to_unwrap, input_tensors=None, clone_function=_unwrap)

        return unwrapped_model