ashvardanian / SimSIMD

Up to 200x Faster Inner Products and Vector Similarity — for Python, JavaScript, Rust, C, and Swift, supporting f64, f32, f16 real & complex, i8, and binary vectors using SIMD for both x86 AVX2 & AVX-512 and Arm NEON & SVE 📐
https://ashvardanian.com/posts/simsimd-faster-scipy/
Apache License 2.0
806 stars 42 forks source link

Maybe a faster dot product (I'm not sure) #45

Closed javiabellan closed 3 months ago

javiabellan commented 6 months ago

Looking at AVX512 dot product I tried to avoid the if inside the loop to make a faster code. Here is a (not tested) idea of the proposed code:

simsimd_avx512_f32_ip(simsimd_f32_t const* a, simsimd_f32_t const* b, simsimd_size_t n)
{
    __m512 ab_vec = _mm512_setzero();
    __m512 a_vec, b_vec;

    int n_tail = n & 15; // Equivalent of n%16 but faster
    n -= n_tail;

    while(n) // faster loop without the "being on the tail condition"
    {
        a_vec = _mm512_loadu_ps(a);
        b_vec = _mm512_loadu_ps(b);
        ab_vec = _mm512_fmadd_ps(a_vec, b_vec, ab_vec);

        a += 16, b += 16, n -= 16;
    }
    if(n_tail)
    {
        __mmask16 mask = _bzhi_u32(0xFFFFFFFF, n_tail);
        a_vec = _mm512_maskz_loadu_ps(mask, a);
        b_vec = _mm512_maskz_loadu_ps(mask, b);
        ab_vec = _mm512_fmadd_ps(a_vec, b_vec, ab_vec);
    }

    return _mm512_reduce_add_ps(ab_vec);
}
ashvardanian commented 6 months ago

Hi @javiabellan, thanks for the contribution! The while is actually identical to a combination of if and goto, so the performance should be the same.

ashvardanian commented 6 months ago

I am not sure the AVX-512 versions have room for improvement, but for Apple one can try to replace NEON with AMX. Let me know if you can test those 🤗

javiabellan commented 6 months ago

I see it. I was thinking about the condition in assembly. Maybe cheking if is not zero (JNZ) is faster than n < 16. But im not sure about that.

I think the improvement comes from each loop iteration, where the proposed code has 1 condition (while(n)) instead of 2 ifs (if (n < 16) and if (n)). The main idea is not checking (at every iteration) if we are on the final tail or not, because we can know that a priori by avoiding being on the tail by modifing n to n -= n_tail.

The main disadvantage of the proposed code is the computation of int n_tail = n & 15 and n -= n_tail but this O(1). And also the larger code by duplicating the ab_vec = _mm512_fmadd_ps(a_vec, b_vec, ab_vec); line.