tensorflow / model-optimization

A toolkit to optimize ML models for deployment for Keras and TensorFlow, including quantization and pruning.
https://www.tensorflow.org/model_optimization
Apache License 2.0
1.5k stars 323 forks source link

Support for Recurrent layers for Quantization Aware Training. #1114

Open parth-desai opened 9 months ago

parth-desai commented 9 months ago

System information

Motivation I am trying to train RNN model with quantization aware training for embedded devices.

Describe the feature I am looking for a way to train with default 8bit weights & activations quantization using quantize_apply API without passing in custom config.

Describe how the feature helps achieve the use case

Describe how existing APIs don't satisfy your use case (optional if obvious)

I tried to use quantize_apply API but I received this error. RuntimeError: Layer gru:<class 'keras.src.layers.rnn.gru.GRU'> is not supported. You can quantize this layer by passing a `tfmot.quantization.keras.QuantizeConfig` instance to the `quantize_annotate_layer` API.

After using quantize_annotate_layer, I was able to train the model but Model fails to save with following error:

  keras.models.save_model(model, filepath=model_filename, save_format="h5")
Traceback (most recent call last):
  File "/workspaces/project-embedded/syntiant-ndp-model-converter/examples/train_audio_model.py", line 169, in <module>
    keras.models.save_model(model, filepath=model_filename, save_format="h5")
  File "/home/vscode/tf_venv/lib/python3.10/site-packages/keras/src/saving/saving_api.py", line 167, in save_model
    return legacy_sm_saving_lib.save_model(
  File "/home/vscode/tf_venv/lib/python3.10/site-packages/keras/src/utils/traceback_utils.py", line 70, in error_handler
    raise e.with_traceback(filtered_tb) from None
  File "/home/vscode/tf_venv/lib/python3.10/site-packages/h5py/_hl/group.py", line 183, in create_dataset
    dsid = dataset.make_new_dset(group, shape, dtype, data, name, **kwds)
  File "/home/vscode/tf_venv/lib/python3.10/site-packages/h5py/_hl/dataset.py", line 163, in make_new_dset
    dset_id = h5d.create(parent.id, name, tid, sid, dcpl=dcpl, dapl=dapl)
  File "h5py/_objects.pyx", line 54, in h5py._objects.with_phil.wrapper
  File "h5py/_objects.pyx", line 55, in h5py._objects.with_phil.wrapper
  File "h5py/h5d.pyx", line 137, in h5py.h5d.create
ValueError: Unable to synchronously create dataset (name already exists)

I used following QuantizeConfig

class GruQuantizeConfig(tfmot.quantization.keras.QuantizeConfig):
    # Configure how to quantize weights.
    def get_weights_and_quantizers(self, layer):
        return [
            (
                layer.cell.kernel,
                LastValueQuantizer(
                    num_bits=8, symmetric=True, narrow_range=False, per_axis=False
                ),
            ),
            (
                layer.cell.recurrent_kernel,
                LastValueQuantizer(
                    num_bits=8, symmetric=True, narrow_range=False, per_axis=False
                ),
            ),
        ]

    # Configure how to quantize activations.
    def get_activations_and_quantizers(self, layer):
        return [
            (
                layer.cell.activation,
                MovingAverageQuantizer(
                    num_bits=8, symmetric=False, narrow_range=False, per_axis=False
                ),
            ),
            (
                layer.cell.recurrent_activation,
                MovingAverageQuantizer(
                    num_bits=8, symmetric=False, narrow_range=False, per_axis=False
                ),
            ),
        ]

    def set_quantize_weights(self, layer, quantize_weights):
        # Add this line for each item returned in `get_weights_and_quantizers`
        # , in the same order
        layer.cell.kernel = quantize_weights[0]
        layer.cell.recurrent_kernel = quantize_weights[1]

    def set_quantize_activations(self, layer, quantize_activations):
        # Add this line for each item returned in `get_activations_and_quantizers`
        # , in the same order.
        layer.cell.activation = quantize_activations[0]
        layer.cell.recurrent_activation = quantize_activations[1]

    # Configure how to quantize outputs (may be equivalent to activations).
    def get_output_quantizers(self, layer):
        return []

    def get_config(self):
        return {}

I looked at the source code. It seems that the support for RNN is disabled here for some reason.

I was wondering if this can be enabled back?

chococigar commented 7 months ago

Thanks for filing this issue, Parth.

As you said, it looks like RNN was disabled as it was unsupported and yet to be verified on TFLite.

We'll be keeping track of this feature request, but please note that LSTM / RNN / GRU varient support is not prioritized at this moment because it is less relevant to today's ML landscapes compared to transformers.

Thanks, Jen