michaelfeil / hf-hub-ctranslate2

Connecting Transformers on HuggingFace Hub with CTranslate2
https://michaelfeil.github.io/hf-hub-ctranslate2/
MIT License
36 stars 2 forks source link

Embedding function for LangChain + Chroma #16

Closed kripper closed 9 months ago

kripper commented 1 year ago

We are testing with LangChain and Chroma, and need an embedding_function for:

db = Chroma.from_documents(texts, embedding_function)

How can we get the embedding_function?

Normally we would get it with:

embedding_function = SentenceTransformerEmbeddings(model_name = "all-MiniLM-L6-v2")

or

embedding_function = LlamaCppEmbeddings(model_path = model_dir + "model.bin")

kripper commented 1 year ago

Maybe something like this:

import sentencepiece as spm
class SentencePieceProcessorEmbedder:
    def __init__(self, model_path):
        self.sp = spm.SentencePieceProcessor(model_path)

    def embed_documents(self, docs):
        # TODO: Check
        return self.sp.encode(docs)

    def embed_query(self, query):
        # TODO: Check
        return self.sp.encode(query)

embedding_function = SentencePieceProcessorEmbedder(model_dir + "/tokenizer.model")

Depends on https://github.com/michaelfeil/hf-hub-ctranslate2/issues/17 to be able to test.

BBC-Esq commented 1 year ago

Here's a script, not sure if it helps...

import os
from hf_hub_ctranslate2 import CT2SentenceTransformer
from chromadb import PersistentClient, Collection
from typing import Optional

# Specify the input and output files
input_txt_file = "relevant_contexts.txt"
log_txt_file = "log.txt"

# Specify the model name
model_name = "BAAI/bge-large-en-v1.5"

# Get the current directory
current_directory = os.path.dirname(os.path.realpath(__file__))

# Specify the new name of the database file
db_file_name = "test_db.sqlite3"

# Create the full path for the new database file
db_file_path = os.path.join(current_directory, db_file_name)

def main():
    # Load the text from the input file
    with open(input_txt_file, 'r') as file:
        lines = file.readlines()

    # Create the CT2SentenceTransformer instance
    model = CT2SentenceTransformer(model_name, compute_type="float32", device="cuda")

    # Generate embeddings
    embeddings = model.encode(
        lines,
        batch_size=32,
        convert_to_numpy=True,
        normalize_embeddings=True,
    )

    # Create a PersistentClient with the new database file name as the path
    with PersistentClient(path=db_file_path) as client:
        # Create a collection in the database, if it already exists, it returns None
        collection: Optional[Collection] = client.create_collection("my_collection")

        if collection:
            # Add embeddings to the collection
            ids = [str(i) for i in range(len(embeddings))]
            collection.add(ids=ids, embeddings=embeddings)

    # Convert the embeddings to a string representation
    embeddings_str = "\n".join([str(embed) for embed in embeddings])

    # Write the embeddings to the log file
    with open(log_txt_file, 'w') as file:
        file.write(embeddings_str)

if __name__ == "__main__":
    main()
kripper commented 1 year ago

Please note that CTranslate2 is now supported by LangChain: https://python.langchain.com/docs/integrations/llms/ctranslate2

BBC-Esq commented 1 year ago

I am aware of this.