xlang-ai / instructor-embedding

[ACL 2023] One Embedder, Any Task: Instruction-Finetuned Text Embeddings
Apache License 2.0
1.85k stars 134 forks source link

Quantization doesn't work on ARM devices #60

Closed samarthm closed 9 months ago

samarthm commented 1 year ago

On ARM devices, I believe that QNNPACK is the best bet for the quantization engine. I'm running the following code to quantize the base model of instructor-large but when I get the embeddings back, they're still in float32. This shouldn't happen I believe.

torch.backends.quantized.engine = 'qnnpack'
# non-quantized
model = INSTRUCTOR('hkunlp/instructor-large', device='cpu')
# quantized
qmodel = copy.deepcopy(model)
qmodel = torch.quantization.quantize_dynamic(qmodel, {torch.nn.Linear}, dtype=torch.qint8, inplace=True)

I'm running a basic "encode" call afterwards.

embeddings = qmodel.encode([[instruction,sentence]])

Example embedding outputs: [-2.07384042e-02 -4.40651225e-03 9.42820404e-03 ... 8.98325684e-03]

Is there a way to get an int8 output?

hongjin-su commented 1 year ago

Hi, Thanks a lot for your interest in the INSTRUCTOR!

Given your codes above, even if the internal weights of the model might be quantized, the output (embeddings) is still in float32. If you really need the embeddings in int8, you may consider manually quantizing them with post-process. Here are some codes for your reference:

def quantize_tensor_to_int8(tensor, scale, zero_point):
    qtensor = tensor / scale + zero_point
    qtensor.clamp_(0, 255).round_()
    qtensor = qtensor.byte()
    return qtensor

def dequantize_tensor_from_int8(qtensor, scale, zero_point):
    return (qtensor.float() - zero_point) * scale

# Compute the scale and zero_point for quantization based on min and max values
min_val = embeddings.min()
max_val = embeddings.max()
q_scale = (max_val - min_val) / 255
q_zero_point = (min_val / q_scale).round().clamp(0, 255).byte().item()

# Quantize and then de-quantize for demonstration
q_embeddings = quantize_tensor_to_int8(embeddings, q_scale, q_zero_point)
deq_embeddings = dequantize_tensor_from_int8(q_embeddings, q_scale, q_zero_point)
hongjin-su commented 9 months ago

Feel free to re-open the issue if you have any further questions or comments!