ashvardanian / SimSIMD

Up to 200x Faster Dot Products & Similarity Metrics — for Python, Rust, C, JS, and Swift, supporting f64, f32, f16 real & complex, i8, and bit vectors using SIMD for both AVX2, AVX-512, NEON, SVE, & SVE2 📐
https://ashvardanian.com/posts/simsimd-faster-scipy/
Apache License 2.0
988 stars 59 forks source link

Fused-Multiply-Add #214

Closed ashvardanian closed 3 weeks ago

ashvardanian commented 1 month ago

SimSIMD is expanding and becoming closer to a fully-fledged BLAS library. BLAS level 1 for now, but it's a start! SimSIMD will prioritize mixed and low-precision vector math, favoring modern AI workloads. For image & media processing workloads, the new fma and wsum kernels approach 65 GB/s per core on Intel Sapphire Rapids. That's 100x faster than the serial code for u8 inputs with f32 scaling and accumulation.

Contains the following element-wise operations:

In NumPy terms:

import numpy as np
def wsum(A: np.ndarray, B: np.ndarray, Alpha: float, Beta: float) -> np.ndarray:
    assert A.dtype == B.dtype, "Input types must match and affect the output style"
    return (Alpha * A + Beta * B).astype(A.dtype)
def fma(A: np.ndarray, B: np.ndarray, C: np.ndarray, Alpha: float, Beta: float) -> np.ndarray:
    assert A.dtype == B.dtype and A.dtype == C.dtype, "Input types must match and affect the output style"
    return (Alpha * A * B + Beta * C).astype(A.dtype)

This tiny set of operations is enough to implement a wide range of algorithms:

Benchmarks

On Intel Sapphire Rapids:

Run on (16 X 3900 MHz CPU s)
CPU Caches:
  L1 Data 48 KiB (x8)
  L1 Instruction 32 KiB (x8)
  L2 Unified 2048 KiB (x8)
  L3 Unified 61440 KiB (x1)
Load Average: 0.79, 0.75, 0.56
-------------------------------------------------------------------------------------------------------------
Benchmark                                                   Time             CPU   Iterations UserCounters...
-------------------------------------------------------------------------------------------------------------
fma_f64_haswell<1536d>/min_time:10.000/threads:1         1344 ns         1344 ns     10391897 abs_delta=0 bytes=27.4208G/s pairs=743.836k/s relative_error=0
wsum_f64_haswell<1536d>/min_time:10.000/threads:1        1040 ns         1040 ns     13465261 abs_delta=0 bytes=23.6376G/s pairs=961.815k/s relative_error=0
fma_f32_haswell<1536d>/min_time:10.000/threads:1          651 ns          651 ns     21534450 abs_delta=23.597n bytes=28.3033G/s pairs=1.53555M/s relative_error=47.0002n
wsum_f32_haswell<1536d>/min_time:10.000/threads:1         392 ns          392 ns     36225731 abs_delta=19.6436n bytes=31.3326G/s pairs=2.54985M/s relative_error=54.2672n
fma_f16_haswell<1536d>/min_time:10.000/threads:1          188 ns          188 ns     74334715 abs_delta=9.24044u bytes=49.1302G/s pairs=5.33097M/s relative_error=18.3975u
wsum_f16_haswell<1536d>/min_time:10.000/threads:1         130 ns          129 ns    106997523 abs_delta=12.015u bytes=47.4441G/s pairs=7.72203M/s relative_error=33.1896u
fma_bf16_haswell<1536d>/min_time:10.000/threads:1         225 ns          225 ns     62443286 abs_delta=1.91338m bytes=41.0221G/s pairs=4.45118M/s relative_error=3.81108m
wsum_bf16_haswell<1536d>/min_time:10.000/threads:1        161 ns          161 ns     86471812 abs_delta=1.36093m bytes=38.1318G/s pairs=6.20635M/s relative_error=3.75961m
fma_u8_sapphire<1536d>/min_time:10.000/threads:1         70.9 ns         70.9 ns    197232316 abs_delta=9.2812 bytes=64.9867G/s pairs=14.103M/s relative_error=2.45142m
wsum_u8_sapphire<1536d>/min_time:10.000/threads:1        50.6 ns         50.6 ns    276672248 abs_delta=8.89144 bytes=60.6775G/s pairs=19.7518M/s relative_error=3.28203m
fma_i8_sapphire<1536d>/min_time:10.000/threads:1         94.0 ns         94.0 ns    149003863 abs_delta=10.1192 bytes=49.0403G/s pairs=10.6424M/s relative_error=6.98359m
wsum_i8_sapphire<1536d>/min_time:10.000/threads:1        70.4 ns         70.4 ns    198873173 abs_delta=9.76862 bytes=43.613G/s pairs=14.197M/s relative_error=9.3472m
fma_f64_skylake<1536d>/min_time:10.000/threads:1         1340 ns         1340 ns     10460553 abs_delta=39.3003a bytes=27.5182G/s pairs=746.479k/s relative_error=78.2836a
wsum_f64_skylake<1536d>/min_time:10.000/threads:1        1036 ns         1036 ns     13484768 abs_delta=28.4608a bytes=23.717G/s pairs=965.047k/s relative_error=78.6298a
fma_f32_skylake<1536d>/min_time:10.000/threads:1          626 ns          626 ns     22261554 abs_delta=25.3818n bytes=29.4286G/s pairs=1.5966M/s relative_error=50.5553n
wsum_f32_skylake<1536d>/min_time:10.000/threads:1         386 ns          386 ns     35032887 abs_delta=19.7444n bytes=31.8146G/s pairs=2.58908M/s relative_error=54.5454n
fma_bf16_skylake<1536d>/min_time:10.000/threads:1         188 ns          188 ns     74667249 abs_delta=415.805u bytes=48.9511G/s pairs=5.31154M/s relative_error=827.962u
wsum_bf16_skylake<1536d>/min_time:10.000/threads:1        147 ns          147 ns     95128759 abs_delta=269.793u bytes=41.8834G/s pairs=6.81696M/s relative_error=745.331u
fma_f16_serial<1536d>/min_time:10.000/threads:1           900 ns          900 ns     15592180 abs_delta=2.97965u bytes=10.2444G/s pairs=1.11159M/s relative_error=5.93995u
wsum_f16_serial<1536d>/min_time:10.000/threads:1          821 ns          821 ns     17058449 abs_delta=1.11521u bytes=7.48594G/s pairs=1.21841M/s relative_error=3.07961u
fma_u8_serial<1536d>/min_time:10.000/threads:1           6692 ns         6692 ns      2089290 abs_delta=1.66854 bytes=688.583M/s pairs=149.432k/s relative_error=440.882u
wsum_u8_serial<1536d>/min_time:10.000/threads:1          5577 ns         5577 ns      2508971 abs_delta=2.32787 bytes=550.797M/s pairs=179.296k/s relative_error=859.403u
fma_i8_serial<1536d>/min_time:10.000/threads:1           6874 ns         6874 ns      2039761 abs_delta=5.14013 bytes=670.367M/s pairs=145.479k/s relative_error=3.54862m
wsum_i8_serial<1536d>/min_time:10.000/threads:1          5851 ns         5851 ns      2394538 abs_delta=6.36953 bytes=525.018M/s pairs=170.904k/s relative_error=6.09231m