run-llama / llama_index

LlamaIndex is a data framework for your LLM applications
https://docs.llamaindex.ai
MIT License
35.5k stars 5.01k forks source link

[Question]: Custom callback handler #14717

Open KhyatiNinad opened 2 months ago

KhyatiNinad commented 2 months ago

Question Validation

Question

Hi,

I am trying to build a custom callback handler in a hierarchical agent setup where one agent goes in as tool for another Agent.

How can I get the event_id in the bottom-most Agent custom_query method, for the current request?

dosubot[bot] commented 2 months ago

To get the event_id in the bottom-most agent's custom_query method for the current request in a hierarchical agent setup, you can use the on_event_start method from the UpTrainCallbackHandler class. This method handles the start of an event and returns the event_id.

Here is the relevant code snippet from the on_event_start method:

def on_event_start(
        self,
        event_type: CBEventType,
        payload: Any = None,
        event_id: str = "",
        parent_id: str = "",
        **kwargs: Any,
    ) -> str:
        """Run when an event starts and return id of event."""
        event = CBEvent(event_type, payload=payload, id_=event_id)
        self._event_pairs_by_id[event.id_].append(event)

        if event_type is CBEventType.QUERY:
            self.schema.question = payload["query_str"]
        if event_type is CBEventType.TEMPLATING and "template_vars" in payload:
            template_vars = payload["template_vars"]
            self.schema.context = template_vars.get("context_str", "")
        elif event_type is CBEventType.RERANKING and "nodes" in payload:
            self.schema.eval_types.add("reranking")
            # Store old context data
            self.schema.old_context = [node.text for node in payload["nodes"]]
        elif event_type is CBEventType.SUB_QUESTION:
            # For the first sub question, store parent question and parent id
            if "sub_question" not in self.schema.eval_types:
                self.schema.parent_question = self.schema.question
                self.schema.eval_types.add("sub_question")
            # Store sub question data - question and parent id
            self.schema.sub_question_parent_id = parent_id
        return event_id

In this method, the event_id is returned at the end, which you can use in your custom_query method to get the event_id for the current request.

Additionally, you can create custom event handlers to capture and process these events, including event_id. Here is an example of how to create a custom event handler:

from llama_index.core.instrumentation.events import BaseEvent
from llama_index.core.instrumentation.event_handlers import BaseEventHandler

class ExampleEventHandler(BaseEventHandler):
    """Example event handler."""

    events: List[BaseEvent] = []

    @classmethod
    def class_name(cls) -> str:
        """Class name."""
        return "ExampleEventHandler"

    def handle(self, event: BaseEvent) -> None:
        """Logic for handling event."""
        print("-----------------------")
        print(event.id_)
        print(event.timestamp)
        print(event.span_id)
        print(f"Event type: {event.class_name()}")
        self.events.append(event)
        print("-----------------------")

You can then attach this event handler to the dispatcher to watch events and spans come in:

from llama_index.core.instrumentation import get_dispatcher

# root dispatcher
root_dispatcher = get_dispatcher()

# register event handler
event_handler = ExampleEventHandler()
root_dispatcher.add_event_handler(event_handler)

By using these tools, you can effectively manage and retrieve the event_id in your hierarchical agent setup [1][2].

To continue talking to Dosu, mention @dosu.

KhyatiNinad commented 2 months ago

@dosu So, if the event_id is returned at the end, how to retrieve it in my custom_query method to get the event_id for the current request. Please provide code for that

dosubot[bot] commented 2 months ago

To retrieve the event_id in your custom_query method for the current request, you can use the CallbackManager to manage and access the event stack. Here is an example of how you can achieve this:

  1. Define your custom callback handler:
    • Create a custom callback handler that captures the event_id when an event starts.
from llama_index.core.callbacks.base import BaseCallbackHandler, CBEventType, CBEvent
from typing import Any, Dict, List
from collections import defaultdict

class CustomCallbackHandler(BaseCallbackHandler):
    def __init__(self):
        super().__init__()
        self._event_pairs_by_id: Dict[str, List[CBEvent]] = defaultdict(list)

    def on_event_start(
        self,
        event_type: CBEventType,
        payload: Any = None,
        event_id: str = "",
        parent_id: str = "",
        **kwargs: Any,
    ) -> str:
        event = CBEvent(event_type, payload=payload, id_=event_id)
        self._event_pairs_by_id[event.id_].append(event)
        return event_id

    def get_current_event_id(self) -> str:
        # Assuming the last event in the stack is the current event
        if self._event_pairs_by_id:
            return list(self._event_pairs_by_id.keys())[-1]
        return ""
  1. Attach the custom callback handler to the dispatcher:
    • Register your custom callback handler with the dispatcher.
from llama_index.core.instrumentation import get_dispatcher

# root dispatcher
root_dispatcher = get_dispatcher()

# register custom callback handler
custom_callback_handler = CustomCallbackHandler()
root_dispatcher.add_event_handler(custom_callback_handler)
  1. Retrieve the event_id in your custom_query method:
    • Use the custom callback handler to get the current event_id.
class RAGStringQueryEngine(CustomQueryEngine):
    """RAG String Query Engine."""

    retriever: BaseRetriever
    response_synthesizer: BaseSynthesizer
    llm: OpenAI
    qa_prompt: PromptTemplate

    def custom_query(self, query_str: str):
        # Retrieve the current event_id
        current_event_id = custom_callback_handler.get_current_event_id()
        print(f"Current Event ID: {current_event_id}")

        nodes = self.retriever.retrieve(query_str)
        context_str = "\n\n".join([n.node.get_content() for n in nodes])
        response = self.llm.complete(
            qa_prompt.format(context_str=context_str, query_str=query_str)
        )

        return str(response)

In this setup, the CustomCallbackHandler captures the event_id when an event starts and provides a method to retrieve the current event_id. You can then use this handler in your custom_query method to get the event_id for the current request.

To continue talking to Dosu, mention @dosu.

KhyatiNinad commented 2 months ago

@dosu But will this work when we use the agent in a web API where multiple requests are going on in parallel

dosubot[bot] commented 2 months ago

To ensure that the event_id retrieval is thread-safe and correctly corresponds to the current request in a web API where multiple requests are going on in parallel, you can use context-local storage. In Python, this can be achieved using the contextvars module. Here is an example:

  1. Define your custom callback handler with context-local storage:
from llama_index.core.callbacks.base import BaseCallbackHandler, CBEventType, CBEvent
from contextvars import ContextVar
from typing import Any, Dict, List
from collections import defaultdict

# Create a context-local storage object
event_id_var = ContextVar('event_id')

class CustomCallbackHandler(BaseCallbackHandler):
    def __init__(self):
        super().__init__()
        self._event_pairs_by_id: Dict[str, List[CBEvent]] = defaultdict(list)

    def on_event_start(
        self,
        event_type: CBEventType,
        payload: Any = None,
        event_id: str = "",
        parent_id: str = "",
        **kwargs: Any,
    ) -> str:
        event = CBEvent(event_type, payload=payload, id_=event_id)
        self._event_pairs_by_id[event.id_].append(event)
        event_id_var.set(event_id)  # Set the event_id in context-local storage
        return event_id

    def get_current_event_id(self) -> str:
        return event_id_var.get("")
  1. Attach the custom callback handler to the dispatcher:
from llama_index.core.instrumentation import get_dispatcher

# root dispatcher
root_dispatcher = get_dispatcher()

# register custom callback handler
custom_callback_handler = CustomCallbackHandler()
root_dispatcher.add_event_handler(custom_callback_handler)
  1. Retrieve the event_id in your custom_query method:
class RAGStringQueryEngine(CustomQueryEngine):
    """RAG String Query Engine."""

    retriever: BaseRetriever
    response_synthesizer: BaseSynthesizer
    llm: OpenAI
    qa_prompt: PromptTemplate

    def custom_query(self, query_str: str):
        # Retrieve the current event_id
        current_event_id = custom_callback_handler.get_current_event_id()
        print(f"Current Event ID: {current_event_id}")

        nodes = self.retriever.retrieve(query_str)
        context_str = "\n\n".join([n.node.get_content() for n in nodes])
        response = self.llm.complete(
            qa_prompt.format(context_str=context_str, query_str=query_str)
        )

        return str(response)

By using ContextVar, each request's event_id is correctly isolated and thread-safe, ensuring that the event_id retrieval works correctly even when multiple requests are processed in parallel [1].

To continue talking to Dosu, mention @dosu.

KhyatiNinad commented 2 months ago

@dosu OK. It seems event_id is getting changed for child agents. How can I get event_id of top most parent agent in my custom_query tool method of child agent.

dosubot[bot] commented 2 months ago

To get the event_id of the top-most parent agent in your custom_query method of the child agent, you can track the hierarchy of events using a ContextVar to store the root event ID. Here’s how you can achieve this:

  1. Define your custom callback handler with context-local storage:
    • Track the root event ID when an event starts.
from llama_index.core.callbacks.base import BaseCallbackHandler, CBEventType, CBEvent
from contextvars import ContextVar
from typing import Any, Dict, List
from collections import defaultdict

# Create context-local storage objects
event_id_var = ContextVar('event_id')
root_event_id_var = ContextVar('root_event_id')

class CustomCallbackHandler(BaseCallbackHandler):
    def __init__(self):
        super().__init__()
        self._event_pairs_by_id: Dict[str, List[CBEvent]] = defaultdict(list)

    def on_event_start(
        self,
        event_type: CBEventType,
        payload: Any = None,
        event_id: str = "",
        parent_id: str = "",
        **kwargs: Any,
    ) -> str:
        event = CBEvent(event_type, payload=payload, id_=event_id)
        self._event_pairs_by_id[event.id_].append(event)
        event_id_var.set(event_id)  # Set the event_id in context-local storage

        # Set the root event ID if it's not already set
        if not root_event_id_var.get(None):
            root_event_id_var.set(event_id)

        return event_id

    def get_current_event_id(self) -> str:
        return event_id_var.get("")

    def get_root_event_id(self) -> str:
        return root_event_id_var.get("")
  1. Attach the custom callback handler to the dispatcher:
    • Register your custom callback handler with the dispatcher.
from llama_index.core.instrumentation import get_dispatcher

# root dispatcher
root_dispatcher = get_dispatcher()

# register custom callback handler
custom_callback_handler = CustomCallbackHandler()
root_dispatcher.add_event_handler(custom_callback_handler)
  1. Retrieve the root event_id in your custom_query method:
    • Use the custom callback handler to get the root event_id.
class RAGStringQueryEngine(CustomQueryEngine):
    """RAG String Query Engine."""

    retriever: BaseRetriever
    response_synthesizer: BaseSynthesizer
    llm: OpenAI
    qa_prompt: PromptTemplate

    def custom_query(self, query_str: str):
        # Retrieve the root event_id
        root_event_id = custom_callback_handler.get_root_event_id()
        print(f"Root Event ID: {root_event_id}")

        nodes = self.retriever.retrieve(query_str)
        context_str = "\n\n".join([n.node.get_content() for n in nodes])
        response = self.llm.complete(
            qa_prompt.format(context_str=context_str, query_str=query_str)
        )

        return str(response)

By using ContextVar to store the root event ID, you can ensure that the root event_id is correctly tracked and accessible in your custom_query method, even when dealing with hierarchical agents and parallel requests [1][2].

To continue talking to Dosu, mention @dosu.