Open NagatoYuki0943 opened 1 month ago
创建以下文件为 BCERerank.py ,并以from BCERerank import BCERerank替代 from BCEmbedding.tools.langchain import BCERerank 即可暂时解决问题:
BCERerank.py
from BCERerank import BCERerank
from BCEmbedding.tools.langchain import BCERerank
from __future__ import annotations from typing import Dict, Optional, Sequence, Any from langchain_core.documents import Document from pydantic import model_validator from langchain.callbacks.manager import Callbacks from langchain.retrievers.document_compressors.base import BaseDocumentCompressor from pydantic import PrivateAttr class BCERerank(BaseDocumentCompressor): """Document compressor that uses `BCEmbedding RerankerModel API`.""" client: str = "BCEmbedding" top_n: int = 3 """Number of documents to return.""" model: str = "maidalun1020/bce-reranker-base_v1" """Model to use for reranking.""" _model: Any = PrivateAttr() class Config: """Configuration for this pydantic object.""" extra = "forbid" arbitrary_types_allowed = True def __init__( self, top_n: int = 3, model: str = "maidalun1020/bce-reranker-base_v1", device: Optional[str] = None, **kwargs, ): super().__init__(top_n=top_n, model=model) try: from BCEmbedding.models import RerankerModel except ImportError: raise ImportError( "Cannot import `BCEmbedding` package,", "please `pip install BCEmbedding>=0.1.2`", ) self._model = RerankerModel(model_name_or_path=model, device=device, **kwargs) # @model_validator(mode="before") # def validate_environment(cls, values: Dict) -> Dict: # """Validate that api key and python package exists in environment.""" # values["client"] = "BCEmbedding.models.RerankerModel" # return values def compress_documents( self, documents: Sequence[Document], query: str, callbacks: Optional[Callbacks] = None, ) -> Sequence[Document]: """ Compress documents using `BCEmbedding RerankerModel API`. Args: documents: A sequence of documents to compress. query: The query to use for compressing the documents. callbacks: Callbacks to run during the compression process. Returns: A sequence of compressed documents. """ if len(documents) == 0: # to avoid empty api call return [] doc_list = list(documents) passages = [] valid_doc_list = [] invalid_doc_list = [] for d in doc_list: passage = d.page_content if isinstance(passage, str) and len(passage) > 0: passages.append(passage.replace("\n", " ")) valid_doc_list.append(d) else: invalid_doc_list.append(d) rerank_result = self._model.rerank(query, passages) final_results = [] for score, doc_id in zip( rerank_result["rerank_scores"], rerank_result["rerank_ids"] ): doc = valid_doc_list[doc_id] doc.metadata["relevance_score"] = score final_results.append(doc) for doc in invalid_doc_list: doc.metadata["relevance_score"] = 0 final_results.append(doc) final_results = final_results[: self.top_n] return final_results
创建以下文件为
BCERerank.py
,并以from BCERerank import BCERerank
替代from BCEmbedding.tools.langchain import BCERerank
即可暂时解决问题: