ashvardanian / SimSIMD

Up to 200x Faster Dot Products & Similarity Metrics — for Python, Rust, C, JS, and Swift, supporting f64, f32, f16 real & complex, i8, and bit vectors using SIMD for both AVX2, AVX-512, NEON, SVE, & SVE2 📐
https://ashvardanian.com/posts/simsimd-faster-scipy/
Apache License 2.0
988 stars 59 forks source link

Using vbfdotq_f32 for `dot_bf16_neon` is faster? #167

Closed MarkReedZ closed 2 months ago

MarkReedZ commented 2 months ago

The following code runs in 11ns vs 15ns for the current version using vbfmlaltq_f32. Does it make sense to use vbfdotq_32? I'm not sure this code is correct - how do we typically test the results?

Should we add an unaligned vector size to the benchmarks?

SIMSIMD_PUBLIC void simsimd_dot_bf16_neon(simsimd_bf16_t const* a, simsimd_bf16_t const* b, simsimd_size_t n,
                                          simsimd_distance_t* result) {
    float32x4_t ab_vec = vdupq_n_f32(0);

    while (n >= 8) {
        bfloat16x8_t a_vec = vld1q_bf16((simsimd_bf16_for_arm_simd_t const*)a);
        bfloat16x8_t b_vec = vld1q_bf16((simsimd_bf16_for_arm_simd_t const*)b);

        ab_vec = vbfdotq_f32(ab_vec, a_vec, b_vec);

        n -= 8;
        a += 8;
        b += 8;
    }
    // TODO handle the remainder
    *result = vaddvq_f32(ab_vec);
}
ashvardanian commented 2 months ago

I'm not sure this code is correct - how do we typically test the results?

@MarkReedZ, C++ benchmarks will log the accuracy delta compared to serial baseline. Python tests will fail if this instruction does something weird. Let's run those two.

ashvardanian commented 2 months ago

Does it make sense to use vbfdotq_32?

@MarkReedZ, I'm not sure if I've used that instruction before. If it doesn't affect compilation settings and CPU-capability requirements, sure! Otherwise, we can add a note everywhere vbfmlaltq_f32 is used. But for #163 it probably still makes sense.

MarkReedZ commented 2 months ago

PR here: https://github.com/ashvardanian/SimSIMD/pull/169

Run on AWS t4g, and c7g instances with gcc 12/13 only.