run-llama / llama_index

LlamaIndex is a data framework for your LLM applications
https://docs.llamaindex.ai
MIT License
36.76k stars 5.27k forks source link

[Question]: lightweight colbert rerank installation #16262

Open schiaro98 opened 1 month ago

schiaro98 commented 1 month ago

Question Validation

Question

At this moment llama-index-postprocessor-colbert-rerank import requires torch and his nvidia-* dependencies. Is there any way to avoid these download/import in order to make the build time smaller and decrease packages size? Thanks

dosubot[bot] commented 1 month ago

Hello @schiaro98! I'm here to help you with any bugs, questions, or contributions while waiting for a human maintainer. Let's tackle this together!

Yes, you can avoid downloading/importing torch and its nvidia-* dependencies by modifying the ColbertRerank class to use a different backend or by mocking the torch and transformers imports if you do not need the actual functionality during certain operations (e.g., testing or building).

Here is an example of how you might mock these imports:

import sys
from unittest.mock import MagicMock

# Mock torch and transformers if they are not needed
sys.modules['torch'] = MagicMock()
sys.modules['transformers'] = MagicMock()

from typing import Any, List, Optional
from llama_index.core.bridge.pydantic import Field, PrivateAttr
from llama_index.core.callbacks import CBEventType, EventPayload
from llama_index.core.instrumentation import get_dispatcher
from llama_index.core.instrumentation.events.rerank import (
    ReRankEndEvent,
    ReRankStartEvent,
)
from llama_index.core.postprocessor.types import BaseNodePostprocessor
from llama_index.core.schema import MetadataMode, NodeWithScore, QueryBundle
from llama_index.core.utils import infer_torch_device

# Importing mocked torch and transformers
import torch
from transformers import AutoTokenizer, AutoModel

DEFAULT_COLBERT_MAX_LENGTH = 512

dispatcher = get_dispatcher(__name__)

class ColbertRerank(BaseNodePostprocessor):
    model: str = Field(description="Colbert model name.")
    top_n: int = Field(description="Number of nodes to return sorted by score.")
    device: str = Field(
        default="cpu",
        description="Device to use for sentence transformer.",
    )
    keep_retrieval_score: bool = Field(
        default=False,
        description="Whether to keep the retrieval score in metadata.",
    )
    _model: Any = PrivateAttr()
    _tokenizer: Any = PrivateAttr()

    def __init__(
        self,
        top_n: int = 5,
        model: str = "colbert-ir/colbertv2.0",
        tokenizer: str = "colbert-ir/colbertv2.0",
        device: Optional[str] = None,
        keep_retrieval_score: Optional[bool] = False,
    ):
        device = infer_torch_device() if device is None else device
        super().__init__(
            top_n=top_n,
            device=device,
            keep_retrieval_score=keep_retrieval_score,
            model=model,
        )
        self._tokenizer = AutoTokenizer.from_pretrained(tokenizer)
        self._model = AutoModel.from_pretrained(model)

    @classmethod
    def class_name(cls) -> str:
        return "ColbertRerank"

    def _calculate_sim(self, query: str, documents_text_list: List[str]) -> List[float]:
        query_encoding = self._tokenizer(query, return_tensors="pt")
        query_embedding = self._model(**query_encoding).last_hidden_state
        rerank_score_list = []

        for document_text in documents_text_list:
            document_encoding = self._tokenizer(
                document_text, return_tensors="pt", truncation=True, max_length=512
            )
            document_embedding = self._model(**document_encoding).last_hidden_state
            sim_matrix = torch.nn.functional.cosine_similarity(
                query_embedding.unsqueeze(2), document_embedding.unsqueeze(1), dim=-1
            )

            max_sim_scores, _ = torch.max(sim_matrix, dim=2)
            rerank_score_list.append(torch.mean(max_sim_scores, dim=1))

        return rerank_score_list

    def _postprocess_nodes(
        self,
        nodes: List[NodeWithScore],
        query_bundle: Optional[QueryBundle] = None,
    ) -> List[NodeWithScore]:
        dispatcher.event(
            ReRankStartEvent(
                query=query_bundle, nodes=nodes, top_n=self.top_n, model_name=self.model
            )
        )

        if query_bundle is None:
            raise ValueError("Missing query bundle in extra info.")
        if len(nodes) == 0:
            return []

        nodes_text_list = [
            str(node.node.get_content(metadata_mode=MetadataMode.EMBED))
            for node in nodes
        ]

        with self.callback_manager.event(
            CBEventType.RERANKING,
            payload={
                EventPayload.NODES: nodes,
                EventPayload.MODEL_NAME: self.model,
                EventPayload.QUERY_STR: query_bundle.query_str,
                EventPayload.TOP_K: self.top_n,
            },
        ) as event:
            scores = self._calculate_sim(query_bundle.query_str, nodes_text_list)

            assert len(scores) == len(nodes)

            for node, score in zip(nodes, scores):
                if self.keep_retrieval_score:
                    node.node.metadata["retrieval_score"] = node.score
                node.score = float(score)

            reranked_nodes = sorted(nodes, key=lambda x: -x.score if x.score else 0)[
                : self.top_n
            ]
            event.on_end(payload={EventPayload.NODES: reranked_nodes})

        dispatcher.event(ReRankEndEvent(nodes=reranked_nodes))
        return reranked_nodes

This approach allows you to avoid the actual import and usage of torch and transformers during the build process, reducing build time and package size [1].

To continue talking to Dosu, mention @dosu.

schiaro98 commented 1 month ago

Since the llamaindex application is inside a flask environment, is there any way to not recreate a instance of the rereranker model for any request?

dosubot[bot] commented 1 month ago

Yes, you can avoid recreating an instance of the ColbertRerank model for each request in a Flask environment by initializing the model once and reusing it across requests. You can achieve this by creating a global instance of the ColbertRerank model when the Flask application starts and then using this instance in your request handlers.

Here is an example of how you can do this:

from flask import Flask, request, jsonify
from llama_index.postprocessor.colbert_rerank.base import ColbertRerank

app = Flask(__name__)

# Initialize the ColbertRerank model once
colbert_rerank_model = ColbertRerank()

@app.route('/rerank', methods=['POST'])
def rerank():
    data = request.json
    query = data['query']
    nodes = data['nodes']

    # Use the pre-initialized model to rerank nodes
    reranked_nodes = colbert_rerank_model._postprocess_nodes(nodes, query_bundle=query)

    return jsonify(reranked_nodes)

if __name__ == '__main__':
    app.run()

In this example, the colbert_rerank_model is initialized once when the Flask application starts. The /rerank endpoint uses this pre-initialized model to process incoming requests, avoiding the need to recreate the model instance for each request [1].

To continue talking to Dosu, mention @dosu.