svilupp / PromptingTools.jl

Streamline your life using PromptingTools.jl, the Julia package that simplifies interacting with large language models.
https://svilupp.github.io/PromptingTools.jl/dev/
MIT License
96 stars 9 forks source link

[FR] Improve performance of Bool embeddings #144

Closed svilupp closed 1 month ago

svilupp commented 2 months ago

domluna posted on generativeAI Slack a really nice gist using Bool embeddings (held in Int8) + StaticArrays (here).

It seems to provide huge performance benefits compared to my fairly trivial Bool retrieval implementation in RAGTools (here).

playing around with RAG and binary vectors https://huggingface.co/blog/embedding-quantization the idea of 64 bytes is from this post https://www.mixedbread.ai/blog/binary-mrl this is brute force parallel implementation for search that assumes the data is stored as binary in byte elements (int8, uint8). so 512 bits is 64 int8 elements, or a 64 element static vector using StaticArrays. For 100M vector dataset I get < 1s comparison time on my M1 macbook air. I’m wondering if there’s anything I can do to make it faster.

It would be excellent to:

Ideally, also integrate into AIHelpMe so everyone downstream can benefit (these TODOs would be transferred to the other repo)

svilupp commented 2 months ago

For reference, I did some benchmarking few weeks ago and Bool embeddings performed really well (especially assuming that we would use a reranker/cross-encoder downstream

With top_k=20, 1024dims in Bool still retains 97% recall and very competitive MRR:

image
domluna commented 2 months ago

nice!

StaticArrays makes it faster but it's not massive ~2x at most (on 1M vectors), so techinically not explitcitly required:

# no StaicArrays
julia> @b $k_closest_parallel(X1, q1, 10)
21.741 ms (52 allocs: 5.250 KiB)

# q2 is a StaticArray
julia> @b $k_closest_parallel(X1, q2, 10)
12.545 ms (52 allocs: 5.500 KiB)

# q2 is a StaticArray, X2 is a list of static arrays
julia> @b $k_closest_parallel(X2, q2, 10)
9.180 ms (52 allocs: 5.500 KiB)

the cool thing about the binary embeddings is that you can keep everything in memory and you don't need an enourmously powerful computer. for 1 billion rows you would need 64GB instead of 1TB, which greatly decreases costs. Furthermore you can potentially use this as part of a reranking pipeline where you keep a higher dimensional embedding version on disk and then seek the relevant rows from the binary embedding similarity.

domluna commented 2 months ago
function hamming_distance(x1::AbstractArray{T}, x2::AbstractArray{T})::Int where {T<:INT}
    s = 0
    for i in eachindex(x1, x2)
        s += hamming_distance(x1[i], x2[i])
    end
    s
end

changing the sum calc to the above now produces these timings (adding simd or inbounds macros seems to have no effect)

julia> @b k_closest_parallel(X1, q1, 10)
4.710 ms (52 allocs: 5.250 KiB)

julia> @b k_closest_parallel(X1, q2, 10)
4.080 ms (52 allocs: 5.500 KiB)

julia> @b k_closest_parallel(X2, q2, 10)
4.070 ms (52 allocs: 5.500 KiB)

so using StaticArrays doesn't add much.

julia> versioninfo()
Julia Version 1.11.0-beta1
Commit 08e1fc0abb9 (2024-04-10 08:40 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: macOS (arm64-apple-darwin22.4.0)
  CPU: 8 × Apple M1
  WORD_SIZE: 64
  LLVM: libLLVM-16.0.6 (ORCJIT, apple-m1)
Threads: 4 default, 0 interactive, 2 GC (on 4 virtual cores)
Environment:
  JULIA_STACKTRACE_MINIMAL = true
  DYLD_LIBRARY_PATH = /Users/lunaticd/.wasmedge/lib
  JULIA_EDITOR = nvim
domluna commented 2 months ago
using StaticArrays
using Base.Threads

INT = Union{Int8,UInt8}

function hamming_distance(x1::T, x2::T)::Int where {T<:INT}
    c = 0
    for i = 0:7
        c += ((x1 >> i) & 1) ⊻ ((x2 >> i) & 1)
    end
    return Int(c)
end

function hamming_distance(x1::AbstractArray{T}, x2::AbstractArray{T})::Int where {T<:INT}
    s = 0
    @inbounds @simd for i in eachindex(x1, x2)
        s += hamming_distance(x1[i], x2[i])
    end
    s
end

function k_closest_parallel(
    db::AbstractArray{V},
    query::AbstractVector{T},
    k::Int,
) where {T<:INT,V<:AbstractVector{T}}
    n = length(db)
    t = nthreads()
    task_ranges = [(i:min(i + n ÷ t - 1, n)) for i = 1:n÷t:n]
    tasks = map(task_ranges) do r
        Threads.@spawn k_closest(view(db, r), query, k)
    end
    results = fetch.(tasks)
    sort!(vcat(results...), by = x -> x[1])[1:k]
end

function k_closest(
    db::AbstractVector{V},
    query::AbstractVector{T},
    k::Int,
) where {T<:INT,V<:AbstractVector{T}}
    results = Vector{Pair{Int,Int}}(undef, k)
    m = typemax(Int)
    fill!(results, (m => -1))

    @inbounds for i in eachindex(db)
        d = hamming_distance(db[i], query)
        for j = 1:k
            if d < results[j][1]
                old = results[j]
                results[j] = d => i
                for l = j+1:k-1
                    old, results[l] = results[l], old
                end
                break
            end
        end
    end

    return results
end

the core operation takes 20ns on a static array but when everything is combined we actually get even lower than that on average.

On a 1M vector of where each element is a 64 element vector of Int8

julia> @b k_closest_parallel(X, q, 1)
2.816 ms (50 allocs: 3.547 KiB)

julia> @b k_closest_parallel(X, q, 5)
3.142 ms (52 allocs: 4.516 KiB)

julia> @b k_closest_parallel(X, q, 10)
3.560 ms (52 allocs: 5.500 KiB)

julia> @b k_closest_parallel(X, q, 50)
7.449 ms (54 allocs: 13.938 KiB)

julia> @b k_closest_parallel(X, q, 100)
11.626 ms (54 allocs: 23.781 KiB)

naively looping and doing the distance op 1M times without any additional work would be 20ms but parallelized over 4 cores we're less than that.

svilupp commented 2 months ago

I'm actually slightly hesitant to enforce threading under the hood, because:

Btw. I'm a bit lost in the references above -- what's your recommendation wrt StaticArrays? Do you think they are a valuable addition or should we keep it simple?

domluna commented 2 months ago

it doesn't seem that StaticArrays is absolutely necessary. The performance will be better but it's not drastically better. For 1M rows, 17.5% faster. Parallel isn't necessary, but it's just a situation where the problem is easily parallelizable (mapreduce pattern) so we do get very close to perfect scaling with cores, i.e., 4 cores makes it ~4x faster, 16 cores - 16x faster.

svilupp commented 1 month ago

Linking a great writeup by @domluna here: https://github.com/domluna/tinyrag

In particular, this function looks the same as the inner function here.

It needs some benchmarking and potentially mini PR if someone is interested!

EDIT: I should have said the PR could be:

svilupp commented 1 month ago

Closed by https://github.com/svilupp/PromptingTools.jl/pull/152