Open ahron1 opened 4 months ago
Regarding int8 quantization: To perform quantization, it's necessary to obtain the minimum and maximum values for each dimension of the Embedding for calibration purposes. This is why the quantize_embeddings function has arguments like calibration_embeddings or ranges. https://github.com/UKPLab/sentence-transformers/blob/0a32ec8445ef46b2b5d4f81af4931e293d42623f/sentence_transformers/quantization.py#L421 If these arguments are not provided, the function will obtain the minimum and maximum values along the row direction based on the embeddings passed to quantize_embeddings. Consequently, if this process is performed on a single Embedding, the minimum and maximum values will be equivalent. These minimum and maximum values are used to calculate a difference in the following code: https://github.com/UKPLab/sentence-transformers/blob/0a32ec8445ef46b2b5d4f81af4931e293d42623f/sentence_transformers/quantization.py#L423 As a result, the difference between the minimum and maximum becomes 0, leading to NaN values in subsequent processing. When these NaN values are converted using astype(np.int8), they become 0. Therefore, when using the quantize_embeddings function, it's necessary to either:
Pass the Embedding results from the dataset you want to use as embeddings or calibration_embeddings, or If you know the value range, pass it as ranges.
example:
from sentence_transformers import SentenceTransformer
from sentence_transformers.quantization import quantize_embeddings
model = SentenceTransformer("mixedbread-ai/mxbai-embed-large-v1")
e1 = model.encode(["this is a test", "this is another test"])
e2 = quantize_embeddings(e1, precision="int8")
e3 = quantize_embeddings(e1, precision="binary")
print(e1.shape)
print(e2.shape)
print(e3.shape)
print(e1[0][1])
print(e2)
print(e3)
execution result:
(2, 1024)
(2, 1024)
(2, 128)
0.42505217
[[-128 -128 127 ... 127 -128 127]
[ 127 127 -128 ... -128 127 -128]]
[[ 112 120 -60 -36 -88 -32 -6 -98 9 29 99 52 45 -30
84 69 -67 101 68 65 -43 55 55 -61 -44 -37 -66 -21
19 -1 116 -4 64 -88 -120 -108 -62 109 -19 27 -116 22
-7 65 -9 -15 42 33 -29 79 53 -109 -63 12 -114 15
-118 -26 -55 -86 -32 113 -111 -105 -15 -98 -87 103 -33 73
-48 30 -40 57 32 -88 122 -74 -64 -112 33 -69 -100 97
-69 -80 26 -28 -40 -122 109 34 109 -94 26 88 -23 -81
-73 23 22 80 14 -22 83 -43 31 108 34 -81 -74 -117
69 108 26 -119 -41 -1 48 85 -20 -14 86 77 -67 26
-50 -66]
[ 80 120 -60 -51 -88 -64 122 -98 9 -115 -61 48 61 -45
84 69 -115 5 100 -53 -42 63 55 64 -12 -46 50 -37
21 127 118 124 66 -72 -120 -99 66 -116 -19 31 -92 22
-39 7 -5 -15 34 37 102 14 -75 22 -63 44 -98 79
-118 -26 -38 -18 -48 113 -80 -43 113 -98 -103 102 91 -55
-64 30 -40 57 -32 -72 120 38 -96 16 43 -71 -115 99
-85 -104 -102 -27 -40 -114 61 96 -17 35 27 8 -21 46
55 95 22 64 46 -54 -61 -43 23 -4 32 -81 -74 -101
-35 12 26 -120 -9 -21 17 83 -52 -6 22 -51 29 26
-50 62]]
However, it might be a good idea to display a user warning when there is only one Embedding.
Regarding binary quantization: In the following code, the embeddings > 0 part first quantizes the embeddings to 0 or 1 based on whether the value exceeds 0. Then, using np.packbits, it compresses the binarized array by combining 8 dimensions into 1 byte. By doing this, a 1024-dimensional vector becomes 1024/8 = 128 dimensions, which has the advantage of reducing memory usage for embeddings. https://github.com/UKPLab/sentence-transformers/blob/0a32ec8445ef46b2b5d4f81af4931e293d42623f/sentence_transformers/quantization.py#L431
If you want to maintain the original number of dimensions or preserve the independence of each dimension, you can binarize bit by bit using the following code:
e3 = quantize_embeddings(e1, precision="ubinary")
unpacked_e3 = np.unpackbits(e3)
quantize_binary_e3 = (embeddings > 0).astype(np.int8)
I am trying to quantize embeddings using the
quantize_embeddings
function. The results are a bit different from what I would expect -The original
e1
has 1024 dimensions.e2
, quantized to int8, has 1024 dimensions, but the weights are all 0.e3
, quantized to binary, is 128-dimensional and has weights values between -126 and +127. The datatype isint8
.