Open crinex opened 3 weeks ago
Hi @crinex , thank you for checking out the Qualcomm tutorial!
Here is the process with code pointers
export_llama_lib.py
, we will call .pt2e_quantize()
.pt2e_quantize
function, we will do 3 things, step 1: run prepare_pt2e
to insert observers, step 2: calibrate which will update the params in the observers, step 3: converted the observers to actual quant/dequant operator.Does it answer your question?
Hi @cccclai Thank you for the explanation. It helped me understand.
I have another question. We are converting the Llama-3.2-3B-Instruct model to qnn_8a8w
.
During the process of running export_llama
to convert the model into a pte file, I wanted to verify the quantization state and changes in the model.
So, I executed the prepare_pt2e
, pt2e_calibrate
, convert_pt2e
, DuplicateDynamicQuantChainPass
, and export_to_edge
functions, and then printed the model's parameter count and size (MB).
After checking, it turned out that the model’s size and parameter count remained unchanged through each of these steps. Specifically, the parameter count is 3,606,752,256, and the size is 13,758.66796875 MB.
It doesn’t seem like the model was actually quantized. I’m curious to know exactly where quantization takes place and where the model size should decrease.
Ultimately, we want to resolve the error that occurs when generating sentences with qnn_8a8w.
Thank you for always helping so kindly.
cc @kimishpatel @jerryzh168 on when the weights are actually converted.
I think most of the time in practice this happens in to_backend since most/all quantized ops today execute in delegates for ET. Quantized embedding might be an exception where we have a pass that replaces the pattern with a quantized op and packs the weight in the top level graph.
we quantize the weights in convert_pt2e
: https://github.com/pytorch/pytorch/blob/891ba2ec8a3e2e71137fab4a8e91940a19c8272b/torch/ao/quantization/quantize_pt2e.py#L241
As @jerryzh168 said convert_pt2e should have quantized weights. If you serialize the model at that point, you should see the impact on file size. I dont know how you are measuring the size after the listed steps
Dear @kimishpatel @jerryzh168 @JacobSzwejbka
I performed the functions prepare_pt2e, self.pt2e_calibrate, and convert_pt2e within the pt2e_quantize() function mentioned above, then set a debugging breakpoint. I measured the model size using the following code:
def get_model_size(model):
num_params = sum(p.numel() for p in model.parameters())
param_size_bytes = sum(p.numel() * p.element_size() for p in model.parameters())
param_size_megabytes = param_size_bytes / (1024 ** 2)
return num_params, param_size_megabytes
For the model parameter, I continuously passed the value m, which is defined as m = prepare_pt2e(self.pre_autograd_graph_module, composed_quantizer), and checked the values using get_model_size(m).
Is there any part where I might have made a mistake?
@jerryzh168 do you know? I suspect it might be related to how const prop works. If you do torch.export.save, does that refelct in model size?
yeah I think the quantized weights are not model.parameters
, they will be buffers: https://github.com/pytorch/pytorch/blob/c98ef0279e6eb968f5f9d22e1f193e7064594152/torch/_export/passes/constant_folding.py#L45
we typically just save the state_dict and check file size: https://pytorch.org/tutorials/prototype/pt2e_quant_ptq.html#checking-model-size-and-accuracy-evaluation
Dear @cccclai
I’m reviewing the code while following the guide(Export with Spinquant) you provided for converting the Llama3.2-3B-Instruct model with Qualcomm SpinQuant. When I execute the
_export_llama
function in theexport_llama_lib.py
file, thept2e_quantize(quantizers)
function is called. Within this function, thept2e_calibrate
function is executed before theconvert_pt2e
function. Why ispt2e_calibrate
performed beforeconvert_pt2e
here? Generally, wouldn't it make more sense to perform calibration after quantization?Thank you