tensorflow / model-optimization

A toolkit to optimize ML models for deployment for Keras and TensorFlow, including quantization and pruning.
https://www.tensorflow.org/model_optimization
Apache License 2.0
1.49k stars 323 forks source link

Error reported when prunning one tf-keras layer #662

Open iamweiweishi opened 3 years ago

iamweiweishi commented 3 years ago

Prior to filing: check that this should be a bug instead of a feature request. Everything supported, including the compatible versions of TensorFlow, is listed in the overview page of each technique. For example, the overview page of quantization-aware training is here. An issue for anything not supported should be a feature request.

Describe the bug My model is built on subclasses of Keras Models, so I cannot prune the model directly refered to this. To get around this, I only prune the tf-keras-layers inside the model like:

  .......
  self.conv_1 = tf.keras.layers.Conv2D(
        filters=..., kernel_size=..., strides=...,
        padding="valid", name="pw_conv_1",
        kernel_regularizer=kernel_regularizer,
        bias_regularizer=bias_regularizer
    )

    from tensorflow_model_optimization.sparsity import keras as sparsity
    pruning_prm = {
        'pruning_schedule': sparsity.PolynomialDecay(initial_sparsity=0, final_sparsity=0.5, begin_step=0,
                                                     end_step=75000, frequency=100)}
    self.conv_1 = sparsity.prune_low_magnitude(self.conv_1, **pruning_prm)
  ......

Then, Error reported:

(0) Invalid argument: assertion failed: [Prune() wrapper requires the UpdatePruningStep callback to be provided during training. Please add it as a callback to your model.fit call.] [Condition x >= y did not hold element-wise:] [x (conformer/conformer_encoder/conformer_encoder_block_0/conformer_encoder_block_0_conv_module/prune_low_magnitude_conformer_encoder_block_0_conv_module_pw_conv_1/assert_greater_equal/ReadVariableOp:0) = ] [-1] [y (conformer/conformer_encoder/conformer_encoder_block_0/conformer_encoder_block_0_conv_module/prune_low_magnitude_conformer_encoder_block_0_conv_module_pw_conv_1/assert_greater_equal/y:0) = ] [0] [[{{node conformer/conformer_encoder/conformer_encoder_block_0/conformer_encoder_block_0_conv_module/prune_low_magnitude_conformer_encoder_block_0_conv_module_pw_conv_1/assert_greater_equal/Assert/AssertGuard/else/_33/conformer/conformer_encoder/conformer_encoder_block_0/conformer_encoder_block_0_conv_module/prune_low_magnitude_conformer_encoder_block_0_conv_module_pw_conv_1/assert_greater_equal/Assert/AssertGuard/Assert}}]] [[conformer/conformer_encoder/conformer_encoder_block_2/conformer_encoder_block_2_conv_module/prune_low_magnitude_conformer_encoder_block_2_conv_module_pw_conv_1/assert_greater_equal/Assert/AssertGuard/output/_120/_131]] (1) Invalid argument: assertion failed: [Prune() wrapper requires the UpdatePruningStep callback to be provided during training. Please add it as a callback to your model.fit call.] [Condition x >= y did not hold element-wise:] [x (conformer/conformer_encoder/conformer_encoder_block_0/conformer_encoder_block_0_conv_module/prune_low_magnitude_conformer_encoder_block_0_conv_module_pw_conv_1/assert_greater_equal/ReadVariableOp:0) = ] [-1] [y (conformer/conformer_encoder/conformer_encoder_block_0/conformer_encoder_block_0_conv_module/prune_low_magnitude_conformer_encoder_block_0_conv_module_pw_conv_1/assert_greater_equal/y:0) = ] [0] [[{{node conformer/conformer_encoder/conformer_encoder_block_0/conformer_encoder_block_0_conv_module/prune_low_magnitude_conformer_encoder_block_0_conv_module_pw_conv_1/assert_greater_equal/Assert/AssertGuard/else/_33/conformer/conformer_encoder/conformer_encoder_block_0/conformer_encoder_block_0_conv_module/prune_low_magnitude_conformer_encoder_block_0_conv_module_pw_conv_1/assert_greater_equal/Assert/AssertGuard/Assert}}]] 0 successful operations. 0 derived errors ignored. [Op:__inference_train_function_264461]

Function call stack: train_function -> train_function

Even if I added 'tfmot.sparsity.keras.UpdatePruningStep()' into callbacks, the error still remains.

System information

TensorFlow version (installed from source or binary): pip install tensorflow-gpu==2.3.0

TensorFlow Model Optimization version (installed from source or binary):0.5.0

Python version:3.6

hajarasgari commented 8 months ago

Any update on this subject, I'm getting the same error... Appreciate any help!