huggingface / optimum-quanto

A pytorch quantization backend for optimum
Apache License 2.0
829 stars 61 forks source link

Weights Still in FP32 after Quantization #347

Open ClaraLovesFunk opened 2 weeks ago

ClaraLovesFunk commented 2 weeks ago

Dear quanto folks,

I implemented quantization as suggested in your coding example quantize_sst2_model.py. When printing the datatypes of the parameters, I found that after quantization all the weights remained in float32. Do you have any explaination to this?

And also do you have any explainations, why i can't use bigger batch sizes when applying quantization of both weights and activations? I used PubMedBERT for Huggingface, fine-tuned it myself and applied static quantization (see code below).

And do you know why inference speed significantly slows down when i use the reloaded statically quantized model (code below) as opposed to the directly statically quantized model? I again followed the instructions of the coding example

Any help greatly appreciated since I'm just wrapping up my soon due master thesis about this <3 Clara

Direct Static Quantiation:

weights = qint8
activations = qint8

model_quantized_static = AutoModelForSequenceClassification.from_pretrained(model_path, num_labels=len(label_mapping)).to(device)
quantize(model_quantized_static, weights=weights, activations=activations)
if activations is not None:
    print("Calibrating ...")
    with Calibration():
        evaluate_model(model_quantized_static, dataset_val, device, batch_size = 64)

freeze(model_quantized_static)
# Check the data type of model parameters
for name, param in model_quantized_static.named_parameters():
    print(f"Parameter: {name}, Data Type: {param.dtype}")

Reloading statically quantized model:

model_quantized_reloaded = AutoModelForSequenceClassification.from_pretrained(model_path, num_labels=len(label_mapping)).to(device)
quantize(model_quantized_reloaded, weights=weights, activations=activations)
state_dict = torch.load(model_quantized_path)
model_quantized_reloaded.load_state_dict(state_dict)
freeze(model_quantized_reloaded)
ClaraLovesFunk commented 2 weeks ago

I just tested your example file quantize_sst2_model.py and printed the parameters of the reloaded model and also there all the parameters are still in float32.

for name, param in model_reloaded.named_parameters():
    print(f"Parameter: {name}, Data Type: {param.dtype}")

Float model 872 sentences evaluated in 2.08 s. accuracy = 0.9105504587155964 Calibrating ... 872 sentences evaluated in 3.12 s. accuracy = 0.893348623853211 Quantized model (w: quanto.qint8, a: quanto.qint8) 872 sentences evaluated in 1.85 s. accuracy = 0.8979357798165137

:68: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. state_dict = torch.load(b) Serialized quantized model 872 sentences evaluated in 1.98 s. accuracy = 0.8864678899082569 Parameter: distilbert.embeddings.word_embeddings.weight, Data Type: torch.float32 Parameter: distilbert.embeddings.position_embeddings.weight, Data Type: torch.float32 Parameter: distilbert.embeddings.LayerNorm.weight, Data Type: torch.float32 Parameter: distilbert.embeddings.LayerNorm.bias, Data Type: torch.float32 Parameter: distilbert.transformer.layer.0.attention.q_lin.weight, Data Type: torch.float32 Parameter: distilbert.transformer.layer.0.attention.q_lin.bias, Data Type: torch.float32 Parameter: distilbert.transformer.layer.0.attention.k_lin.weight, Data Type: torch.float32 Parameter: distilbert.transformer.layer.0.attention.k_lin.bias, Data Type: torch.float32 Parameter: distilbert.transformer.layer.0.attention.v_lin.weight, Data Type: torch.float32 Parameter: distilbert.transformer.layer.0.attention.v_lin.bias, Data Type: torch.float32 Parameter: distilbert.transformer.layer.0.attention.out_lin.weight, Data Type: torch.float32 Parameter: distilbert.transformer.layer.0.attention.out_lin.bias, Data Type: torch.float32 Parameter: distilbert.transformer.layer.0.sa_layer_norm.weight, Data Type: torch.float32 Parameter: distilbert.transformer.layer.0.sa_layer_norm.bias, Data Type: torch.float32 Parameter: distilbert.transformer.layer.0.ffn.lin1.weight, Data Type: torch.float32 Parameter: distilbert.transformer.layer.0.ffn.lin1.bias, Data Type: torch.float32 Parameter: distilbert.transformer.layer.0.ffn.lin2.weight, Data Type: torch.float32 Parameter: distilbert.transformer.layer.0.ffn.lin2.bias, Data Type: torch.float32 Parameter: distilbert.transformer.layer.0.output_layer_norm.weight, Data Type: torch.float32 Parameter: distilbert.transformer.layer.0.output_layer_norm.bias, Data Type: torch.float32 Parameter: distilbert.transformer.layer.1.attention.q_lin.weight, Data Type: torch.float32 Parameter: distilbert.transformer.layer.1.attention.q_lin.bias, Data Type: torch.float32 Parameter: distilbert.transformer.layer.1.attention.k_lin.weight, Data Type: torch.float32 Parameter: distilbert.transformer.layer.1.attention.k_lin.bias, Data Type: torch.float32 Parameter: distilbert.transformer.layer.1.attention.v_lin.weight, Data Type: torch.float32 Parameter: distilbert.transformer.layer.1.attention.v_lin.bias, Data Type: torch.float32 Parameter: distilbert.transformer.layer.1.attention.out_lin.weight, Data Type: torch.float32 Parameter: distilbert.transformer.layer.1.attention.out_lin.bias, Data Type: torch.float32 Parameter: distilbert.transformer.layer.1.sa_layer_norm.weight, Data Type: torch.float32 Parameter: distilbert.transformer.layer.1.sa_layer_norm.bias, Data Type: torch.float32 Parameter: distilbert.transformer.layer.1.ffn.lin1.weight, Data Type: torch.float32 Parameter: distilbert.transformer.layer.1.ffn.lin1.bias, Data Type: torch.float32 Parameter: distilbert.transformer.layer.1.ffn.lin2.weight, Data Type: torch.float32 Parameter: distilbert.transformer.layer.1.ffn.lin2.bias, Data Type: torch.float32 Parameter: distilbert.transformer.layer.1.output_layer_norm.weight, Data Type: torch.float32 Parameter: distilbert.transformer.layer.1.output_layer_norm.bias, Data Type: torch.float32 Parameter: distilbert.transformer.layer.2.attention.q_lin.weight, Data Type: torch.float32 Parameter: distilbert.transformer.layer.2.attention.q_lin.bias, Data Type: torch.float32 Parameter: distilbert.transformer.layer.2.attention.k_lin.weight, Data Type: torch.float32 Parameter: distilbert.transformer.layer.2.attention.k_lin.bias, Data Type: torch.float32 Parameter: distilbert.transformer.layer.2.attention.v_lin.weight, Data Type: torch.float32 Parameter: distilbert.transformer.layer.2.attention.v_lin.bias, Data Type: torch.float32 Parameter: distilbert.transformer.layer.2.attention.out_lin.weight, Data Type: torch.float32 Parameter: distilbert.transformer.layer.2.attention.out_lin.bias, Data Type: torch.float32 Parameter: distilbert.transformer.layer.2.sa_layer_norm.weight, Data Type: torch.float32 Parameter: distilbert.transformer.layer.2.sa_layer_norm.bias, Data Type: torch.float32 Parameter: distilbert.transformer.layer.2.ffn.lin1.weight, Data Type: torch.float32 Parameter: distilbert.transformer.layer.2.ffn.lin1.bias, Data Type: torch.float32 Parameter: distilbert.transformer.layer.2.ffn.lin2.weight, Data Type: torch.float32 Parameter: distilbert.transformer.layer.2.ffn.lin2.bias, Data Type: torch.float32 Parameter: distilbert.transformer.layer.2.output_layer_norm.weight, Data Type: torch.float32 Parameter: distilbert.transformer.layer.2.output_layer_norm.bias, Data Type: torch.float32 Parameter: distilbert.transformer.layer.3.attention.q_lin.weight, Data Type: torch.float32 Parameter: distilbert.transformer.layer.3.attention.q_lin.bias, Data Type: torch.float32 Parameter: distilbert.transformer.layer.3.attention.k_lin.weight, Data Type: torch.float32 Parameter: distilbert.transformer.layer.3.attention.k_lin.bias, Data Type: torch.float32 Parameter: distilbert.transformer.layer.3.attention.v_lin.weight, Data Type: torch.float32 Parameter: distilbert.transformer.layer.3.attention.v_lin.bias, Data Type: torch.float32 Parameter: distilbert.transformer.layer.3.attention.out_lin.weight, Data Type: torch.float32 Parameter: distilbert.transformer.layer.3.attention.out_lin.bias, Data Type: torch.float32 Parameter: distilbert.transformer.layer.3.sa_layer_norm.weight, Data Type: torch.float32 Parameter: distilbert.transformer.layer.3.sa_layer_norm.bias, Data Type: torch.float32 Parameter: distilbert.transformer.layer.3.ffn.lin1.weight, Data Type: torch.float32 Parameter: distilbert.transformer.layer.3.ffn.lin1.bias, Data Type: torch.float32 Parameter: distilbert.transformer.layer.3.ffn.lin2.weight, Data Type: torch.float32 Parameter: distilbert.transformer.layer.3.ffn.lin2.bias, Data Type: torch.float32 Parameter: distilbert.transformer.layer.3.output_layer_norm.weight, Data Type: torch.float32 Parameter: distilbert.transformer.layer.3.output_layer_norm.bias, Data Type: torch.float32 Parameter: distilbert.transformer.layer.4.attention.q_lin.weight, Data Type: torch.float32 Parameter: distilbert.transformer.layer.4.attention.q_lin.bias, Data Type: torch.float32 Parameter: distilbert.transformer.layer.4.attention.k_lin.weight, Data Type: torch.float32 Parameter: distilbert.transformer.layer.4.attention.k_lin.bias, Data Type: torch.float32 Parameter: distilbert.transformer.layer.4.attention.v_lin.weight, Data Type: torch.float32 Parameter: distilbert.transformer.layer.4.attention.v_lin.bias, Data Type: torch.float32 Parameter: distilbert.transformer.layer.4.attention.out_lin.weight, Data Type: torch.float32 Parameter: distilbert.transformer.layer.4.attention.out_lin.bias, Data Type: torch.float32 Parameter: distilbert.transformer.layer.4.sa_layer_norm.weight, Data Type: torch.float32 Parameter: distilbert.transformer.layer.4.sa_layer_norm.bias, Data Type: torch.float32 Parameter: distilbert.transformer.layer.4.ffn.lin1.weight, Data Type: torch.float32 Parameter: distilbert.transformer.layer.4.ffn.lin1.bias, Data Type: torch.float32 Parameter: distilbert.transformer.layer.4.ffn.lin2.weight, Data Type: torch.float32 Parameter: distilbert.transformer.layer.4.ffn.lin2.bias, Data Type: torch.float32 Parameter: distilbert.transformer.layer.4.output_layer_norm.weight, Data Type: torch.float32 Parameter: distilbert.transformer.layer.4.output_layer_norm.bias, Data Type: torch.float32 Parameter: distilbert.transformer.layer.5.attention.q_lin.weight, Data Type: torch.float32 Parameter: distilbert.transformer.layer.5.attention.q_lin.bias, Data Type: torch.float32 Parameter: distilbert.transformer.layer.5.attention.k_lin.weight, Data Type: torch.float32 Parameter: distilbert.transformer.layer.5.attention.k_lin.bias, Data Type: torch.float32 Parameter: distilbert.transformer.layer.5.attention.v_lin.weight, Data Type: torch.float32 Parameter: distilbert.transformer.layer.5.attention.v_lin.bias, Data Type: torch.float32 Parameter: distilbert.transformer.layer.5.attention.out_lin.weight, Data Type: torch.float32 Parameter: distilbert.transformer.layer.5.attention.out_lin.bias, Data Type: torch.float32 Parameter: distilbert.transformer.layer.5.sa_layer_norm.weight, Data Type: torch.float32 Parameter: distilbert.transformer.layer.5.sa_layer_norm.bias, Data Type: torch.float32 Parameter: distilbert.transformer.layer.5.ffn.lin1.weight, Data Type: torch.float32 Parameter: distilbert.transformer.layer.5.ffn.lin1.bias, Data Type: torch.float32 Parameter: distilbert.transformer.layer.5.ffn.lin2.weight, Data Type: torch.float32 Parameter: distilbert.transformer.layer.5.ffn.lin2.bias, Data Type: torch.float32 Parameter: distilbert.transformer.layer.5.output_layer_norm.weight, Data Type: torch.float32 Parameter: distilbert.transformer.layer.5.output_layer_norm.bias, Data Type: torch.float32 Parameter: pre_classifier.weight, Data Type: torch.float32 Parameter: pre_classifier.bias, Data Type: torch.float32 Parameter: classifier.weight, Data Type: torch.float32 Parameter: classifier.bias, Data Type: torch.float32
dacorvo commented 1 week ago

@ClaraLovesFunk thank you for your feedback. The parameters dtype is still float32, but if you check their type, you will see that they are now QTensor subtypes instead of Tensor. QTensor subtypes preserve the external dtype but their internal data is quantized. You can check the qtype property to verify if it is correct.

ClaraLovesFunk commented 1 week ago

Thank you so much for the explanation, David! Will do.

ClaraLovesFunk commented 1 week ago

Do you maybe also have an explanation, why i can't use bigger batch sizes after applying quantization and veryfing my model shrinked from 413.44 to 169.11 MB?