chroma-core / chroma

the AI-native open-source embedding database
https://www.trychroma.com/
Apache License 2.0
14.71k stars 1.23k forks source link

[Bug]: EmbeddingFunction not working as documented in migration docs #2835

Open davidtbo opened 6 days ago

davidtbo commented 6 days ago

What happened?

Followed instructions here exactly: https://docs.trychroma.com/deployment/migration#migration-to-0.4.16---november-7,-2023

from chromadb.api.types import Documents, Embeddings, Embeddable, Images, Protocol
from transformers import pipeline
from typing import TypeVar, Union

config = {
    "embedding_model": "microsoft/Multilingual-MiniLM-L12-H384",
    # ...other stuff omitted for brevity
}

pipeline = pipeline(
    task="feature-extraction", 
    model=config['embedding_model']
)

Embeddable = Union[Documents, Images]
D = TypeVar("D", bound=Embeddable, contravariant=True)

class EmbeddingFunction(Protocol[D]):
    def __call__(self, input: D) -> Embeddings:
        return pipeline(input, return_tensors=True)

Got error in log output below, instructing me to do the thing I just did. Any help would be greatly appreciated. :)

Versions

chromadb version 0.5.7, chroma-hnswlib 0.7.6 (this was installed by chroma, not me directly) python 3.10.12 Ubuntu 22.04

Relevant log output

ValueError: Expected EmbeddingFunction.__call__ to have the following signature: odict_keys(['self', 'input']), got odict_keys(['self', 'args', 'kwargs'])
E           Please see https://docs.trychroma.com/guides/embeddings for details of the EmbeddingFunction interface.
E           Please note the recent change to the EmbeddingFunction interface: https://docs.trychroma.com/deployment/migration#migration-to-0.4.16---november-7,-2023
tazarov commented 4 days ago

@davidtbo, can you try this:

from transformers import pipeline
from typing import Dict, Any
from chromadb.api.types import (
    Documents,
    EmbeddingFunction,
    Embeddings
)

class MyCustomEmbeddingFunction(EmbeddingFunction[Documents]):
    def __init__(
            self,
            **kwargs: Dict[str, Any]
    ):
        """Initialize the embedding function."""
        self._pipeline = pipeline(
            task="feature-extraction", 
            model=kwargs.get('embedding_model')
        )

    def __call__(self, input: Documents) -> Embeddings:
        """Embed the input documents."""
        return self._pipeline(input, return_tensors=True)

if __name__ == "__main__":
    embedding_function = MyCustomEmbeddingFunction(embedding_model="microsoft/Multilingual-MiniLM-L12-H384")
    print(embedding_function(["Hello, world!"]))

The EmbeddingFunction can be directly inherited with the correct type(s).