google / qkeras

QKeras: a quantization deep learning library for Tensorflow Keras
Apache License 2.0
533 stars 102 forks source link

Params not quantized after model_save_quantized_weights function #105

Closed laumecha closed 1 year ago

laumecha commented 1 year ago

Hello all, I am trying to quantize a Keras model. For this, I load a Keras Resnet model trained on the cifar10 dataset (which has an accuracy of 0.877), then define the dictionary and perform the retraining. After the retraining, I call the model_save_quantized_weightsfunction and save the model. However, when I try to see the model on Neutron, all the parameters have more than 4 bits (e.g., 0.1347484141588211). Also, when I print the return parameter of the model_save_quantized_weights I see that all the outputs presents the same form (e.g., 0.05311733). What am I doing wrong here?

Also, I see that I have the same accuracy with 4 and 8 bits (around 0.5). How is this possible?

Here is my code:

model = keras.models.load_model('exported_m/cifar10_Resnet_unquantized_0-8773.h5')
model.summary()

q_dict = {
    "QConv2D": {
        "kernel_quantizer": "quantized_bits(4,0,alpha='auto',use_stochastic_rounding=True)",
        "bias_quantizer": "quantized_bits(4,0,alpha='auto',use_stochastic_rounding=True)"
    },
    "QDense": {
        "kernel_quantizer": "quantized_bits(4,0,alpha='auto',use_stochastic_rounding=True)",
        "bias_quantizer": "quantized_bits(4,0,alpha='auto',use_stochastic_rounding=True)"
    },
    "QBatchNormalization": {
        "mean_quantizer":"quantized_bits(8,4,alpha=1)", 
        "gamma_quantizer":"quantized_bits(8,4,alpha=1)", 
        "variance_quantizer":"quantized_bits(8,4,alpha=1)", 
        "beta_quantizer":"quantized_bits(8,4,alpha=1)",
        "inverse_quantizer":"quantized_bits(8,4,alpha=1)"
    }
}

qmodel = model_quantize(model, q_dict, 4, transfer_weights=True)

qmodel.summary()
EPOCHS = 1
print_qstats(qmodel)
qmodel.compile(optimizer='adam', loss=keras.losses.categorical_crossentropy, metrics=['accuracy'])
from keras import backend as K
K.set_value(qmodel.optimizer.learning_rate, 0.0005)
history = qmodel.fit(train_images, train_labels, batch_size=64, epochs=EPOCHS,
                    validation_data=(test_images, test_labels))

quantized_params = model_save_quantized_weights(qmodel)
qmodel.save('exported_m/quantized/test.h5')
print(quantized_params)
laumecha commented 1 year ago

Searching on the other issue (#60), I have found that if I define alpha=1 the parameters of all the layers are quantized after the model_save_quantized_weights function.

However, is this correct? If yes, why is this happening? And why the batch normalization layers are not quantized?

jurevreca12 commented 1 year ago

When you save a model as "model_save_quantized_weights" you might be expecting only integers, or fixed-point literals that are power-of-two, I guess? But numbers can in general be quantized also to values that are outside of the scope of the aforementioned representations. However, some quantization approaches are more sensible then others, of course. The quantized quantized_bits works on a tensor level and it does something like this: Wq = alpha * W Where W is the real-valued weights, Wq is the quantized-values. So for instance if you are computing W*x + b then you can compute this as: (Wq*x) / alpha + b. So what do you gain here? well if we limit Wq to only 4-bit numbers, and also x is a n-bit number, then we can use a integer multiplier to compute this, instead of a floating-point multiplier. But of course we still need to divide by alpha. So in general this doesn't necessary help us. But if we limit ourselves to alphas that are power-of-two, then we can compute this division as a shift operation, which if far more efficient. QKeras allows you to limit yourself to only power-of-two alphas, by setting the alpha parameter to "alpha_po2". You can also manually choose alpha=1 then it will always be 1, but then you will likely get worse results with your network.

I recommend you read a survey paper on quantization techniques. I recommend something like this:https://arxiv.org/pdf/2106.08295.pdf. It will help you understand how quantization aware training works.

jurevreca12 commented 1 year ago

Regarding Batch normalization layers, they are typically folded in the preceding active layer (dense or conv typically). This is also described in the paper I mentioned in the subsection "Batch normalization folding".

laumecha commented 1 year ago

Oh, I see. Now I understand. Thank you very much for your help!