eosphoros-ai / DB-GPT

AI Native Data App Development framework with AWEL(Agentic Workflow Expression Language) and Agents
http://docs.dbgpt.cn
MIT License
13.8k stars 1.86k forks source link

[New Feature] ES VectorStore #1483

Closed IamWWT closed 6 months ago

IamWWT commented 6 months ago

已经验证通过的功能:

1)可以新建知识空间(仅支持英文,不支持中文) 2)可以上传文档进行EMBEDDING, 3)可以逐个删除上传的每一个文档。 4)可以搜索对话。

涉及修改的文件内容如下:

1).env 添加如下

VECTOR_STORE_TYPE=ElasticSearch ElasticSearch_URL=127.0.0.1 ElasticSearch_PORT=9200 ElasticSearch_USERNAME=elastic ElasticSearch_PASSWORD=i=+iLw9y0Jduq86XTi6W

2)dbgpt/_private/config.py 添加如下

    self.ElasticSearch_URL = os.getenv("ElasticSearch_URL", "127.0.0.1")
    self.ElasticSearch_PORT = os.getenv("ElasticSearch_PORT", "9200")
    self.ElasticSearch_USERNAME = os.getenv("ElasticSearch_USERNAME", None)
    self.ElasticSearch_PASSWORD = os.getenv("ElasticSearch_PASSWORD", None)

3)dbgpt/app/knowledge/service.py 的 def delete_document():修改如下

    def delete_document(self, space_name: str, doc_name: str):
        """delete document
        Args:
            - space_name: knowledge space name
            - doc_name: doocument name
        """ 
        document_query = KnowledgeDocumentEntity(doc_name=doc_name, space=space_name)
        documents = knowledge_document_dao.get_documents(document_query) 
        if len(documents) != 1:
            raise Exception(f"there are no or more than one document called {doc_name}")
        vector_ids = documents[0].vector_ids 
        if vector_ids is not None:
            ## wwt add
            embedding_factory = CFG.SYSTEM_APP.get_component("embedding_factory", EmbeddingFactory)
            embedding_fn = embedding_factory.create(model_name=EMBEDDING_MODEL_CONFIG[CFG.EMBEDDING_MODEL])
            ### wwt 修改
            if CFG.VECTOR_STORE_TYPE == "Milvus":
                config = VectorStoreConfig(name=space_name,            
                                        embedding_fn=embedding_fn,  ## wwt add
                                        max_chunks_once_load=CFG.KNOWLEDGE_MAX_CHUNKS_ONCE_LOAD,   ## wwt add
                                        user=CFG.MILVUS_USERNAME,
                                        password=CFG.MILVUS_PASSWORD,
                                        )
            elif CFG.VECTOR_STORE_TYPE == "ElasticSearch":
                logger.info(f"wwt add 正在删除ES类型的文档。")
                config = VectorStoreConfig(name=space_name, embedding_fn=embedding_fn,  ## wwt add
                                        max_chunks_once_load=CFG.KNOWLEDGE_MAX_CHUNKS_ONCE_LOAD,   ## wwt add
                                        user=CFG.ElasticSearch_USERNAME,
                                        password=CFG.ElasticSearch_PASSWORD,
                                        )
            elif CFG.VECTOR_STORE_TYPE == "Chroma":
                config = VectorStoreConfig(name=space_name)
            else:
                config = VectorStoreConfig(name=space_name) 
            vector_store_connector = VectorStoreConnector(
                vector_store_type=CFG.VECTOR_STORE_TYPE,
                vector_store_config=config,
            ) 
            # delete vector by ids 
            vector_store_connector.delete_by_ids(vector_ids)
        # delete chunks 
        document_chunk_dao.raw_delete(documents[0].id)
        # delete document
        return knowledge_document_dao.raw_delete(document_query)

4)dbgpt/storage/vector_store/init.py 新增修改如下:

def _import_elastic() -> Any:
    from dbgpt.storage.vector_store.elastic_store import ElasticStore

    return ElasticStore

def __getattr__(name: str) -> Any:
    if name == "Chroma":
        return _import_chroma()
    elif name == "Milvus":
        return _import_milvus()
    elif name == "Weaviate":
        return _import_weaviate()
    elif name == "PGVector":
        return _import_pgvector()
    elif name == "ElasticSearch":
        return _import_elastic()
    else:
        raise AttributeError(f"Could not find: {name}")

__all__ = ["Chroma", "Milvus", "Weaviate", "PGVector", "ElasticSearch"]

5)dbgpt/storage/vector_store/ 新增文件elastic_store.py如下:

"""Elasticsearch vector store for 全文索引---- for 全文检索."""
from __future__ import annotations

import json
import logging
import os
from typing import Any, Iterable, List, Optional

from dbgpt._private.pydantic import Field
from dbgpt.core import Chunk, Embeddings
from dbgpt.core.awel.flow import Parameter, ResourceCategory, register_resource
from dbgpt.storage.vector_store.base import (
    _COMMON_PARAMETERS,
    VectorStoreBase,
    VectorStoreConfig,
)
from dbgpt.storage.vector_store.filters import FilterOperator, MetadataFilters
from dbgpt.util import string_utils
from dbgpt.util.i18n_utils import _

logger = logging.getLogger(__name__)

try:
    import jieba
    import jieba.analyse 
    from langchain.schema import Document
    from langchain.vectorstores.elasticsearch import ElasticsearchStore
    from elasticsearch import Elasticsearch 
except ImportError:
    raise ValueError(
        "Could not import elasticsearch python package. "
        "Please install it with `pip install elasticsearch`."
    )  

@register_resource(
    _("ElasticSearch Vector Store"),
    "elasticsearch_vector_store",
    category=ResourceCategory.VECTOR_STORE,
    parameters=[
        *_COMMON_PARAMETERS,
        Parameter.build_from(
            _("Uri"),
            "uri",
            str,
            description=_(
                "The uri of elasticsearch store, if not set, will use the default " "uri."
            ),
            optional=True,
            default="localhost",
        ),
        Parameter.build_from(
            _("Port"),
            "port",
            str,
            description=_(
                "The port of elasticsearch store, if not set, will use the default " "port."
            ),
            optional=True,
            default="9200",
        ),
        Parameter.build_from(
            _("Alias"),
            "alias",
            str,
            description=_(
                "The alias of elasticsearch store, if not set, will use the default " "alias."
            ),
            optional=True,
            default="default",
        ),
        Parameter.build_from(
            _("Index Name"),
            "index_name",
            str,
            description=_(
                "The index name of elasticsearch store, if not set, will use the "
                "default index name."
            ),
            optional=True,
            default="index_name_test",
        ),
        Parameter.build_from(
            _("Text Field"),
            "text_field",
            str,
            description=_(
                "The text field of elasticsearch store, if not set, will use the "
                "default text field."
            ),
            optional=True,
            default="content",
        ),
        Parameter.build_from(
            _("Embedding Field"),
            "embedding_field",
            str,
            description=_(
                "The embedding field of elasticsearch store, if not set, will use the "
                "default embedding field."
            ),
            optional=True,
            default="vector",
        ),
    ],
    description=_("Elasticsearch vector store."),
)
class ElasticsearchVectorConfig(VectorStoreConfig):
    """Elasticsearch vector store config."""

    class Config:
        """Config for BaseModel."""

        arbitrary_types_allowed = True

    uri: str = Field(
        default="localhost",
        description="The uri of elasticsearch store, if not set, will use the default uri.",
    )
    port: str = Field(
        default="9200",
        description="The port of elasticsearch store, if not set, will use the default port.",
    )

    alias: str = Field(
        default="default",
        description="The alias of elasticsearch store, if not set, will use the default "
        "alias.",
    )
    index_name: str = Field(
        default="index_name_test",
        description="The index name of elasticsearch store, if not set, will use the "
        "default index name.",
    )
    text_field: str = Field(
        default="content",
        description="The text field of elasticsearch store, if not set, will use the default "
        "text field.",
    )
    embedding_field: str = Field(
        default="vector",
        description="The embedding field of elasticsearch store, if not set, will use the "
        "default embedding field.",
    )
    metadata_field: str = Field(
        default="metadata",
        description="The metadata field of elasticsearch store, if not set, will use the "
        "default metadata field.",
    )
    secure: str = Field(
        default="",
        description="The secure of elasticsearch store, if not set, will use the default "
        "secure.",
    )

class ElasticStore(VectorStoreBase):
    """Elasticsearch vector store."""

    def __init__(self, vector_store_config: ElasticsearchVectorConfig) -> None:
        """Create a ElasticsearchStore instance.

        Args:
            vector_store_config (ElasticsearchVectorConfig): ElasticsearchStore config. 
        """

        connect_kwargs = {}
        elasticsearch_vector_config = vector_store_config.dict()
        self.uri = elasticsearch_vector_config.get("uri") or os.getenv(
            "ElasticSearch_URL", "localhost"
        )
        self.port = elasticsearch_vector_config.get("post") or os.getenv(
            "ElasticSearch_PORT", "9200"
        )
        self.username = elasticsearch_vector_config.get("username") or os.getenv("ElasticSearch_USERNAME")
        self.password = elasticsearch_vector_config.get("password") or os.getenv(
            "ElasticSearch_PASSWORD"
        ) 

        self.collection_name = (
            elasticsearch_vector_config.get("name") or vector_store_config.name
        )
        if string_utils.is_all_chinese(self.collection_name):
            bytes_str = self.collection_name.encode("utf-8")
            hex_str = bytes_str.hex()
            self.collection_name = hex_str
        if vector_store_config.embedding_fn is None:
            # Perform runtime checks on self.embedding to
            # ensure it has been correctly set and loaded
            raise ValueError("embedding_fn is required for ElasticSearchStore")
        self.index_name = self.collection_name.lower()
        self.embedding: Embeddings = vector_store_config.embedding_fn
        self.fields: List = [] 

        if (self.username is None) != (self.password is None):
            raise ValueError(
                "Both username and password must be set to use authentication for "
                "ElasticSearch"
            )

        if self.username:
            connect_kwargs["username"] = self.username
            connect_kwargs["password"] = self.password

        # 创建索引的配置===单节点情况下
        self.index_settings = { "settings": {
                                "number_of_shards": 1,
                                "number_of_replicas": 0  # 设置副本数量为0
                        }}

        """"""
        # ES python客户端连接(仅连接)
        try:
            if self.username != "" and self.password != "":
                self.es_client_python = Elasticsearch(f"http://{self.uri}:{self.port}",
                                                        basic_auth=(self.username,self.password))                 
                # 不创建索引,要不然会报错
                #if not self.vector_name_exists():
                #    self.es_client_python.indices.create(index=self.index_name, body=self.index_settings)
            else:
                logger.warning("ES未配置用户名和密码")
                self.es_client_python = Elasticsearch(f"http://{self.uri}:{self.port}")
                #if not self.vector_name_exists():
                #    self.es_client_python.indices.create(index=self.index_name, body=self.index_settings)
        except ConnectionError:
            logger.error("连接到 Elasticsearch 失败!")
        except Exception as e:
            logger.error(f"ES python客户端连接(仅连接)===Error 发生 : {e}")

        # langchain ES 连接、创建索引
        try: 
            if self.username != "" and self.password != "":
                self.db_init = ElasticsearchStore(
                    es_url=f"http://{self.uri}:{self.port}",
                    index_name=self.index_name,
                    query_field="context",
                    vector_query_field="dense_vector",
                    embedding=self.embedding,
                    es_user=self.username,
                    es_password=self.password
                )
            else: 
                logger.warning("ES未配置用户名和密码")
                self.db_init = ElasticsearchStore(
                    es_url=f"http://{self.uri}:{self.port}",
                    index_name=self.index_name,
                    query_field="context",
                    vector_query_field="dense_vector",
                    embedding=self.embedding,
                )            
        except ConnectionError:
            print("### 连接到 Elasticsearch 失败!")
            logger.error("### 连接到 Elasticsearch 失败!")
        except Exception as e:
            logger.error(f"langchain ES 连接、创建索引===Error 发生 : {e}")

    def load_document(
        self,
        #docs: Iterable[str],   
        chunks: List[Chunk]
    ) -> List[str]: 
        """Add text data into ElastcSearch.
        将docs写入到ES中
        """
        logger.info("ElasticStore load document")
        try:
            # 连接 + 同时写入文档 
            texts = [chunk.content for chunk in chunks]
            metadatas = [chunk.metadata for chunk in chunks]
            ids = [chunk.chunk_id for chunk in chunks]
            if self.username != "" and self.password != "":
                logger.info(f"wwt docs metadatas[0] === ElasticsearchStore.from_texts:::{metadatas[0]}: len={len(metadatas)}")
                self.db = ElasticsearchStore.from_texts(
                    texts=texts,
                    embedding=self.embedding,
                    metadatas=metadatas,
                    ids=ids,
                    es_url=f"http://{self.uri}:{self.port}",
                    index_name=self.index_name,
                    distance_strategy="COSINE",  # Defaults to COSINE. Can be one of COSINE, EUCLIDEAN_DISTANCE, or DOT_PRODUCT.
                    query_field="context",  ## Name of the field to store the texts in.
                    vector_query_field="dense_vector", # Optional. Name of the field to store the embedding vectors in. 
                    es_user=self.username,
                    es_password=self.password,
                ) 
                logger.info(f"wwt add Embedding success.......")
            else:
                self.db = ElasticsearchStore.from_documents(
                    texts=texts,
                    embedding=self.embedding,
                    metadatas=metadatas,
                    ids=ids,
                    es_url=f"http://{self.uri}:{self.port}",
                    index_name=self.index_name,
                    distance_strategy="COSINE",
                    query_field="context",
                    vector_query_field="dense_vector",
                    #verify_certs=False, 
                    ) 
            return ids
        except ConnectionError as ce:
            print(ce)
            print("连接到 Elasticsearch 失败!")
            logger.error("连接到 Elasticsearch 失败!")
        except Exception as e:
            logger.error(f"load_document===Error 发生 : {e}")
            print(e)

    def delete_by_ids(self, ids):
        """Delete vector by ids."""
        logger.info(f"1begin delete elasticsearch len ids: {len(ids)}") 
        logger.info(f"1begin delete elasticsearch type ids: {type(ids)}") 
        ids = ids.split(",")
        logger.info(f"2begin delete elasticsearch len ids: {len(ids)}") 
        logger.info(f"2begin delete elasticsearch type ids: {type(ids)}") 
        #es_client= self.db_init.connect_to_elasticsearch(
        #        es_url=f"http://{self.uri}:{self.port}",  
        #        es_user=self.username,
        #        es_password=self.password,   
        #)
        try:
            self.db_init.delete(ids=ids)  
            self.es_client_python.indices.refresh(index=self.index_name)
        except Exception as e:
            logger.error(f"Error 发生 : {e}") 

    def similar_search(
        self, text: str, topk: int, score_threshold: float, filters: Optional[MetadataFilters] = None
    ) -> List[Chunk]:
        """Perform a search on a query string and return results.
        # TODO: 语义分词后期配置可换
        """
        query = text
        print(
            f" similar_search 输入的query参数为:{query}") 
        query_list = jieba.analyse.textrank(query, topK=20, withWeight=False)
        if len(query_list) == 0:
            query_list = [query]
        body = {
            "query": {
                "match": {
                    "context": " ".join(query_list)
                }
            }
        }
        search_results = self.es_client_python.search(index=self.index_name, body=body, size=topk)
        search_results = search_results['hits']['hits']

        # 判断搜索结果是否为空
        if not search_results:
            return []

        info_docs = []
        byte_count = 0

        for result in search_results:
            index_name = result["_index"]  
            vector_doc = result["dense_vector"]  # 文本的稠密向量表示
            doc_id = result["_id"]  
            source = result["_source"]
            context = source["context"]
            metadata = source["metadata"]
            score = result["_score"]

            # 如果下一个context会超过总字节数限制,则截断context
            VS_TYPE_PROMPT_TOTAL_BYTE_SIZE = 3000   ### 每种向量库的prompt字节的最大长度,超过则截断,后面放到.env中
            if (byte_count + len(context)) > VS_TYPE_PROMPT_TOTAL_BYTE_SIZE:
                context = context[:VS_TYPE_PROMPT_TOTAL_BYTE_SIZE - byte_count]

            doc_with_score = [Document(page_content=context, metadata=metadata), score, doc_id]
            info_docs.append(doc_with_score)

            byte_count += len(context)

            # 如果字节数已经达到限制,则结束循环
            if byte_count >= VS_TYPE_PROMPT_TOTAL_BYTE_SIZE:
                break
        print(f"ES搜索到{len(info_docs)}个结果:")
        # 将结果写入文件
        result_file = open("es_search_results.txt", "w", encoding="utf-8")
        result_file.write(f"query:{query}")
        result_file.write(f"ES搜索到{len(info_docs)}个结果:\n")
        for item in info_docs:
            doc = item[0]
            result_file.write(doc.page_content + "\n")
            result_file.write("*" * 20)
            result_file.write("\n")
            result_file.flush()
            print(doc.page_content + "\n")
            print("*" * 20)
            print("\n")
        result_file.close()

        return [
            Chunk(
                metadata=json.loads(doc.metadata.get("metadata", "")),
                content=doc.page_content,
            )
            for doc, score, id  in info_docs
        ]

    #def similar_search_with_scores(self, text: str, topk: int, score_threshold: float,): 
    def similar_search_with_scores(
        self, text, topk, score_threshold, filters: Optional[MetadataFilters] = None
    ) -> List[Chunk]:
        """Perform a search on a query string and return results with score.

        For more information about the search parameters, take a look at the pyElasticSearch
        documentation found here:
        https://ElasticSearch.io/api-reference/pyElasticSearch/v2.2.6/Collection/search().md

        Args:
            text (str): The query text.
            topk (int): The number of similar documents to return.
            score_threshold (float): Optional, a floating point value between 0 to 1.
            filters (Optional[MetadataFilters]): Optional, metadata filters.
        Returns:
            List[Tuple[Document, float]]: Result doc and score.
        """ 

        query = text
        print(f" similar_search 输入的query参数为:{query}") 
        query_list = jieba.analyse.textrank(query, topK=20, withWeight=False)
        if len(query_list) == 0:
            query_list = [query]
        body = {
            "query": {
                "match": {
                    "context": " ".join(query_list)
                }
            }
        }
        search_results = self.es_client_python.search(index=self.index_name, body=body, size=topk)
        search_results = search_results['hits']['hits']
        # 判断搜索结果是否为空
        if not search_results:
            return []

        info_docs = []
        byte_count = 0

        for result in search_results:            
            # logger.info(f"wwt add query result==={result}")
            ## 全部列出了
            index_name = result["_index"]  
            #vector_doc = result["dense_vector"]  # 文本的稠密向量表示
            doc_id = result["_id"]  
            source = result["_source"] #  源头
            context = source["context"]  # 文本内容
            metadata = source["metadata"]  ## 文本来源路径
            score = result["_score"] / 100  # 分数,100分zhi

            # 如果下一个context会超过总字节数限制,则截断context
            VS_TYPE_PROMPT_TOTAL_BYTE_SIZE = 3000   ### 每种向量库的prompt字节的最大长度,超过则截断,后面放到.env中
            if (byte_count + len(context)) > VS_TYPE_PROMPT_TOTAL_BYTE_SIZE:
                context = context[:VS_TYPE_PROMPT_TOTAL_BYTE_SIZE - byte_count]

            doc_with_score = [Document(page_content=context, metadata=metadata), score, doc_id]
            info_docs.append(doc_with_score)

            byte_count += len(context)

            # 如果字节数已经达到限制,则结束循环
            if byte_count >= VS_TYPE_PROMPT_TOTAL_BYTE_SIZE:
                break
        print(f"ES搜索到{len(info_docs)}个结果:")
        logger.info(f"ES搜索到{len(info_docs)}个结果:")
        # 将结果写入文件
        result_file = open("es_search_results.txt", "w", encoding="utf-8")
        result_file.write(f"query:{query} \n")
        result_file.write(f"ES搜索到{len(info_docs)}个结果:\n")
        for item in info_docs:
            doc = item[0]
            result_file.write(doc.page_content + "\n")
            result_file.write("*" * 50)
            result_file.write("\n")
            result_file.flush()
            print(doc.page_content + "\n")
            print("*" * 50)
            print("\n\n")
        result_file.close()

        if any(score < 0.0 or score > 1.0 for _, score, _ in info_docs):
            logger.warning(
                "similarity score need between" f" 0 and 1, got {info_docs}"
            )

        logger.info(f"wwt add score_threshold: {score_threshold}")
        if score_threshold is not None:
            docs_and_scores = [
                Chunk(
                    metadata=doc.metadata,
                    content=doc.page_content,
                    score=score,
                    chunk_id=id,
                )
                for doc, score, id in info_docs
                if score >= score_threshold
            ]
            if len(docs_and_scores) == 0:
                logger.warning(
                    "No relevant docs were retrieved using the relevance score"
                    f" threshold {score_threshold}"
                )
        return docs_and_scores

    def vector_name_exists(self):
        """Whether vector name exists.""" 
        """is vector store name exist."""
        return self.es_client_python.indices.exists(index=self.index_name)

    def delete_vector_name(self, vector_name: str):
        """Delete vector name/index_name."""  
        """从知识库(知识库名的小写部分)删除全部向量"""
        if self.es_client_python.indices.exists(index=self.index_name):
            self.es_client_python.indices.delete(index=self.index_name)
            #self.es_client_python.indices.delete(index=self.kb_name)
Aries-ckt commented 6 months ago

hi, @IamWWT, amazing feature, can you make pull request for ElasticSearchStore ?