tensorflow / model-optimization

A toolkit to optimize ML models for deployment for Keras and TensorFlow, including quantization and pruning.
https://www.tensorflow.org/model_optimization
Apache License 2.0
1.49k stars 319 forks source link

QAT aware training for mobilenetV2 not working #1086

Open christophezeinaty opened 1 year ago

christophezeinaty commented 1 year ago

Hello I am trying to apply quantization aware training for mobilenet, and I am testing on the mnist dataset, the floating point models works very well but the moment I add the quantization I have a very huge loss and the model isn't converging.

definition of the model :

` def build_model(target_size): input_tensor = Input(shape=(target_size, target_size, 3)) base_model = applications.MobileNetV2( include_top=False, weights='imagenet', input_tensor=input_tensor, input_shape=(target_size, target_size, 3), pooling='avg')

for layer in base_model.layers:
    layer.trainable = False  # trainable has to be false in order to freeze the layers

op = Dense(256, activation='relu')(base_model.output)
op = Dropout(.25)(op)

output_tensor = Dense(10, activation='softmax')(op)

model = Model(inputs=input_tensor, outputs=output_tensor)

return model`

quantization and training:

def quantized_model(model_fp32): quantize_model = tfmot.quantization.keras.quantize_model q_aware_model = quantize_model(model_fp32) return q_aware_model

`q_aware_model = quantized_model(model_fp32) q_aware_model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['categorical_accuracy'])

print("training model with int8 precision")

train_images_subset = train_images[0:1000] # out of 60000
train_labels_subset = train_labels[0:1000]
encoded_y_quant_train = to_categorical(train_labels_subset, num_classes=10, dtype='float32')

train_quant_generator = load_data_generator(train_images_subset, encoded_y_quant_train, batch_size=64)

q_aware_model.fit(train_quant_generator,
                batch_size=500, epochs=1, steps_per_epoch=900)`
Xhark commented 1 year ago

Would you please add some more details that how did you evaluate and what's the expected output? Can you try full dataset (0:60000) instead of 0:1000 for QAT? Thanks!