ashvardanian / SimSIMD

Up to 200x Faster Dot Products & Similarity Metrics — for Python, Rust, C, JS, and Swift, supporting f64, f32, f16 real & complex, i8, and bit vectors using SIMD for both AVX2, AVX-512, NEON, SVE, & SVE2 📐
https://ashvardanian.com/posts/simsimd-faster-scipy/
Apache License 2.0
988 stars 59 forks source link

Add `bf16` complex dot product for NEON #163

Open ashvardanian opened 2 months ago

ashvardanian commented 2 months ago

The vbfmlaltq_f32 and vbfmlalbq_f32 already have the benefit of skipping odd/even entries.

MarkReedZ commented 2 months ago

The complex dot product exists for neon, but we're converting to f32 and want to operate on the bf16 inputs. The complex vector is real, imag, real, imag,,,

Original

//ab_real += ar * br - ai * bi;                                                                              \
//ab_imag += ar * bi + ai * br;   
ab_real_vec = vfmaq_f32(ab_real_vec, a_real_vec, b_real_vec);
ab_real_vec = vfmsq_f32(ab_real_vec, a_imag_vec, b_imag_vec);
ab_imag_vec = vfmaq_f32(ab_imag_vec, a_real_vec, b_imag_vec);
ab_imag_vec = vfmaq_f32(ab_imag_vec, a_imag_vec, b_real_vec);    

New looks like this perhaps (altq is an fma of odd entries while albq is even)

    ab_real_vec = vbfmlaltq_f32(ab_real_vec, a_vec, b_vec);
    ab_real_vec = vbfmlalbq_f32(ab_real_vec, vnegq_f16(a_vec), b_vec);  // ar*br + (-ai*bi)

    ab_imag_vec = vbfmlaltq_f32(ab_imag_vec, a_vec, vrev32q_bf16(b_vec));   // vrev32q swaps imag and real
    ab_imag_vec = vbfmlalbq_f32(ab_imag_vec, a_vec, vrev32q_bf16(b_vec));  // ar * bi + ai * br;   
ashvardanian commented 2 months ago

Indeed, you are right! I suppose the new version must be a lot faster, right?

MarkReedZ commented 2 months ago

10% faster. There are not bf16 versions of the neg and rev32 so we still have to jump through hoops. I confirmed that the new function's output matches the old and the tests pass. Will take a look to see if we can do this better before making a PR.

Assembly code: https://godbolt.org/z/4hzr9f943

        // ar*br + (-ai*bi)
        ab_real_vec = vbfmlaltq_f32(ab_real_vec, a_vec, b_vec);
        //ab_real_vec = vbfmlalbq_f32(ab_real_vec, vreinterpretq_bf16_f16(vnegq_f16(vreinterpretq_f16_bf16(a_vec))), b_vec);  
        ab_real_vec = vbfmlalbq_f32(ab_real_vec, vreinterpretq_bf16_u16(veorq_u16(vreinterpretq_u16_bf16(a_vec), vdupq_n_u16(0x8000))), b_vec);

        // vrev32q swaps imag and real
        // ar * bi + ai * br;   
        ab_imag_vec = vbfmlaltq_f32(ab_imag_vec, a_vec, vreinterpretq_bf16_u16(vrev32q_u16(vreinterpretq_u16_bf16(b_vec))));
        ab_imag_vec = vbfmlalbq_f32(ab_imag_vec, a_vec, vreinterpretq_bf16_u16(vrev32q_u16(vreinterpretq_u16_bf16(b_vec))));
ashvardanian commented 2 months ago

Interestingly, the godbolt.org snippet you've provided breaks Clang 18.1 if you add -O3. Without it the assembly contains a lot of noise, a bit hard to read. Still, @MarkReedZ, the source looks really good! Was hoping it would be at least 20% 😢

MarkReedZ commented 2 months ago

Good catch I was playing around with another compiler on there.

Ubuntu 24.04's clang 18.1 sees the same bug when building this code. Issue opened: https://github.com/llvm/llvm-project/issues/107810

I'll try to move code around to avoid this later.

Code: https://github.com/MarkReedZ/SimSIMD/commit/09e89bb71e37b7e120a34f19012d3fc2b13183f4

MarkReedZ commented 2 months ago

Clang is choking on the flipping of the sign bit. I haven't come up with an alternative to these two. No amount of moving code around fixes the clang bug if veorq and vnegq are used to flip the bit.

       //ab_real_vec = vbfmlalbq_f32(ab_real_vec, vreinterpretq_bf16_f16(vnegq_f16(vreinterpretq_f16_bf16(a_vec))), b_vec);  
       ab_real_vec = vbfmlalbq_f32(ab_real_vec, vreinterpretq_bf16_u16(veorq_u16(vreinterpretq_u16_bf16(a_vec), vdupq_n_u16(0x8000))), b_vec);
ashvardanian commented 2 weeks ago

Hi @MarkReedZ! Any chance you have an update in this?

MarkReedZ commented 2 weeks ago

This is fixed in clang 19.1. I'm not sure what our approach to handling this should be as 18 will remain the default for some time. We could check the clang version number defines though apparently in some cases those may by overridden.

ashvardanian commented 2 weeks ago

@MarkReedZ, can you please submit a PR that works with 19, and I'll try a few more ideas around your prototype?