vespa-engine / vespa

AI + Data, online. https://vespa.ai
https://vespa.ai
Apache License 2.0
5.62k stars 589 forks source link

Optimize MaxSim with hamming (sum of max inverted hamming distances) #32232

Open jobergum opened 3 weeks ago

jobergum commented 3 weeks ago

Multi-vector MaxSim is increasingly important and we have optimizations for float cell precision, but I think we should also consider optimize for int8 with hamming as it approximates the dotproduct for normalized vectors.

It can help scale both ColBert and ColPali.

The following documents proton query latency for a query where true ranking 500 documents using the schema below (one thread per search) and ranking accuracy on the DocVQA benchmark.

Query Cell Type Document Cell Type nDCG@5 Latency Score Rank-Profile
float float 53.7 97 sum of max dotproducts full
float float (unpacked from int8) 51.5 105 sum of max dotproducts binary
int8 int8 48.6 506 sum of max inverted hamming binary-hamming

Schema with rank-profiles referenced above. Note that without the cell_cast from bfloat16 in the full profile, it is much slower (8-10x).

schema pdf_page {
    document pdf_page {
        field id type string {
            indexing: summary | attribute
        }
        field embedding type tensor<bfloat16>(patch{}, v[128]) {
            indexing: attribute
        }
        field binary_embedding type tensor<int8>(patch{}, v[16]) {
            indexing: attribute
        }
    }
    rank-profile full {
        inputs {
            query(qt) tensor<float>(querytoken{}, v[128])             

        }
        function max_sim() {
            expression {

                                sum(
                                    reduce(
                                        sum(
                                            query(qt) * cell_cast(attribute(embedding), float), v
                                        ),
                                        max, patch
                                    ),
                                    querytoken
                                )

            }
        }
        first-phase {
            expression {
                max_sim
            }
        }
    }
    rank-profile binary inherits full {
        function max_sim() {
            expression {

                                sum(
                                    reduce(
                                        sum(
                                            query(qt) * unpack_bits(attribute(binary_embedding)), v
                                        ),
                                        max, patch
                                    ),
                                    querytoken
                                )

            }
        }
        first-phase {
            expression {
                max_sim
            }
        }
    }
    rank-profile binary-hamming {
        inputs {
            query(qtb) tensor<int8>(querytoken{}, v[16])             

        }
        function max_sim() {
            expression {

                                sum(
                                    reduce(
                                        sum(
                                            1/(1+hamming(query(qtb), attribute(binary_embedding))), v
                                        ),
                                        max, patch
                                    ),
                                    querytoken
                                )

            }
        }
        first-phase {
            expression {
                max_sim
            }
        }
    }
}
jobergum commented 2 weeks ago
Query Cell Type Document Cell Type nDCG@5 Latency Score Rank-Profile
float float 54.4 x sum of max dotproducts full
float float (unpacked from int8 to 0,1) 52.4 x sum of max dotproducts binary
float float (unpacked from int8 scaled to -1,1) 52.4 x sum of max dotproducts binary-sign
int8 int8 48.7 x sum of max partial inverted hamming binary-hamming2
int8 int8 49.8 x sum of max full inverted hamming binary-hamming
int8 int8 34.86 x binary dotproduct binary-dotproduct

Where binary-hamming is defined as

sum(
  reduce(
    1/(1+ sum(hamming(query(qt), attribute(binary_embedding)),v)
   ),
    max, patch
  ),
  querytoken
)

Where we find the hamming distance for the entire 128 dim vector and covert that to a similarity score by 1/(1 + hamming_full_vector) instead of sum over each cell. This is also easier to optimize, and scores better than the initial suggestion which sums each cell score.

Full schema with rank-profiles referenced in the above table.


schema pdf_page {
    document pdf_page {
        field id type string {
            indexing: summary | attribute
        }
        field embedding type tensor<bfloat16>(patch{}, v[128]) {
            indexing: attribute
        }
        field binary_embedding type tensor<int8>(patch{}, v[16]) {
            indexing: attribute
        }
    }
    rank-profile full {
        inputs {
            query(qt) tensor<float>(querytoken{}, v[128])             

        }
        function max_sim() {
            expression {

                                sum(
                                    reduce(
                                        sum(
                                            query(qt) * cell_cast(attribute(embedding), float), v
                                        ),
                                        max, patch
                                    ),
                                    querytoken
                                )

            }
        }
        first-phase {
            expression {
                max_sim
            }
        }
    }
    rank-profile binary inherits full {
        function max_sim() {
            expression {

                                sum(
                                    reduce(
                                        sum(
                                            query(qt) * unpack_bits(attribute(binary_embedding)), v
                                        ),
                                        max, patch
                                    ),
                                    querytoken
                                )

            }
        }
        first-phase {
            expression {
                max_sim
            }
        }
    }
    rank-profile binary-hamming {
        inputs {
            query(qt) tensor<int8>(querytoken{}, v[16])             

        }
        function max_sim() {
            expression {
                                sum(
                                    reduce(
                                        1/(1+ sum(
                                            hamming(query(qt), attribute(binary_embedding)),v
                                        )),
                                        max, patch
                                    ),
                                    querytoken
                                )

            }
        }
        first-phase {
            expression {
                max_sim
            }
        }
    }
    rank-profile binary-dotproduct {
        inputs {
            query(qt) tensor<int8>(querytoken{}, v[16])             

        }
        function max_sim() {
            expression {

                                sum(
                                    reduce(
                                        sum(
                                            unpack_bits(query(qt)) * unpack_bits(attribute(binary_embedding)), v
                                        ),
                                        max, patch
                                    ),
                                    querytoken
                                )

            }
        }
        first-phase {
            expression {
                max_sim
            }
        }
    }
    rank-profile binary-sign inherits full {
        function max_sim() {
            expression {

                                sum(
                                    reduce(
                                        sum(
                                            query(qt) * (2*unpack_bits(attribute(binary_embedding)) -1), v
                                        ),
                                        max, patch
                                    ),
                                    querytoken
                                )

            }
        }
        first-phase {
            expression {
                max_sim
            }
        }
    }
    rank-profile binary-hamming2 {
        inputs {
            query(qt) tensor<int8>(querytoken{}, v[16])             

        }
        function max_sim() {
            expression {

                                sum(
                                    reduce(
                                        sum(
                                            1/(1+hamming(query(qt), attribute(binary_embedding))),v
                                        ),
                                        max, patch
                                    ),
                                    querytoken
                                )

            }
        }
        first-phase {
            expression {
                max_sim
            }
        }
    }
}```
havardpe commented 1 week ago

PR https://github.com/vespa-engine/vespa/pull/32320 merged in Vespa version 8.404.8