Closed PhantomGrapes closed 3 months ago
To fix the validation error in the DashScopeRerank
class, you need to add the model_name
parameter when initializing the ReRankStartEvent
. Here is the corrected code:
def _postprocess_nodes(
self,
nodes: List[NodeWithScore],
query_bundle: Optional[QueryBundle] = None,
) -> List[NodeWithScore]:
dispatcher.event(
ReRankStartEvent(
model_name=self.model, # Add the model_name parameter here
top_n=self.top_n,
query=query_bundle,
nodes=nodes,
)
)
if query_bundle is None:
raise ValueError("Missing query bundle in extra info.")
if len(nodes) == 0:
return []
with self.callback_manager.event(
CBEventType.RERANKING,
payload={
EventPayload.NODES: nodes,
EventPayload.MODEL_NAME: self.model,
EventPayload.QUERY_STR: query_bundle.query_str,
EventPayload.TOP_K: self.top_n,
},
) as event:
texts = [
node.node.get_content(metadata_mode=MetadataMode.EMBED)
for node in nodes
]
results = dashscope.TextReRank.call(
model=self.model,
top_n=self.top_n,
query=query_bundle.query_str,
documents=texts,
)
new_nodes = []
for result in results.output.results:
new_node_with_score = NodeWithScore(
node=nodes[result.index].node, score=result.relevance_score
)
new_nodes.append(new_node_with_score)
event.on_end(payload={EventPayload.NODES: new_nodes})
dispatcher.event(
ReRankEndEvent(
nodes=new_nodes,
)
)
return new_nodes
This change ensures that the model_name
parameter is included when initializing the ReRankStartEvent
, resolving the validation error [1].
To continue talking to Dosu, mention @dosu.
Bug Description
DashScopeRerank use model as input parameters to ReRankStartEvent:
which should be model_name in ReRankStartEvent:
this problem causes value error when using DashScopeRerank.
Version
v0.10.34
Steps to Reproduce
from llama_index.core.data_structs import Node from llama_index.core.schema import NodeWithScore from llama_index.postprocessor.dashscope_rerank import DashScopeRerank
nodes = [ NodeWithScore(node=Node(text="text1"), score=0.7), NodeWithScore(node=Node(text="text2"), score=0.8), ]
dashscope_rerank = DashScopeRerank(top_n=5) results = dashscope_rerank.postprocess_nodes(nodes, query_str="")
for res in results:
print("Text: ", res.node.get_content(), "Score: ", res.score)
Relevant Logs/Tracbacks