Closed kripper closed 9 months ago
Maybe something like this:
import sentencepiece as spm
class SentencePieceProcessorEmbedder:
def __init__(self, model_path):
self.sp = spm.SentencePieceProcessor(model_path)
def embed_documents(self, docs):
# TODO: Check
return self.sp.encode(docs)
def embed_query(self, query):
# TODO: Check
return self.sp.encode(query)
embedding_function = SentencePieceProcessorEmbedder(model_dir + "/tokenizer.model")
Depends on https://github.com/michaelfeil/hf-hub-ctranslate2/issues/17 to be able to test.
Here's a script, not sure if it helps...
import os
from hf_hub_ctranslate2 import CT2SentenceTransformer
from chromadb import PersistentClient, Collection
from typing import Optional
# Specify the input and output files
input_txt_file = "relevant_contexts.txt"
log_txt_file = "log.txt"
# Specify the model name
model_name = "BAAI/bge-large-en-v1.5"
# Get the current directory
current_directory = os.path.dirname(os.path.realpath(__file__))
# Specify the new name of the database file
db_file_name = "test_db.sqlite3"
# Create the full path for the new database file
db_file_path = os.path.join(current_directory, db_file_name)
def main():
# Load the text from the input file
with open(input_txt_file, 'r') as file:
lines = file.readlines()
# Create the CT2SentenceTransformer instance
model = CT2SentenceTransformer(model_name, compute_type="float32", device="cuda")
# Generate embeddings
embeddings = model.encode(
lines,
batch_size=32,
convert_to_numpy=True,
normalize_embeddings=True,
)
# Create a PersistentClient with the new database file name as the path
with PersistentClient(path=db_file_path) as client:
# Create a collection in the database, if it already exists, it returns None
collection: Optional[Collection] = client.create_collection("my_collection")
if collection:
# Add embeddings to the collection
ids = [str(i) for i in range(len(embeddings))]
collection.add(ids=ids, embeddings=embeddings)
# Convert the embeddings to a string representation
embeddings_str = "\n".join([str(embed) for embed in embeddings])
# Write the embeddings to the log file
with open(log_txt_file, 'w') as file:
file.write(embeddings_str)
if __name__ == "__main__":
main()
Please note that CTranslate2 is now supported by LangChain: https://python.langchain.com/docs/integrations/llms/ctranslate2
I am aware of this.
We are testing with LangChain and Chroma, and need an
embedding_function
for:db = Chroma.from_documents(texts, embedding_function)
How can we get the
embedding_function
?Normally we would get it with:
embedding_function = SentenceTransformerEmbeddings(model_name = "all-MiniLM-L6-v2")
or
embedding_function = LlamaCppEmbeddings(model_path = model_dir + "model.bin")