ashvardanian / SimSIMD

Up to 200x Faster Inner Products and Vector Similarity ā€” for Python, JavaScript, Rust, C, and Swift, supporting f64, f32, f16 real & complex, i8, and binary vectors using SIMD for both x86 AVX2 & AVX-512 and Arm NEON & SVE šŸ“
https://ashvardanian.com/posts/simsimd-faster-scipy/
Apache License 2.0
806 stars 42 forks source link

118x faster than GCC 12: KL & JS divergence with AVX-512FP16 #34

Closed ashvardanian closed 8 months ago

ashvardanian commented 8 months ago

Divergence functions are a bit more complex than the Cosine Similarity, primarily because they have to compute logarithms, which are relatively slow when using LibC's logf.

So, aside from minor patches, in this PR, I've rewritten the Jensen Shannon distances leveraging several optimizations, mainly focusing on AVX-512 and AVX-512FP16 extensions, which resulted in 4.6x improvement over the auto-vectorized single-precision variant and a whopping 118x improvement over the half-precision code.

Optimizations

Implementation

To remind, the Jensen Shannon divergence is the symmetric version of the Kullback-Leibler divergence:

JSD(P, Q) = \frac{1}{2} D(P || M) + \frac{1}{2} D(Q || M) \\
M = \frac{1}{2}(P + Q),     D(P || Q) = \sum P(i) \cdot \log \left( \frac{P(i)}{Q(i)} \right)

For AVX-512FP16, the current implementation looks like this:

__attribute__((target("avx512f,avx512vl,avx512fp16")))
inline __m512h simsimd_avx512_f16_log2(__m512h x) {
    // Extract the exponent and mantissa
    __m512h one = _mm512_set1_ph((_Float16)1);
    __m512h e = _mm512_getexp_ph(x);
    __m512h m = _mm512_getmant_ph(x, _MM_MANT_NORM_1_2, _MM_MANT_SIGN_src);

    // Compute the polynomial using Horner's method
    __m512h p = _mm512_set1_ph((_Float16)-3.4436006e-2f);
    p = _mm512_fmadd_ph(m, p, _mm512_set1_ph((_Float16)3.1821337e-1f));
    p = _mm512_fmadd_ph(m, p, _mm512_set1_ph((_Float16)-1.2315303f));
    p = _mm512_fmadd_ph(m, p, _mm512_set1_ph((_Float16)2.5988452f));
    p = _mm512_fmadd_ph(m, p, _mm512_set1_ph((_Float16)-3.3241990f));
    p = _mm512_fmadd_ph(m, p, _mm512_set1_ph((_Float16)3.1157899f));

    return _mm512_add_ph(_mm512_mul_ph(p, _mm512_sub_ph(m, one)), e);
}

__attribute__((target("avx512f,avx512vl,avx512fp16")))
inline static simsimd_f32_t simsimd_avx512_f16_js(simsimd_f16_t const* a, simsimd_f16_t const* b, simsimd_size_t n) {
    __m512h sum_a_vec = _mm512_set1_ph((_Float16)0);
    __m512h sum_b_vec = _mm512_set1_ph((_Float16)0);
    __m512h epsilon_vec = _mm512_set1_ph((_Float16)1e-6f);
    for (simsimd_size_t i = 0; i < n; i += 32) {
        __mmask32 mask = n - i >= 32 ? 0xFFFFFFFF : ((1u << (n - i)) - 1u);
        __m512h a_vec = _mm512_castsi512_ph(_mm512_maskz_loadu_epi16(mask, a + i));
        __m512h b_vec = _mm512_castsi512_ph(_mm512_maskz_loadu_epi16(mask, b + i));
        __m512h m_vec = _mm512_mul_ph(_mm512_add_ph(a_vec, b_vec), _mm512_set1_ph((_Float16)0.5f));

        // Avoid division by zero problems from probabilities under zero down the road.
        // Masking is a nicer way to do this, than adding the `epsilon` to every component.
        __mmask32 nonzero_mask_a = _mm512_cmp_ph_mask(a_vec, epsilon_vec, _CMP_GE_OQ);
        __mmask32 nonzero_mask_b = _mm512_cmp_ph_mask(b_vec, epsilon_vec, _CMP_GE_OQ);
        __mmask32 nonzero_mask = nonzero_mask_a & nonzero_mask_b & mask;

        // Division is an expensive operation. Instead of doing it twice,
        // we can approximate the reciprocal of `m` and multiply instead.
        __m512h m_recip_approx = _mm512_rcp_ph(m_vec);
        __m512h ratio_a_vec = _mm512_mul_ph(a_vec, m_recip_approx);
        __m512h ratio_b_vec = _mm512_mul_ph(b_vec, m_recip_approx);

        // The natural logarithm is equivalent to `log2`, multiplied by the `loge(2)`
        __m512h log_ratio_a_vec = simsimd_avx512_f16_log2(ratio_a_vec);
        __m512h log_ratio_b_vec = simsimd_avx512_f16_log2(ratio_b_vec);

        // Instead of separate multiplication and addition, invoke the FMA
        sum_a_vec = _mm512_maskz_fmadd_ph(nonzero_mask, a_vec, log_ratio_a_vec, sum_a_vec);
        sum_b_vec = _mm512_maskz_fmadd_ph(nonzero_mask, b_vec, log_ratio_b_vec, sum_b_vec);
    }
    simsimd_f32_t log2_normalizer = 0.693147181f;
    return _mm512_reduce_add_ph(_mm512_add_ph(sum_a_vec, sum_b_vec)) * 0.5f * log2_normalizer;
}

Benchmarks

I conducted benchmarks at both the higher-level Python and lower-level C++ layers, comparing the auto-vectorization on GCC 12 to our new implementation on an Intel Sapphire Rapids CPU on AWS:

Screenshot 2023-10-23 at 14 33 03

The program was compiled with -O3, -march=native, and -ffast-math and was running on all cores of the 4-core instance, potentially favoring the non-vectorized solution. When normalized and tabulated, the results are as follows:

Benchmark Pairs/s Gigabytes/s Absolute Error Relative Error
serial_f32_js_1536d 0.243 M/s 2.98 G/s 0 0
serial_f16_js_1536d 0.018 M/s 0.11 G/s 0.123 0.035
avx512_f32_js_1536d 1.127 M/s 13.84 G/s 0.001 345u
avx512_f16_js_1536d 2.139 M/s 13.14 G/s 0.070 0.020
avx2_f16_js_1536d 0.547 M/s 3.36 G/s 0.011 0.003

Of course, the results will vary depending on the vector size. I generally use 1536 dimensions, matching the size of OpenAI Ada embeddings, standard in NLP workloads. The Jensen Shannon divergence, however, is used broadly in other domains of statistics, bio-informatics, and chem-informatics, so I'm adding it as a new out-of-the-box supported metric into USearch today šŸ„³

This further accelerates the k-approximate Nearest Neighbors Search and the clustering of Billions of different protein sequences without alignment procedures. Expect one more "Less Slow" post soon! šŸ¤—

ashvardanian commented 8 months ago

:tada: This PR is included in version 3.2.0 :tada:

The release is available on GitHub release

Your semantic-release bot :package::rocket: