chroma-core / chroma

the AI-native open-source embedding database
https://www.trychroma.com/
Apache License 2.0
14.74k stars 1.23k forks source link

用存在chromadb中一模一样的文档名检索,但有时无法找到这个文档? #1259

Open leaf-ygq opened 11 months ago

leaf-ygq commented 11 months ago

I have 15,000 pieces of data stored in chromadb. Each piece of data contains (document name, document content). I searched the document names of these 15000 data without modification, but 66 of the document names did not retrieve the correct document

tazarov commented 11 months ago

@leaf-ygq, unfortunately, names as such are not the ideal candidates for semantic search, which is probably why you see this kind of discrepancy. Instead, I suggest embedding a summary of the document's content.

leaf-ygq commented 11 months ago

@tazarov Thanks! Maybe I didn't explain the problem clearly. image I used the command in the picture above to add 15,000 unique document names to "documents“ Then, I use results = collection.query( query_texts=["doc1 to doc 15000"], n_results=1 ) to query . The queried text is the same as the stored text. However, there are 66 of the queried text did not retrieve the correct stored text. Why does this happen? Is it caused by the underlying retrieval mechanism of chromadb?

tazarov commented 11 months ago

@ leaf-ygq, the challenge you face comes from the embedding function. I did some testing with higher M and ef_construction values for hnsw lib (you can see the experiments in this notebook). As you'll notice strings like doc2329 and doc2321 appear to be closely related in the latent space of your embedding function.

Maybe if you elaborate a little on your use case we can help you with your use case.

zh2024 commented 1 month ago

same here

tazarov commented 1 month ago

same here

@zh2024, can you try to create your collection (this only when the collection is first created) with "hnsw:search_ef":100 in your collection metadata:

collection = client.get_or_create_collection("my_collection",metadata={"hnsw:search_ef":100})

Then, import your data and query it.

Let me know the results.

zh2024 commented 1 month ago

@tazarov I can't share my raw data. I generate some random strings, get the similar results.

I think this is a bug in chromadb. I don't find the same situation in faiss.

# %%
import chromadb

path = "./chromadb_kb_test"

client = chromadb.PersistentClient(path=path)

# %%
collection_name = "test_collection"

collections = client.list_collections()
print("collections",collections)
for collection in collections:
    if collection.name == collection_name:
        client.delete_collection(collection_name)
        break
client.list_collections()

# %%
import random

words = [
    "the", "of", "and", "to", "a", "in", "that", "is", "was", "he", "for", "it", "with", "as", "his", "on", "be", "at",
    "by", "I", "this", "had", "not", "but", "what", "all", "were", "we", "when", "your", "can", "there", "use", "an",
    "each", "which", "she", "do", "how", "their", "if", "will", "up", "about", "one", "out", "them", "could", "so",
    "my", "did", "me", "like", "him", "her", "over", "some", "say", "see", "two", "than", "more", "been", "no", "now",
    "then", "its", "would", "make", "time", "into", "has", "look", "who", "know", "go", "come", "people", "just", "could",
    "year", "because", "good", "new", "very", "give", "our", "under", "name", "little", "work", "man", "show", "well",
    "back", "even", "most", "through", "after", "life", "day", "same", "think", "last", "right", "use", "tell", "while",
    "child", "world", "over", "still", "try", "ask", "men", "need", "should", "three", "find"
]

def gen_random_str(length):
    return ' '.join(random.choice(words) for _ in range(length))

def generate_str_arr():
    string_array = [[gen_random_str(random.randint(5, 50)), index] for index in range(10000)]
    return string_array

# generate random strings
chunks = generate_str_arr()
print(chunks)

# %%
from chromadb import Documents, EmbeddingFunction, Embeddings
from sentence_transformers import SentenceTransformer
import torch

class MyEmbeddingFunction(EmbeddingFunction):
    def __init__(self) -> None:
        super().__init__()
        self.model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
        device = "cuda" if torch.cuda.is_available() else "cpu"
        self.model.to(device)

    def __call__(self, input: Documents) -> Embeddings:
        return self.model.encode(input).tolist()

emb_fn = MyEmbeddingFunction()

collection = client.create_collection(name=collection_name, embedding_function=emb_fn, metadata={"hnsw:space": "ip"})

# %%
documents = [chunk[0] for chunk in chunks]
metadatas = [{"id": chunk[1]} for chunk in chunks]
print(documents)
print(metadatas)

# %%
collection.add(documents=documents, metadatas=metadatas, ids=[f"{chunk[1]}" for chunk in chunks])

# %%
from typing import List
import torch

def get_query_results(query_texts: List[str], n_results=1):
    results = collection.query(
        query_texts=query_texts,
        n_results=n_results,
        include=["metadatas","distances","documents"]
    )

    res = {}
    res["documents"] = results["documents"]
    res["similarities"] = ((1 - torch.tensor(results["distances"])) * 100).tolist()
    res["ids"] = [[metadata["id"] for metadata in metadata_list] for metadata_list in results["metadatas"]]
    return res

get_query_results(documents[:5])

# %%
documents = [chunk[0] for chunk in chunks]
id_list = [chunk[1] for chunk in chunks]

print(documents)
print(id_list)

# %%
import itertools
doc_it = iter(documents)
id_it = iter(id_list)
group_size = 1000
pred_ids = []

unmatch_docs = []
unmatch_ids = []
while True:
    doc_chunks = list(itertools.islice(doc_it, group_size))
    id_chunks = list(itertools.islice(id_it, group_size))
    if not doc_chunks:
        break
    results = get_query_results(doc_chunks, n_results=1)
    ids = [item for sublist in results["ids"] for item in sublist]

    unmatch_index = [i for i, (x, y) in enumerate(zip(id_chunks, ids)) if x != y]
    unmatch_docs.extend([doc_chunks[i] for i in unmatch_index])
    unmatch_ids.extend([id_chunks[i] for i in unmatch_index])

    pred_ids.extend(ids)
print(pred_ids)

# %%
# get accuracy
sum(x == y for x, y in zip(id_list, pred_ids)) / len(id_list)

# %%
unmatch_pairs = list(zip(unmatch_docs, unmatch_ids))
len(unmatch_pairs)

# %%
unmatch_pairs

# %%
print(unmatch_pairs[0][0])
results = get_query_results([unmatch_pairs[0][0]], n_results=1)
results

the result show that accuracy is 99.88%(sometimes lower than 98%), but in my case, it should be 100%, because every query is in the db.

And I find that if you change n_results=10, you will get right result.