Open schiaro98 opened 1 month ago
Hello @schiaro98! I'm here to help you with any bugs, questions, or contributions while waiting for a human maintainer. Let's tackle this together!
Yes, you can avoid downloading/importing torch
and its nvidia-*
dependencies by modifying the ColbertRerank
class to use a different backend or by mocking the torch
and transformers
imports if you do not need the actual functionality during certain operations (e.g., testing or building).
Here is an example of how you might mock these imports:
import sys
from unittest.mock import MagicMock
# Mock torch and transformers if they are not needed
sys.modules['torch'] = MagicMock()
sys.modules['transformers'] = MagicMock()
from typing import Any, List, Optional
from llama_index.core.bridge.pydantic import Field, PrivateAttr
from llama_index.core.callbacks import CBEventType, EventPayload
from llama_index.core.instrumentation import get_dispatcher
from llama_index.core.instrumentation.events.rerank import (
ReRankEndEvent,
ReRankStartEvent,
)
from llama_index.core.postprocessor.types import BaseNodePostprocessor
from llama_index.core.schema import MetadataMode, NodeWithScore, QueryBundle
from llama_index.core.utils import infer_torch_device
# Importing mocked torch and transformers
import torch
from transformers import AutoTokenizer, AutoModel
DEFAULT_COLBERT_MAX_LENGTH = 512
dispatcher = get_dispatcher(__name__)
class ColbertRerank(BaseNodePostprocessor):
model: str = Field(description="Colbert model name.")
top_n: int = Field(description="Number of nodes to return sorted by score.")
device: str = Field(
default="cpu",
description="Device to use for sentence transformer.",
)
keep_retrieval_score: bool = Field(
default=False,
description="Whether to keep the retrieval score in metadata.",
)
_model: Any = PrivateAttr()
_tokenizer: Any = PrivateAttr()
def __init__(
self,
top_n: int = 5,
model: str = "colbert-ir/colbertv2.0",
tokenizer: str = "colbert-ir/colbertv2.0",
device: Optional[str] = None,
keep_retrieval_score: Optional[bool] = False,
):
device = infer_torch_device() if device is None else device
super().__init__(
top_n=top_n,
device=device,
keep_retrieval_score=keep_retrieval_score,
model=model,
)
self._tokenizer = AutoTokenizer.from_pretrained(tokenizer)
self._model = AutoModel.from_pretrained(model)
@classmethod
def class_name(cls) -> str:
return "ColbertRerank"
def _calculate_sim(self, query: str, documents_text_list: List[str]) -> List[float]:
query_encoding = self._tokenizer(query, return_tensors="pt")
query_embedding = self._model(**query_encoding).last_hidden_state
rerank_score_list = []
for document_text in documents_text_list:
document_encoding = self._tokenizer(
document_text, return_tensors="pt", truncation=True, max_length=512
)
document_embedding = self._model(**document_encoding).last_hidden_state
sim_matrix = torch.nn.functional.cosine_similarity(
query_embedding.unsqueeze(2), document_embedding.unsqueeze(1), dim=-1
)
max_sim_scores, _ = torch.max(sim_matrix, dim=2)
rerank_score_list.append(torch.mean(max_sim_scores, dim=1))
return rerank_score_list
def _postprocess_nodes(
self,
nodes: List[NodeWithScore],
query_bundle: Optional[QueryBundle] = None,
) -> List[NodeWithScore]:
dispatcher.event(
ReRankStartEvent(
query=query_bundle, nodes=nodes, top_n=self.top_n, model_name=self.model
)
)
if query_bundle is None:
raise ValueError("Missing query bundle in extra info.")
if len(nodes) == 0:
return []
nodes_text_list = [
str(node.node.get_content(metadata_mode=MetadataMode.EMBED))
for node in nodes
]
with self.callback_manager.event(
CBEventType.RERANKING,
payload={
EventPayload.NODES: nodes,
EventPayload.MODEL_NAME: self.model,
EventPayload.QUERY_STR: query_bundle.query_str,
EventPayload.TOP_K: self.top_n,
},
) as event:
scores = self._calculate_sim(query_bundle.query_str, nodes_text_list)
assert len(scores) == len(nodes)
for node, score in zip(nodes, scores):
if self.keep_retrieval_score:
node.node.metadata["retrieval_score"] = node.score
node.score = float(score)
reranked_nodes = sorted(nodes, key=lambda x: -x.score if x.score else 0)[
: self.top_n
]
event.on_end(payload={EventPayload.NODES: reranked_nodes})
dispatcher.event(ReRankEndEvent(nodes=reranked_nodes))
return reranked_nodes
This approach allows you to avoid the actual import and usage of torch
and transformers
during the build process, reducing build time and package size [1].
To continue talking to Dosu, mention @dosu.
Since the llamaindex application is inside a flask environment, is there any way to not recreate a instance of the rereranker model for any request?
Yes, you can avoid recreating an instance of the ColbertRerank
model for each request in a Flask environment by initializing the model once and reusing it across requests. You can achieve this by creating a global instance of the ColbertRerank
model when the Flask application starts and then using this instance in your request handlers.
Here is an example of how you can do this:
from flask import Flask, request, jsonify
from llama_index.postprocessor.colbert_rerank.base import ColbertRerank
app = Flask(__name__)
# Initialize the ColbertRerank model once
colbert_rerank_model = ColbertRerank()
@app.route('/rerank', methods=['POST'])
def rerank():
data = request.json
query = data['query']
nodes = data['nodes']
# Use the pre-initialized model to rerank nodes
reranked_nodes = colbert_rerank_model._postprocess_nodes(nodes, query_bundle=query)
return jsonify(reranked_nodes)
if __name__ == '__main__':
app.run()
In this example, the colbert_rerank_model
is initialized once when the Flask application starts. The /rerank
endpoint uses this pre-initialized model to process incoming requests, avoiding the need to recreate the model instance for each request [1].
To continue talking to Dosu, mention @dosu.
Question Validation
Question
At this moment llama-index-postprocessor-colbert-rerank import requires torch and his nvidia-* dependencies. Is there any way to avoid these download/import in order to make the build time smaller and decrease packages size? Thanks