tensorflow / model-optimization

A toolkit to optimize ML models for deployment for Keras and TensorFlow, including quantization and pruning.
https://www.tensorflow.org/model_optimization
Apache License 2.0
1.49k stars 319 forks source link

perform inference after QAT #832

Open lovodkin93 opened 2 years ago

lovodkin93 commented 2 years ago

Hello, I would like to train my model in a QAT scenario. But from what I understand, during QAT, only the Forward pass calculations are done in quantized mode, whereas the weights that are saved are in the original format (for example if I plan on quantizing from 32 to 4 bits, then only the Forward pass is done in 4bits, whereas the weights are saved in 32bits). So, in a normal situation, in inference time, I would need to quantize in advance my trained model into the desirable format (in our example - 4bits), and then perform inference. So my question is, given the following snippet:

def apply_mix_precision_QAT(layer):
  # if isinstance(layer, tf.keras.layers.Dense):
  if isinstance(layer, tf.keras.layers.Conv2D):
    return tfmot.quantization.keras.quantize_annotate_layer(layer)
  if isinstance(layer, tf.keras.layers.Dense):
    return tfmot.quantization.keras.quantize_annotate_layer(layer, quantize_config=ModifiedDenseQuantizeConfig())
  return layer

annotated_model = tf.keras.models.clone_model(model,clone_function=apply_mix_precision_QAT)
with tfmot.quantization.keras.quantize_scope({'ModifiedDenseQuantizeConfig': ModifiedDenseQuantizeConfig}):
    model = tfmot.quantization.keras.quantize_apply(annotated_model)

loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
optimizer = tf.keras.optimizers.SGD(momentum=0.9)
model.compile(optimizer, loss_fn, metrics=['accuracy'])
model.fit(...)

Do I need to do anything special during inference (test) time? In other words, do I need to quantize the trained model, or is it already built-in in the Quantizer wrappers and already happens as part of the " tfmot.quantization.keras.quantize_apply" function?

Xhark commented 2 years ago

If you run inference for a QAT model, it already simulate 4bits with fake-quant. Only difference is we use float32 op. Basically, input and weight is quant-dequant by injected fake-quant. Does this simulation on inference mode is suitable for your case?

lovodkin93 commented 2 years ago

@Xhark What I want is to have quantized weights during inference. In fact, I am comparing it to post-train quantization, therefore I need the weights to be quantized. So I didn't quite understand if in the QAT scenario, using the Quantizer wrapper, the weights are indeed quantized during inference. In other words, is it comparable to post-train quantization, without any extra quantization during inference? Are the weights in both cases will be in the same format (e.g., 4bit)?

dhruven-god commented 5 months ago

I am having the same question how do we check inferencing in QAT model?