Up to 200x Faster Inner Products and Vector Similarity ā for Python, JavaScript, Rust, C, and Swift, supporting f64, f32, f16 real & complex, i8, and binary vectors using SIMD for both x86 AVX2 & AVX-512 and Arm NEON & SVE š
Divergence functions are a bit more complex than the Cosine Similarity, primarily because they have to compute logarithms, which are relatively slow when using LibC's logf.
So, aside from minor patches, in this PR, I've rewritten the Jensen Shannon distances leveraging several optimizations, mainly focusing on AVX-512 and AVX-512FP16 extensions, which resulted in 4.6x improvement over the auto-vectorized single-precision variant and a whopping 118x improvement over the half-precision code.
Optimizations
Logarithm Computation. Instead of multiple bitwise operations, _mm512_getexp_ph and _mm512_getmant_ph are now used to extract the exponent and the mantissa of the floating-point number, streamlining the process. I've also used Horner's method for the polynomial approximation.
Division Avoidance. To avoid expensive division operations, reciprocal approximations are utilized - _mm512_rcp_ph for half-precision and _mm512_rcp14_ps for single-precision. The _mm512_rcp28_ps was found to be unnecessary for this implementation.
Handling Zeros. The _mm512_cmp_ph_mask is used to compute a mask for close-to-zero values, avoiding the addition of an "epsilon" to every component, which is both cleaner and more accurate.
Parallel Accumulation. The accumulation of $KL(a||b)$ and $KL(b||a)$ are now handled in separate registers, and the masked _mm512_maskz_fmadd_ph replaces distinct addition and multiplication operations, optimizing the calculation further.
Implementation
To remind, the Jensen Shannon divergence is the symmetric version of the Kullback-Leibler divergence:
JSD(P, Q) = \frac{1}{2} D(P || M) + \frac{1}{2} D(Q || M) \\
For AVX-512FP16, the current implementation looks like this:
__attribute__((target("avx512f,avx512vl,avx512fp16")))
inline __m512h simsimd_avx512_f16_log2(__m512h x) {
// Extract the exponent and mantissa
__m512h one = _mm512_set1_ph((_Float16)1);
__m512h e = _mm512_getexp_ph(x);
__m512h m = _mm512_getmant_ph(x, _MM_MANT_NORM_1_2, _MM_MANT_SIGN_src);
// Compute the polynomial using Horner's method
__m512h p = _mm512_set1_ph((_Float16)-3.4436006e-2f);
p = _mm512_fmadd_ph(m, p, _mm512_set1_ph((_Float16)3.1821337e-1f));
p = _mm512_fmadd_ph(m, p, _mm512_set1_ph((_Float16)-1.2315303f));
p = _mm512_fmadd_ph(m, p, _mm512_set1_ph((_Float16)2.5988452f));
p = _mm512_fmadd_ph(m, p, _mm512_set1_ph((_Float16)-3.3241990f));
p = _mm512_fmadd_ph(m, p, _mm512_set1_ph((_Float16)3.1157899f));
return _mm512_add_ph(_mm512_mul_ph(p, _mm512_sub_ph(m, one)), e);
}
__attribute__((target("avx512f,avx512vl,avx512fp16")))
inline static simsimd_f32_t simsimd_avx512_f16_js(simsimd_f16_t const* a, simsimd_f16_t const* b, simsimd_size_t n) {
__m512h sum_a_vec = _mm512_set1_ph((_Float16)0);
__m512h sum_b_vec = _mm512_set1_ph((_Float16)0);
__m512h epsilon_vec = _mm512_set1_ph((_Float16)1e-6f);
for (simsimd_size_t i = 0; i < n; i += 32) {
__mmask32 mask = n - i >= 32 ? 0xFFFFFFFF : ((1u << (n - i)) - 1u);
__m512h a_vec = _mm512_castsi512_ph(_mm512_maskz_loadu_epi16(mask, a + i));
__m512h b_vec = _mm512_castsi512_ph(_mm512_maskz_loadu_epi16(mask, b + i));
__m512h m_vec = _mm512_mul_ph(_mm512_add_ph(a_vec, b_vec), _mm512_set1_ph((_Float16)0.5f));
// Avoid division by zero problems from probabilities under zero down the road.
// Masking is a nicer way to do this, than adding the `epsilon` to every component.
__mmask32 nonzero_mask_a = _mm512_cmp_ph_mask(a_vec, epsilon_vec, _CMP_GE_OQ);
__mmask32 nonzero_mask_b = _mm512_cmp_ph_mask(b_vec, epsilon_vec, _CMP_GE_OQ);
__mmask32 nonzero_mask = nonzero_mask_a & nonzero_mask_b & mask;
// Division is an expensive operation. Instead of doing it twice,
// we can approximate the reciprocal of `m` and multiply instead.
__m512h m_recip_approx = _mm512_rcp_ph(m_vec);
__m512h ratio_a_vec = _mm512_mul_ph(a_vec, m_recip_approx);
__m512h ratio_b_vec = _mm512_mul_ph(b_vec, m_recip_approx);
// The natural logarithm is equivalent to `log2`, multiplied by the `loge(2)`
__m512h log_ratio_a_vec = simsimd_avx512_f16_log2(ratio_a_vec);
__m512h log_ratio_b_vec = simsimd_avx512_f16_log2(ratio_b_vec);
// Instead of separate multiplication and addition, invoke the FMA
sum_a_vec = _mm512_maskz_fmadd_ph(nonzero_mask, a_vec, log_ratio_a_vec, sum_a_vec);
sum_b_vec = _mm512_maskz_fmadd_ph(nonzero_mask, b_vec, log_ratio_b_vec, sum_b_vec);
}
simsimd_f32_t log2_normalizer = 0.693147181f;
return _mm512_reduce_add_ph(_mm512_add_ph(sum_a_vec, sum_b_vec)) * 0.5f * log2_normalizer;
}
Benchmarks
I conducted benchmarks at both the higher-level Python and lower-level C++ layers, comparing the auto-vectorization on GCC 12 to our new implementation on an Intel Sapphire Rapids CPU on AWS:
The program was compiled with -O3, -march=native, and -ffast-math and was running on all cores of the 4-core instance, potentially favoring the non-vectorized solution. When normalized and tabulated, the results are as follows:
Benchmark
Pairs/s
Gigabytes/s
Absolute Error
Relative Error
serial_f32_js_1536d
0.243 M/s
2.98 G/s
0
0
serial_f16_js_1536d
0.018 M/s
0.11 G/s
0.123
0.035
avx512_f32_js_1536d
1.127 M/s
13.84 G/s
0.001
345u
avx512_f16_js_1536d
2.139 M/s
13.14 G/s
0.070
0.020
avx2_f16_js_1536d
0.547 M/s
3.36 G/s
0.011
0.003
Of course, the results will vary depending on the vector size. I generally use 1536 dimensions, matching the size of OpenAI Ada embeddings, standard in NLP workloads. The Jensen Shannon divergence, however, is used broadly in other domains of statistics, bio-informatics, and chem-informatics, so I'm adding it as a new out-of-the-box supported metric into USearch today š„³
Divergence functions are a bit more complex than the Cosine Similarity, primarily because they have to compute logarithms, which are relatively slow when using LibC's
logf
.So, aside from minor patches, in this PR, I've rewritten the Jensen Shannon distances leveraging several optimizations, mainly focusing on AVX-512 and AVX-512FP16 extensions, which resulted in 4.6x improvement over the auto-vectorized single-precision variant and a whopping 118x improvement over the half-precision code.
Optimizations
_mm512_getexp_ph
and_mm512_getmant_ph
are now used to extract the exponent and the mantissa of the floating-point number, streamlining the process. I've also used Horner's method for the polynomial approximation._mm512_rcp_ph
for half-precision and_mm512_rcp14_ps
for single-precision. The_mm512_rcp28_ps
was found to be unnecessary for this implementation._mm512_cmp_ph_mask
is used to compute a mask for close-to-zero values, avoiding the addition of an "epsilon" to every component, which is both cleaner and more accurate._mm512_maskz_fmadd_ph
replaces distinct addition and multiplication operations, optimizing the calculation further.Implementation
To remind, the Jensen Shannon divergence is the symmetric version of the Kullback-Leibler divergence:
For AVX-512FP16, the current implementation looks like this:
Benchmarks
I conducted benchmarks at both the higher-level Python and lower-level C++ layers, comparing the auto-vectorization on GCC 12 to our new implementation on an Intel Sapphire Rapids CPU on AWS:
The program was compiled with
-O3
,-march=native
, and-ffast-math
and was running on all cores of the 4-core instance, potentially favoring the non-vectorized solution. When normalized and tabulated, the results are as follows:serial_f32_js_1536d
serial_f16_js_1536d
avx512_f32_js_1536d
avx512_f16_js_1536d
avx2_f16_js_1536d
Of course, the results will vary depending on the vector size. I generally use 1536 dimensions, matching the size of OpenAI Ada embeddings, standard in NLP workloads. The Jensen Shannon divergence, however, is used broadly in other domains of statistics, bio-informatics, and chem-informatics, so I'm adding it as a new out-of-the-box supported metric into USearch today š„³
This further accelerates the k-approximate Nearest Neighbors Search and the clustering of Billions of different protein sequences without alignment procedures. Expect one more "Less Slow" post soon! š¤