yahoojapan / NGT

Nearest Neighbor Search with Neighborhood Graph and Tree for High-dimensional Data
Apache License 2.0
1.26k stars 115 forks source link

Feature SIMD support for Hamming Distance. #160

Closed kpango closed 5 months ago

kpango commented 5 months ago

I have added full support for SSE2, AVX2, and AVX512 for performance improvement when using Distance Type Hamming. The code was expanded to the required number of loops per instruction set using macro to UNROLL with bit specification. The code has been designed to support dimensional data in multiples of 16.

AVX2 uses SSE2 for fraction processing, and AVX512 uses AVX2 and SSE2 for fraction processing. No fractional processing is performed in SSE2 because it is assumed that no fractional numbers will be produced.

I separately wrote a following test code to test the comparison between the function implemented this time and existing functions and general Hamming functions. The benchmark results for AVX2 and SSE2 are Error 0, so there should be no major problems with calculation accuracy and speed as shown in the following result.

However, I have not been able to test AVX512 because I do not have a test environment at this time.

Test Code is below:


#include <emmintrin.h>
#include <immintrin.h>
#include <cstdint>
#include <cstdio>
#include <cmath>
#include <chrono>
#include <random>
#include <vector>
#include <iostream>

#define UNROLL_MACRO_1(MACRO, BIT_SIZE) MACRO(0, 0)
#define UNROLL_MACRO_2(MACRO, BIT_SIZE) UNROLL_MACRO_1(MACRO, BIT_SIZE) MACRO(1, BIT_SIZE)
#define UNROLL_MACRO_4(MACRO, BIT_SIZE) UNROLL_MACRO_2(MACRO, BIT_SIZE) MACRO(2, BIT_SIZE * 2) MACRO(3, BIT_SIZE * 3)
#define UNROLL_MACRO_8(MACRO, BIT_SIZE) UNROLL_MACRO_4(MACRO, BIT_SIZE) MACRO(4, BIT_SIZE * 4) MACRO(5, BIT_SIZE * 5) MACRO(6, BIT_SIZE * 6) MACRO(7, BIT_SIZE * 7)

#define UNROLL_BODY_SSE2(i, BIT_SIZE)                                                                                                                                            \
    __m128i vres##i = _mm_xor_si128(_mm_loadu_si128(reinterpret_cast<const __m128i *>(uinta + BIT_SIZE)), _mm_loadu_si128(reinterpret_cast<const __m128i *>(uintb + BIT_SIZE))); \
    count += _mm_popcnt_u32(_mm_extract_epi32(vres##i, 0));                                                                                                                      \
    count += _mm_popcnt_u32(_mm_extract_epi32(vres##i, 1));                                                                                                                      \
    count += _mm_popcnt_u32(_mm_extract_epi32(vres##i, 2));                                                                                                                      \
    count += _mm_popcnt_u32(_mm_extract_epi32(vres##i, 3));

#define UNROLL_BODY_AVX2(i, BIT_SIZE)                                                                                                                                                     \
    __m256i vres##i = _mm256_xor_si256(_mm256_loadu_si256(reinterpret_cast<const __m256i *>(uinta + BIT_SIZE)), _mm256_loadu_si256(reinterpret_cast<const __m256i *>(uintb + BIT_SIZE))); \
    count += _mm_popcnt_u64(_mm256_extract_epi64(vres##i, 0));                                                                                                                            \
    count += _mm_popcnt_u64(_mm256_extract_epi64(vres##i, 1));                                                                                                                            \
    count += _mm_popcnt_u64(_mm256_extract_epi64(vres##i, 2));                                                                                                                            \
    count += _mm_popcnt_u64(_mm256_extract_epi64(vres##i, 3));

#define UNROLL_BODY_AVX512(i, BIT_SIZE)                                                                                                                                                   \
    __m512i vres##i = _mm512_xor_si512(_mm512_loadu_si512(reinterpret_cast<const __m512i *>(uinta + BIT_SIZE)), _mm512_loadu_si512(reinterpret_cast<const __m512i *>(uintb + BIT_SIZE))); \
    count += _mm_popcnt_u64(_mm512_extract_epi64(vres##i, 0));                                                                                                                            \
    count += _mm_popcnt_u64(_mm512_extract_epi64(vres##i, 1));                                                                                                                            \
    count += _mm_popcnt_u64(_mm512_extract_epi64(vres##i, 2));                                                                                                                            \
    count += _mm_popcnt_u64(_mm512_extract_epi64(vres##i, 3));                                                                                                                            \
    count += _mm_popcnt_u64(_mm512_extract_epi64(vres##i, 4));                                                                                                                            \
    count += _mm_popcnt_u64(_mm512_extract_epi64(vres##i, 5));                                                                                                                            \
    count += _mm_popcnt_u64(_mm512_extract_epi64(vres##i, 6));                                                                                                                            \
    count += _mm_popcnt_u64(_mm512_extract_epi64(vres##i, 7));

#define UNROLL_LOOP_SSE2(FACTOR) UNROLL_MACRO_##FACTOR(UNROLL_BODY_SSE2, 16)
#define UNROLL_LOOP_AVX2(FACTOR) UNROLL_MACRO_##FACTOR(UNROLL_BODY_AVX2, 32)
#define UNROLL_LOOP_AVX512(FACTOR) UNROLL_MACRO_##FACTOR(UNROLL_BODY_AVX512, 64)

#define PROCESS_LOOP(INSTRUCTION_SET, FACTOR, STEP) \
    while (uinta + STEP <= last)                    \
    {                                               \
        UNROLL_LOOP_##INSTRUCTION_SET(FACTOR)       \
            uinta += STEP;                          \
        uintb += STEP;                              \
    }

#define PROCESS_ALL_LOOPS(INSTRUCTION_SET, BIT_SIZE) \
    PROCESS_LOOP(INSTRUCTION_SET, 8, BIT_SIZE * 8)   \
    PROCESS_LOOP(INSTRUCTION_SET, 4, BIT_SIZE * 4)   \
    PROCESS_LOOP(INSTRUCTION_SET, 2, BIT_SIZE * 2)   \
    PROCESS_LOOP(INSTRUCTION_SET, 1, BIT_SIZE)

#define PROCESS_REMAINING_DATA_WITH_SSE2_AND_AVX2_AVX512() \
    if (uinta < last)                                      \
    {                                                      \
        PROCESS_LOOP(AVX2, 1, 32)                          \
        PROCESS_LOOP(SSE2, 1, 16)                          \
    }

#define PROCESS_REMAINING_DATA_WITH_SSE2_AVX2() \
    if (uinta < last)                           \
    {                                           \
        PROCESS_LOOP(SSE2, 1, 16)               \
    }

#define DO_NOTHING()

#define COMPARE_HAMMING_DISTANCE(INSTRUCTION_SET, BIT_SIZE, PROCESS_REMAINING_DATA) \
    const uint8_t *last = reinterpret_cast<const uint8_t *>(a + size);              \
    const uint8_t *uinta = reinterpret_cast<const uint8_t *>(a);                    \
    const uint8_t *uintb = reinterpret_cast<const uint8_t *>(b);                    \
    size_t count = 0;                                                               \
    PROCESS_ALL_LOOPS(INSTRUCTION_SET, BIT_SIZE)                                    \
    PROCESS_REMAINING_DATA                                                          \
    return static_cast<double>(count);

template <typename OBJECT_TYPE>
inline static double compareHammingDistance(const OBJECT_TYPE *a, const OBJECT_TYPE *b, size_t size)
{
#if defined(__AVX512F__)
    COMPARE_HAMMING_DISTANCE(AVX512, 64, PROCESS_REMAINING_DATA_WITH_SSE2_AND_AVX2_AVX512())
#elif defined(__AVX2__)
    COMPARE_HAMMING_DISTANCE(AVX2, 32, PROCESS_REMAINING_DATA_WITH_SSE2_AVX2())
#else
    COMPARE_HAMMING_DISTANCE(SSE2, 16, DO_NOTHING())
#endif
}

template <typename OBJECT_TYPE>
inline static double compareHammingDistanceBuiltin(const OBJECT_TYPE *a, const OBJECT_TYPE *b, size_t size)
{
    size_t count = 0;
    for (size_t i = 0; i < size; ++i)
    {
        count += __builtin_popcount(a[i] ^ b[i]);
    }
    return static_cast<double>(count);
}

template <typename OBJECT_TYPE>
inline static double compareHammingDistanceNGT(const OBJECT_TYPE *a, const OBJECT_TYPE *b, size_t size)
{
    const uint64_t *last = reinterpret_cast<const uint64_t *>(a + size);

    const uint64_t *uinta = reinterpret_cast<const uint64_t *>(a);
    const uint64_t *uintb = reinterpret_cast<const uint64_t *>(b);
    size_t count = 0;
    while (uinta < last)
    {
        count += _mm_popcnt_u64(*uinta++ ^ *uintb++);
        count += _mm_popcnt_u64(*uinta++ ^ *uintb++);
    }

    return static_cast<double>(count);
}

inline static double popCount(uint32_t x)
{
    x = (x & 0x55555555) + (x >> 1 & 0x55555555);
    x = (x & 0x33333333) + (x >> 2 & 0x33333333);
    x = (x & 0x0F0F0F0F) + (x >> 4 & 0x0F0F0F0F);
    x = (x & 0x00FF00FF) + (x >> 8 & 0x00FF00FF);
    x = (x & 0x0000FFFF) + (x >> 16 & 0x0000FFFF);
    return x;
}
template <typename OBJECT_TYPE>
inline static double compareHammingDistanceNoSIMD(const OBJECT_TYPE *a, const OBJECT_TYPE *b, size_t size)
{
    const uint32_t *last = reinterpret_cast<const uint32_t *>(a + size);
    const uint32_t *uinta = reinterpret_cast<const uint32_t *>(a);
    const uint32_t *uintb = reinterpret_cast<const uint32_t *>(b);
    size_t count = 0;
    while (uinta < last)
    {
        count += popCount(*uinta++ ^ *uintb++);
    }
    return static_cast<double>(count);
}

void generateTestData(std::vector<uint8_t> &data, std::vector<size_t> &dimensions, size_t num_tests, size_t max_dimension)
{
    std::mt19937 gen(42);
    std::uniform_int_distribution<> dis(0, 255);
    for (size_t i = 0; i < num_tests; ++i)
    {
        size_t dimension = ((rand() % (max_dimension / 16)) + 1) * 16;
        dimensions[i] = dimension;
        for (size_t j = 0; j < dimension; ++j)
        {
            data[i * max_dimension + j] = dis(gen);
        }
    }
}

int main()
{
    const size_t num_tests = 6000000;
    const size_t max_dimension = 8192;

    std::vector<uint8_t> test_data_a(max_dimension * num_tests);
    std::vector<uint8_t> test_data_b(max_dimension * num_tests);
    std::vector<size_t> dimensions(num_tests);

    generateTestData(test_data_a, dimensions, num_tests, max_dimension);
    generateTestData(test_data_b, dimensions, num_tests, max_dimension);

    std::vector<double> primitive_distances(num_tests);
    for (size_t i = 0; i < num_tests; ++i)
    {
        primitive_distances[i] = compareHammingDistanceBuiltin(test_data_a.data() + i * max_dimension, test_data_b.data() + i * max_dimension, dimensions[i]);
    }

    auto benchmark = [&](auto hammingDistanceFunc, const char *label)
    {
        double total_time = 0.0;
        size_t errors = 0;

        for (size_t n = 0; n < 100; ++n)
        {
            auto start = std::chrono::high_resolution_clock::now();
            for (size_t i = 0; i < num_tests; ++i)
            {
                double simd_distance = hammingDistanceFunc(test_data_a.data() + i * max_dimension, test_data_b.data() + i * max_dimension, dimensions[i]);
                if (simd_distance != primitive_distances[i])
                {
                    ++errors;
                }
            }
            auto end = std::chrono::high_resolution_clock::now();
            std::chrono::duration<double> elapsed = end - start;
            total_time += elapsed.count();
        }

        std::cout << label << ": " << total_time << " seconds, Errors: " << errors << "\n";
    };

    std::cout << "start benchmarking \n";
    benchmark(compareHammingDistanceBuiltin<uint8_t>, "Builtin POPCNT Hamming Distance");
    benchmark(compareHammingDistanceNoSIMD<uint8_t>, "NGT Original Non-SIMD Hamming Distance");
    benchmark(compareHammingDistanceNGT<uint8_t>, "NGT Original Hamming Distance");
    benchmark(compareHammingDistance<uint8_t>, "New Macro and SIMD Hamming Distance");
    std::cout << "benchmark finished \n";

    return 0;
}

Test Result:

> ./hamming_distance_avx2
start benchmarking
Builtin POPCNT Hamming Distance: 1481.64 seconds, Errors: 0
NGT Original Non-SIMD Hamming Distance: 2569.08 seconds, Errors: 0
NGT Original Hamming Distance: 393.814 seconds, Errors: 0
New Macro and SIMD Hamming Distance: 367.782 seconds, Errors: 0
benchmark finished

> ./hamming_distance_sse2
start benchmarking
Builtin POPCNT Hamming Distance: 1474.34 seconds, Errors: 0
NGT Original Non-SIMD Hamming Distance: 2581.55 seconds, Errors: 0
NGT Original Hamming Distance: 397.84 seconds, Errors: 0
New Macro and SIMD Hamming Distance: 433.393 seconds, Errors: 0
benchmark finished

> ./hamming_distance_avx2
start benchmarking
Builtin POPCNT Hamming Distance: 1477.9 seconds, Errors: 0
NGT Original Non-SIMD Hamming Distance: 2585.72 seconds, Errors: 0
NGT Original Hamming Distance: 402.682 seconds, Errors: 0
New Macro and SIMD Hamming Distance: 381.26 seconds, Errors: 0
New Template and SIMD Hamming Distance: 380.898 seconds, Errors: 0
benchmark finished

> ./hamming_distance_sse2
start benchmarking
Builtin POPCNT Hamming Distance: 1488.5 seconds, Errors: 0
NGT Original Non-SIMD Hamming Distance: 2572.27 seconds, Errors: 0
NGT Original Hamming Distance: 407.198 seconds, Errors: 0
New Macro and SIMD Hamming Distance: 429.891 seconds, Errors: 0
New Template and SIMD Hamming Distance: 433.764 seconds, Errors: 0
benchmark finished
masajiro commented 5 months ago

Thanks!