Open ashvardanian opened 2 months ago
The complex dot product exists for neon, but we're converting to f32 and want to operate on the bf16 inputs. The complex vector is real, imag, real, imag,,,
Original
//ab_real += ar * br - ai * bi; \
//ab_imag += ar * bi + ai * br;
ab_real_vec = vfmaq_f32(ab_real_vec, a_real_vec, b_real_vec);
ab_real_vec = vfmsq_f32(ab_real_vec, a_imag_vec, b_imag_vec);
ab_imag_vec = vfmaq_f32(ab_imag_vec, a_real_vec, b_imag_vec);
ab_imag_vec = vfmaq_f32(ab_imag_vec, a_imag_vec, b_real_vec);
New looks like this perhaps (altq is an fma of odd entries while albq is even)
ab_real_vec = vbfmlaltq_f32(ab_real_vec, a_vec, b_vec);
ab_real_vec = vbfmlalbq_f32(ab_real_vec, vnegq_f16(a_vec), b_vec); // ar*br + (-ai*bi)
ab_imag_vec = vbfmlaltq_f32(ab_imag_vec, a_vec, vrev32q_bf16(b_vec)); // vrev32q swaps imag and real
ab_imag_vec = vbfmlalbq_f32(ab_imag_vec, a_vec, vrev32q_bf16(b_vec)); // ar * bi + ai * br;
Indeed, you are right! I suppose the new version must be a lot faster, right?
10% faster. There are not bf16 versions of the neg and rev32 so we still have to jump through hoops. I confirmed that the new function's output matches the old and the tests pass. Will take a look to see if we can do this better before making a PR.
Assembly code: https://godbolt.org/z/4hzr9f943
// ar*br + (-ai*bi)
ab_real_vec = vbfmlaltq_f32(ab_real_vec, a_vec, b_vec);
//ab_real_vec = vbfmlalbq_f32(ab_real_vec, vreinterpretq_bf16_f16(vnegq_f16(vreinterpretq_f16_bf16(a_vec))), b_vec);
ab_real_vec = vbfmlalbq_f32(ab_real_vec, vreinterpretq_bf16_u16(veorq_u16(vreinterpretq_u16_bf16(a_vec), vdupq_n_u16(0x8000))), b_vec);
// vrev32q swaps imag and real
// ar * bi + ai * br;
ab_imag_vec = vbfmlaltq_f32(ab_imag_vec, a_vec, vreinterpretq_bf16_u16(vrev32q_u16(vreinterpretq_u16_bf16(b_vec))));
ab_imag_vec = vbfmlalbq_f32(ab_imag_vec, a_vec, vreinterpretq_bf16_u16(vrev32q_u16(vreinterpretq_u16_bf16(b_vec))));
Interestingly, the godbolt.org snippet you've provided breaks Clang 18.1 if you add -O3
. Without it the assembly contains a lot of noise, a bit hard to read. Still, @MarkReedZ, the source looks really good! Was hoping it would be at least 20% 😢
Good catch I was playing around with another compiler on there.
Ubuntu 24.04's clang 18.1 sees the same bug when building this code. Issue opened: https://github.com/llvm/llvm-project/issues/107810
I'll try to move code around to avoid this later.
Code: https://github.com/MarkReedZ/SimSIMD/commit/09e89bb71e37b7e120a34f19012d3fc2b13183f4
Clang is choking on the flipping of the sign bit. I haven't come up with an alternative to these two. No amount of moving code around fixes the clang bug if veorq and vnegq are used to flip the bit.
//ab_real_vec = vbfmlalbq_f32(ab_real_vec, vreinterpretq_bf16_f16(vnegq_f16(vreinterpretq_f16_bf16(a_vec))), b_vec);
ab_real_vec = vbfmlalbq_f32(ab_real_vec, vreinterpretq_bf16_u16(veorq_u16(vreinterpretq_u16_bf16(a_vec), vdupq_n_u16(0x8000))), b_vec);
Hi @MarkReedZ! Any chance you have an update in this?
This is fixed in clang 19.1. I'm not sure what our approach to handling this should be as 18 will remain the default for some time. We could check the clang version number defines though apparently in some cases those may by overridden.
@MarkReedZ, can you please submit a PR that works with 19, and I'll try a few more ideas around your prototype?
The
vbfmlaltq_f32
andvbfmlalbq_f32
already have the benefit of skipping odd/even entries.