tensorflow / model-optimization

A toolkit to optimize ML models for deployment for Keras and TensorFlow, including quantization and pruning.
https://www.tensorflow.org/model_optimization
Apache License 2.0
1.49k stars 319 forks source link

Unable to use strip_pruning for subclass model #965

Open LIU-FAYANG opened 2 years ago

LIU-FAYANG commented 2 years ago

Hi, is there any work around method to use tfmot.sparsity.keras.strip_pruning on subclass models? I tried to use prune_low_magnitude on subclass by apply it to the layers in subclasses as a work around method. So I would like to ask is there method like this to use strip_pruning on subclasses models? Any help would be great:)

Xhark commented 2 years ago

AFAIK evenif you apply prune_low_magnitude to sublayers of your subclass model, the pruning scheduling logic (callback based) during the training is not working due to it can't find your pruned layer. So I think your model is not trained as expected with pruning.

Unfortunately, subclass model supports is very weak in these days. Maybe there's some work around but it's still a brittle solution.

@rino20 Hi rino20, do you know any recent workaround for the subclass model supports? or any recommendation?

LIU-FAYANG commented 2 years ago

Hi @Xhark, thanks for your kind help. This page https://github.com/tensorflow/model-optimization/issues/155 mentioned apply the pruning API directly into subclasses is a workaround method. Currently I'm trying to prune a RNN-T model using prune_low_magnitude and the pruning result is indeed not good, the model tends not to converge and I'm not sure if it's because of this workaround method. Could u elaborate more about why the pruning scheduling logic is not working in this way? I tired to make to all of the layers to be pruned have the same pruning schedule so I'm a bit confused why the pruning scheduling logic is not working this way. Thanks for your kind help!

rino20 commented 2 years ago

Hi @LIU-FAYANG

  1. I am not sure how you apply the pruning to subclass model, but if you applied pruning for the keras Layers in the subclass models, it should be working.

  2. The general strip_pruning method wouldn't work in that case, since you applied pruning in custom way. You will need to create your own strip_pruning method which is the same way of your applying pruning logic.

  3. Could you elaborate more about "model is not converging"? If model is alternating a lot, longer pruning_frequency is recommended.

LIU-FAYANG commented 2 years ago

Hi @rino20, thanks for your help!

  1. yes I tried to apply prune_low_magnitude to the keras layers within the subclass model, take this dense layer as an example, https://github.com/TensorSpeech/TensorFlowASR/blob/main/tensorflow_asr/models/transducer/rnn_transducer.py#L63, I replace this line with self.projection = prune_low_magnitude(tf.keras.layers.Dense( dmodel, name=f"{self.name}_projection", kernel_regularizer=kernel_regularizer, bias_regularizer=bias_regularizer, ), pruning_schedule=[pruning_sched.ConstantSparsity(0.5, 0)]) Also, I applied the pruning API to LSTM layers since this is a RNN Transducer model most of the parameters are within LSTM and dense layers. May I ask is it the right way to use this to prune subclass models?
  2. yes I need to create my own strip pruning method but I don't think It's the same way of applying the pruning logic. For apply prune_low_magnitude API to the subclass model, the way we discussed is hacking into the subclass model and apply API to each layer. But after the model is trained it seems unable to apply strip pruning to each layer like what we do for adding pruning mask, could you elaborate more about your thought on this?
    1. I'm trying to prune an RNN Transducer model based on this repo: https://github.com/TensorSpeech/TensorFlowASR/tree/main/examples/rnn_transducer, with pruning the training loss is quite high which is roughly about 200+, I'm not sure why after apply pruning API the model has such high training loss. I tired to use low target sparsity level and prune one layer at a time but non of these methods helped. Also, I tired to set begin step not equal to zero but the training loss seems to explode right after the training starts. I'll try larger pruning_frequency and see if it'll help.

Thanks for your kind help!

rino20 commented 2 years ago

Hi Thanks for providing the details.

  1. Yes, the way you took is working. In detail, you can achieve your goal by 1) prune (pruning_wrapping) the target layer, and 2) build the subclass model.

  2. Your strip_pruning should do the same as #1. it should 1) get the pruning_wrapped layer -> apply strip_pruning 2) build the subclass model Or you can just swap the layer with stripped layer, however, note that this is not the best practice. run_transducer_block.projection = tfmot.strip_pruning(run_transducer_block.projection)

  3. The loss exploding even with begin_step != 0 seems weird. Could you test with 0 sparsity level and inspect the actual weight value during training?When you proceed training, the loss reduced?

LIU-FAYANG commented 2 years ago

Hi @rino20, thanks for your kind help, I'll try the method u mentioned above these days. Another thing I would like to ask is the model I'm working on seems have custom training loop, I didn't merge the pruning callbacks to the custom training loop but the model is still able to train, which is quite weird. I'd like to ask is it necessary to merge the callbacks to the training loop in this case? Or should I just add the callback list to the model.fit() like normal training? Maybe this is what Xhark means when he mentioned the concern of pruning callbacks for subclass models?

edit: I tired on a small end to end subclass model used in https://www.tensorflow.org/guide/keras/custom_layers_and_models, which has custom training loop for this particular model, I tired to add the pruning callback to the custom training loop by referring to the pruning comprehensive guide but the mask & weights seems does not change at all after pruning for a few epochs. May I ask is it because the pruning callback does not support subclass model?