qdrant / qdrant-client

Python client for Qdrant vector search engine
https://qdrant.tech
Apache License 2.0
733 stars 117 forks source link

How to change distance model when using FastEmbed? #734

Open paluigi opened 1 month ago

paluigi commented 1 month ago

Issue

Not able to change distance model when creating a collection with FastEmbed.

Minimal steps to reproduce

from qdrant_client import QdrantClient
from qdrant_client.models import Distance
client = QdrantClient("coicop.db")
client.set_model("sentence-transformers/paraphrase-multilingual-mpnet-base-v2", distance=Distance.DOT)
client.get_fastembed_vector_params()

Result

{'fast-paraphrase-multilingual-mpnet-base-v2': VectorParams(size=768, distance=<Distance.COSINE: 'Cosine'>, hnsw_config=None, quantization_config=None, on_disk=None, datatype=None, multivector_config=None)}

Expected result

Distance should be Distance.DOT

Environment

OS: Ubuntu 22.04.1 qdrant-client==1.10.1 fastembed==0.3.4

More details

When creating a collection with Sentence Transformers I am able to set a different distance model (DOT, EUCLID). With FastEmbed it seems the distance is only cosine. Also looking to the code, it seems all models are initialized with cosine distance only.

joein commented 1 month ago

Hi @paluigi

Methods like add and query are just convenience methods, which provide some default configurations. The only workaround to use them with another metric is to create a collection with create_collection call (before calling either of add or query). However, take into account, that add and query methods have specific rules for the vector names, they should be the same as the ones returned by get_vector_field_name()

paluigi commented 1 month ago

HI @joein , thanks for the reply. If I want to use FastEmbed then what would be the correct way to use the Distance.DOT metric in a collection?

For what I can see, in my example above the collection would be initialized with the Distance.COSINE metric, even if I tried to set another metric.

hash-f commented 1 month ago

The set_model method does not have a distance parameter.

def set_model(
        self,
        embedding_model_name: str,
        max_length: Optional[int] = None,
        cache_dir: Optional[str] = None,
        threads: Optional[int] = None,
        providers: Optional[Sequence["OnnxProvider"]] = None,
        **kwargs: Any,
    ):
    # Method body

You can specify a different metric while creating a collection. But as far as I know you can not specify a default metric for all operations on the client.

While creating the collection the vector name and the size of the vector field has to match the model specs. One way to do this would be to create a helper method using pieces from the FastEmbedMixin.

import uuid
from qdrant_client import QdrantClient
from qdrant_client import models
from fastembed import TextEmbedding

def add_points(
    collection_name, documents, ids=None, model_name=None, distance=models.Distance.DOT
):
    # We could also pass in the client as a param
    client = QdrantClient(path="./db/")
    if model_name is not None:
        client.set_model(model_name)

    # Get the vector field name and the vector size for the chosen model.
    # Using the exact name is important because client.query() looks at the
    # {vector_field_name} vector.
    vector_field_name = client.get_vector_field_name()
    vector_params = client.get_fastembed_vector_params()

    # Create the collection if it does not exist.
    if not client.collection_exists(collection_name):
        client.create_collection(
            collection_name=collection_name,
            vectors_config={
                vector_field_name: models.VectorParams(
                    size=vector_params[vector_field_name].size,
                    distance=distance,
                )
            },
        )

    # Load the embedding model from FastEmbed.
    if model_name is not None:
        embedding_model = TextEmbedding(model_name)
    else:
        embedding_model = TextEmbedding()

    # Create a generator for UUIDs, if ids are not passed.
    if ids is None:
        ids = iter(lambda: uuid.uuid4().hex, None)
    elif type(ids) is list:
        ids = iter(ids)

    # Upload points
    client.upload_points(
        collection_name=collection_name,
        points=[
            models.PointStruct(
                id=next(ids),
                vector={
                    vector_field_name: embedding,
                },
            )
            # Embed the documents using the embedding_model
            for embedding in embedding_model.embed(documents)
        ],
    )

documents = [
    "This is built to be faster and lighter than other embedding libraries e.g. Transformers, Sentence-Transformers, etc.",
    "fastembed is supported by and maintained by Qdrant.",
]

model_name = "sentence-transformers/paraphrase-multilingual-mpnet-base-v2"
add_points(
    collection_name="test_collection",
    documents=documents,
    model_name=model_name,
)

Because we have followed the naming conventions we can use the query method out of the box.

client.set_model(model_name)

search_result = client.query(
    collection_name="test_collection",
    query_text=query_text,
)

This is a very minimal implementation and there might be better ways to do this. But it could be modified to suit your purpose.