tensorflow / model-optimization

A toolkit to optimize ML models for deployment for Keras and TensorFlow, including quantization and pruning.
https://www.tensorflow.org/model_optimization
Apache License 2.0
1.49k stars 319 forks source link

QAT with trainable=False does not work as expected. #881

Open metinsuloglu opened 2 years ago

metinsuloglu commented 2 years ago

Describe the bug

System information

TensorFlow version (installed from source or binary): 2.6.0

TensorFlow Model Optimization version (installed from source or binary): 0.7.0

Python version: 3.7.10

Describe the expected behavior After setting trainable=False on layers with a quantisation wrapper applied, the weights in that layer should not change during training.

Describe the current behavior The loss decreases during training even if all layers are set to be non-trainable.

Code to reproduce the issue

import numpy as np
import tensorflow as tf
import tensorflow_model_optimization as tfmot

A = np.random.uniform(size=(10000, 10, 10))

print('Expected behaviour')
inp = tf.keras.Input(shape=(10, 10), batch_size=10)
out = tf.keras.layers.Dense(10)(inp)
model = tf.keras.Model(inp, out)
for layer in model.layers:
    layer.trainable = False
model.compile(loss='mse')
model.fit(A, A, batch_size=10, epochs=5)
print('{} trainable weights'.format(len(model.layers[1].trainable_weights)))

print('\nQuantised behaviour')
inp = tf.keras.Input(shape=(10, 10), batch_size=10)
out = tfmot.quantization.keras.quantize_annotate_layer(tf.keras.layers.Dense(10))(inp)
quant_model = tfmot.quantization.keras.quantize_apply(tf.keras.Model(inp, out))
for layer in quant_model.layers:
    layer.trainable = False
quant_model.compile(loss='mse')
quant_model.fit(A, A, batch_size=10, epochs=5)
print('{} trainable weights'.format(len(quant_model.layers[2].trainable_weights)))
xhae commented 2 years ago

@Xhark Can you take a look on this?

hunse commented 2 years ago

I'm also encountering this problem. Part of the problem appears to be that QuantizeWrapper doesn't account for self.trainable == False in trainable_weights and non_trainable_weights (c.f. tf.keras.layers.Layer, which does account for this).

However, I've patched this as per below, and it still doesn't appear to stop the weights in these layers from training. Furthermore, I'm not sure how to stop quantization parameters from training (they're always non-trainable weights, but obviously update based on the input data to the layer, and I'd like to stop them from changing as well).

from tensorflow_model_optimization.python.core.quantization.keras import quantize_wrapper

QuantizeWrapper = quantize_wrapper.QuantizeWrapper

def trainable_weights(self):
    if self.trainable:
        return self.layer.trainable_weights + self._trainable_weights
    else:
        return []

def non_trainable_weights(self):
    if self.trainable:
        return self.layer.non_trainable_weights + self._non_trainable_weights
    else:
        # Return layer weights first, and previously trainable before non-trainable,
        # to maintain the order in `self.weights`.
        # TODO: This won't be the correct order if `self.layer` has weights that are
        # always not trainable. To get the correct order, we would have to use
        # `layer._trainable_weights` and `layer._non_trainable_weights`, or switch
        # over to using QuantizeWrapperV2
        return (
            self.layer.trainable_weights +
            self.layer.non_trainable_weights +
            self._trainable_weights +
            self._non_trainable_weights
        )

QuantizeWrapper.trainable_weights = property(trainable_weights)
QuantizeWrapper.non_trainable_weights = property(non_trainable_weights)

EDIT: I got this patch working by patching QuantizeWrapperV2 as well, see below.

hunse commented 2 years ago

I've started work on a patch here: https://github.com/hunse/model-optimization/pull/1. It works for ensuring that the layer's trainable parameters (e.g. kernel, bias) do not get trained when trainable=False on the layer. @metinsuloglu, it works to get your example above pretty much passing, specifically having the trainable_weights be empty and having the loss stay almost exactly constant.

What it doesn't do yet is stop the quantization parameters from changing; for that reason, the loss does fluctuate slightly in your example. To freeze quantization parameters, we would probably want a different interface to set them to be non-adjustable, since they are already considered non-trainable weights and aren't updated as part of the backprop pass (which is how trainable weights are updated), so it doesn't make sense that setting trainable=False would freeze them.

hunse commented 2 years ago

OK, I've added a test to https://github.com/hunse/model-optimization/pull/1. Let me know if/when I can make a PR here (since the CONTRIBUTING.md document says not to make a PR until an issue is marked as "contributions welcome").

metinsuloglu commented 2 years ago

@xhae @Xhark @fredrec any updates on this?

MATTYGILO commented 1 year ago

@xhae @Xhark @fredrec any updates on this?

MATTYGILO commented 1 year ago

I'm having the problem where as soon as I quantise my model the trainable settings are reset and I can't modify them back