Closed lemire closed 7 years ago
Something like this... using 64-bit masks... which can reduce the number of branches quite a bit...
size_t avx2_strstr_anysize64(const char* s, size_t n, const char* needle, size_t k) { assert(k > 0); assert(n > 0); const __m256i first = _mm256_set1_epi8(needle[0]); const __m256i last = _mm256_set1_epi8(needle[k - 1]); for (size_t i = 0; i < n; i += 64) { const __m256i block_first1 = _mm256_loadu_si256((const __m256i*)(s + i)); const __m256i block_last1 = _mm256_loadu_si256((const __m256i*)(s + i + k - 1)); const __m256i block_first2 = _mm256_loadu_si256((const __m256i*)(s + i + 32)); const __m256i block_last2 = _mm256_loadu_si256((const __m256i*)(s + i + k - 1 + 32)); const __m256i eq_first1 = _mm256_cmpeq_epi8(first, block_first1); const __m256i eq_last1 = _mm256_cmpeq_epi8(last, block_last1); const __m256i eq_first2 = _mm256_cmpeq_epi8(first, block_first2); const __m256i eq_last2 = _mm256_cmpeq_epi8(last, block_last2); uint32_t mask1 = _mm256_movemask_epi8(_mm256_and_si256(eq_first1, eq_last1)); uint32_t mask2 = _mm256_movemask_epi8(_mm256_and_si256(eq_first2, eq_last2)); uint64_t mask = mask1 | ((uint64_t)mask2 << 32); while (mask != 0) { int bitpos = __builtin_ctzll(mask); if (memcmp(s + i + bitpos + 1, needle + 1, k - 2) == 0) { return i + bitpos; } mask ^= mask & (-mask); } } return n; }
Done, thanks
Something like this... using 64-bit masks... which can reduce the number of branches quite a bit...