netease-youdao / QAnything

Question and Answer based on Anything.
https://qanything.ai
Apache License 2.0
10.55k stars 1.01k forks source link

[BUG] python安装版切换知识库后,向量搜索无结果 #401

Open xuzhenjun130 opened 3 weeks ago

xuzhenjun130 commented 3 weeks ago

是否已有关于该错误的issue或讨论? | Is there an existing issue / discussion for this?

该问题是否在FAQ中有解答? | Is there an existing answer for this in FAQ?

当前行为 | Current Behavior

刚启动,问答ok 只要切换问答库,就向量搜索就没有结果

期望行为 | Expected Behavior

能正常切换问答库

运行环境 | Environment

- OS: centos 7
- NVIDIA Driver: 535.161.07
- CUDA: 12.2
- NVIDIA GPU: rtx 4090
- NVIDIA GPU Memory: 32G

QAnything日志 | QAnything logs

No response

复现方法 | Steps To Reproduce

No response

备注 | Anything else?

qanything_kernel/connector/database/faiss/faiss_client.py

改用一次加载全部,后面逐个搜索,搜索完毕后再合并就正常了。

from langchain_community.vectorstores import FAISS
from langchain_community.docstore import InMemoryDocstore
from langchain_core.documents import Document
from qanything_kernel.configs.model_config import VECTOR_SEARCH_TOP_K, FAISS_LOCATION, FAISS_CACHE_SIZE
from typing import Optional, Union, Callable, Dict, Any, List, Tuple
from langchain_community.vectorstores.faiss import dependable_faiss_import
from qanything_kernel.utils.custom_log import debug_logger
from qanything_kernel.connector.database.mysql.mysql_client import KnowledgeBaseManager
from qanything_kernel.utils.general_utils import num_tokens
from functools import lru_cache
import shutil
import stat
import os
import platform

os_system = platform.system()

os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'  # 可能是由于是MacOS系统的原因

class SelfInMemoryDocstore(InMemoryDocstore):
    def add(self, texts: Dict[str, Document]) -> None:
        """Add texts to in memory dictionary.

        Args:
            texts: dictionary of id -> document.

        Returns:
            None
        """
        self._dict.update(texts)

@lru_cache(FAISS_CACHE_SIZE)
def load_vector_store(faiss_index_path, embeddings):
    debug_logger.info(f'load faiss index: {faiss_index_path}')
    return FAISS.load_local(faiss_index_path, embeddings, allow_dangerous_deserialization=True)

class FaissClient:
    def __init__(self, mysql_client: KnowledgeBaseManager, embeddings):
        self.mysql_client: KnowledgeBaseManager = mysql_client
        self.embeddings = embeddings
        self.faiss_clients: Dict[str, FAISS] = {}  # 存储不同 kb_id 对应的 FAISS 客户端

    def _load_all_kbs_to_memory(self):
        for kb_id in os.listdir(FAISS_LOCATION):
            faiss_index_path = os.path.join(FAISS_LOCATION, kb_id, 'faiss_index')
            if os.path.exists(faiss_index_path):
                faiss_client: FAISS = load_vector_store(faiss_index_path, self.embeddings)
            else:
                faiss = dependable_faiss_import()
                index = faiss.IndexFlatL2(768)
                docstore = SelfInMemoryDocstore()
                debug_logger.info(f'init FAISS kb_id: {kb_id}')
                faiss_client: FAISS = FAISS(self.embeddings, index, docstore, index_to_docstore_id={})
            self.faiss_clients[kb_id] = faiss_client
        debug_logger.info(f'FAISS loaded all kb_ids')

    async def search(self, kb_ids, query, filter: Optional[Union[Callable, Dict[str, Any]]] = None,
                     top_k=VECTOR_SEARCH_TOP_K):
        if not self.faiss_clients:
            self._load_all_kbs_to_memory()

        all_docs_with_score = []
        for kb_id in kb_ids:
            faiss_client = self.faiss_clients.get(kb_id)
            if not faiss_client:
                continue

            if filter is None:
                filter = {}
            debug_logger.info(f'FAISS search: {query}, {filter}, {top_k} for kb_id: {kb_id}')
            docs_with_score = await faiss_client.asimilarity_search_with_score(query, k=top_k, filter=filter,
                                                                               fetch_k=200)
            all_docs_with_score.extend(docs_with_score)

        all_docs_with_score.sort(key=lambda x: x[1])  # 按照分数排序
        merged_docs_with_score = self.merge_docs(all_docs_with_score[:top_k])  # 只保留前 top_k 个结果
        return merged_docs_with_score

    def merge_docs(self, docs_with_score):
        merged_docs = []
        docs_with_score = sorted(docs_with_score, key=lambda x: (x[0].metadata['file_id'], x[0].metadata['chunk_id']))
        for doc, score in docs_with_score:
            doc.metadata['score'] = score
            if not merged_docs or merged_docs[-1].metadata['file_id'] != doc.metadata['file_id']:
                merged_docs.append(doc)
            else:
                if merged_docs[-1].metadata['chunk_id'] == doc.metadata['chunk_id'] - 1:
                    if num_tokens(merged_docs[-1].page_content + doc.page_content) <= 800:
                        merged_docs[-1].page_content += '\n' + doc.page_content
                        merged_docs[-1].metadata['chunk_id'] = doc.metadata['chunk_id']
                    else:
                        merged_docs.append(doc)
                else:
                    merged_docs.append(doc)
        return merged_docs

    async def add_document(self, docs):
        kb_id = docs[0].metadata['kb_id']
        if kb_id not in self.faiss_clients:
            self._load_all_kbs_to_memory()
        faiss_client = self.faiss_clients.get(kb_id)

        if not faiss_client:
            raise ValueError(f"KB with id {kb_id} not found")

        add_ids = await faiss_client.aadd_documents(docs)
        chunk_id = 0
        for doc, add_id in zip(docs, add_ids):
            self.mysql_client.add_document(add_id, chunk_id, doc.metadata['file_id'], doc.metadata['file_name'],
                                           doc.metadata['kb_id'])
            chunk_id += 1

        debug_logger.info(f'add documents number: {len(add_ids)}')
        faiss_index_path = os.path.join(FAISS_LOCATION, kb_id, 'faiss_index')
        faiss_client.save_local(faiss_index_path)
        debug_logger.info(f'save faiss index: {faiss_index_path}')
        os.chmod(os.path.dirname(faiss_index_path), stat.S_IRUSR | stat.S_IWUSR | stat.S_IXUSR)
        return add_ids

    def delete_documents(self, kb_id, file_ids=None):
        if kb_id not in self.faiss_clients:
            self._load_all_kbs_to_memory()
        faiss_client = self.faiss_clients.get(kb_id)

        if not faiss_client:
            raise ValueError(f"KB with id {kb_id} not found")

        if file_ids is None:
            kb_index_path = os.path.join(FAISS_LOCATION, kb_id)
            if os.path.exists(kb_index_path):
                shutil.rmtree(kb_index_path)
                del self.faiss_clients[kb_id]
                debug_logger.info(f'delete kb_id: {kb_id}, {kb_index_path}')
                return
        else:
            doc_ids = self.mysql_client.get_documents_by_file_ids(file_ids)
            doc_ids = [doc_id[0] for doc_id in doc_ids]

        if not doc_ids:
            debug_logger.info(f'no documents to delete')
            return

        try:
            res = faiss_client.delete(doc_ids)
            debug_logger.info(f'delete documents: {res}')
            faiss_index_path = os.path.join(FAISS_LOCATION, kb_id, 'faiss_index')
            faiss_client.save_local(faiss_index_path)
            debug_logger.info(f'save faiss index: {faiss_index_path}')
            os.chmod(os.path.dirname(faiss_index_path), stat.S_IRUSR | stat.S_IWUSR | stat.S_IXUSR)
        except ValueError as e:
            debug_logger.warning(f'delete documents not find docs')
jiumar19 commented 2 weeks ago

这个问题的原因就是他faiss_client一个单例,里面的FAISS 反复在merge,造成不同kb之间、同一个kb的不同版本之间互相污染,要么得用你这种分索引搜索再merge结果,但是失去了top-k的意义了,要么就是针对bot去做FAISS的缓存,在缓存里面merge索引,这开源出来的py代码乱七八糟的,质量很堪忧

jiumar19 commented 2 weeks ago

而且从代码上看,FAISS的落盘他是完全没考虑并发的情况,建议自己加锁以免出现各种神奇现象