run-llama / llama_index

LlamaIndex is a data framework for your LLM applications
https://docs.llamaindex.ai
MIT License
35.21k stars 4.94k forks source link

[Bug]: Global CallbackManger via one_click not propagated to RetrieverQueryEngine #8657

Closed mikeldking closed 8 months ago

mikeldking commented 10 months ago

Bug Description

Reporting on behalf of a user of llama_index and arize-phoenix.

They were reporting that the LLM steps for a custom query engine was not producing traces in phoenix. Their example was as follows

mport phoenix as px
import llama_index
from pydantic import BaseModel, Field
from llama_index.callbacks import CallbackManager
from llama_index.llms import OpenAI
from llama_index.query_engine import RetrieverQueryEngine 
from llama_index.indices.list import SummaryIndexRetriever
from llama_index.response_synthesizers import CompactAndRefine
from llama_index import Document, SummaryIndex, ServiceContext
from phoenix.trace.llama_index import (
    OpenInferenceTraceCallbackHandler,
)

session = px.active_session()
if session is None:
    session = px.launch_app()
llama_index.set_global_handler("arize_phoenix")

# Initialize the callback handler
# NOTE: this doesn't work with the global handler nor with the explicit callback manager
# callback_handler = OpenInferenceTraceCallbackHandler()

service_ctx = ServiceContext.from_defaults(
    llm=(
        OpenAI(
            model="gpt-3.5-turbo-16k",
            max_retries=10
        )
    ),
    # callback_manager=CallbackManager(handlers=[callback_handler]),
)

class Person(BaseModel):
    """A class representing a person mentioned in the text"""
    name: str = Field(description="The person's name")
    age: int | None = Field(description="The person's age")

class People(BaseModel):
    """A class representing a collection of unique individuals"""
    people: list[Person] = Field(
        description="A list of all unique people mentioned in a text."
    )

from textwrap import dedent

doc = Document(
    text=dedent(
        """
        Bobby Shaftoe was only 16 when he went to sea as a deckhand. \
        He fell victim to the scurvy after just 2 months. \
        Unfazed, Captain Jack discharged his duties with extreme prejudice. \
        """
    )
)

response_synth = CompactAndRefine(
    service_context=service_ctx,
    output_cls=People,
)

query_eng = RetrieverQueryEngine(
    retriever=SummaryIndexRetriever(index=SummaryIndex.from_documents([doc])),
    response_synthesizer=response_synth
)

query_eng.query("Who is mentioned in this contrived, minimal 'story'?")

df = session.get_spans_dataframe()

Note that they are using the global callback setup. When spans are exported via this process the trace is not properly stitched together and the LLM spans are missing.

However if you manually add the same callback manager, the traces show up correctly.

import llama_index
import phoenix as px
from llama_index import Document, ServiceContext, SummaryIndex
from llama_index.callbacks import CallbackManager
from llama_index.indices.list import SummaryIndexRetriever
from llama_index.llms import OpenAI
from llama_index.query_engine import RetrieverQueryEngine
from llama_index.response_synthesizers import CompactAndRefine
from phoenix.trace.llama_index import (
    OpenInferenceTraceCallbackHandler,
)
from pydantic import BaseModel, Field

session = px.launch_app()

callback_handler = OpenInferenceTraceCallbackHandler()
cb_manager = CallbackManager(handlers=[callback_handler])

service_ctx = ServiceContext.from_defaults(
    llm=(OpenAI(model="gpt-3.5-turbo-16k", max_retries=10, callback_manager=cb_manager)),
    callback_manager=cb_manager,
)

class Person(BaseModel):
    """A class representing a person mentioned in the text"""

    name: str = Field(description="The person's name")
    age: int | None = Field(description="The person's age")

class People(BaseModel):
    """A class representing a collection of unique individuals"""

    people: list[Person] = Field(description="A list of all unique people mentioned in a text.")

from textwrap import dedent

doc = Document(
    text=dedent(
        """
        Bobby Shaftoe was only 16 when he went to sea as a deckhand. \
        He fell victim to the scurvy after just 2 months. \
        Unfazed, Captain Jack discharged his duties with extreme prejudice. \
        """
    )
)

response_synth = CompactAndRefine(
    service_context=service_ctx,
)

query_eng = RetrieverQueryEngine(
    retriever=SummaryIndexRetriever(index=SummaryIndex.from_documents([doc])),
    response_synthesizer=response_synth,
    callback_manager=cb_manager,
)

query_eng.query("Who is mentioned in this contrived, minimal 'story'?")

df = session.get_spans_dataframe()

It seems as though the callback manager in the RefineQueryEngine should default to the global one if none is provided: https://github.com/run-llama/llama_index/blob/56b54359b1e6584ae88fd75d2b5e75d44a13b03d/llama_index/query_engine/retriever_query_engine.py#L46

Let me know if this makes sense. I can put up a PR

Version

0.8.59

Steps to Reproduce

Run the following in a notebook. You will see the LLM calls are missing.

mport phoenix as px
import llama_index
from pydantic import BaseModel, Field
from llama_index.callbacks import CallbackManager
from llama_index.llms import OpenAI
from llama_index.query_engine import RetrieverQueryEngine 
from llama_index.indices.list import SummaryIndexRetriever
from llama_index.response_synthesizers import CompactAndRefine
from llama_index import Document, SummaryIndex, ServiceContext
from phoenix.trace.llama_index import (
    OpenInferenceTraceCallbackHandler,
)

session = px.active_session()
if session is None:
    session = px.launch_app()
llama_index.set_global_handler("arize_phoenix")

# Initialize the callback handler
# NOTE: this doesn't work with the global handler nor with the explicit callback manager
# callback_handler = OpenInferenceTraceCallbackHandler()

service_ctx = ServiceContext.from_defaults(
    llm=(
        OpenAI(
            model="gpt-3.5-turbo-16k",
            max_retries=10
        )
    ),
    # callback_manager=CallbackManager(handlers=[callback_handler]),
)

class Person(BaseModel):
    """A class representing a person mentioned in the text"""
    name: str = Field(description="The person's name")
    age: int | None = Field(description="The person's age")

class People(BaseModel):
    """A class representing a collection of unique individuals"""
    people: list[Person] = Field(
        description="A list of all unique people mentioned in a text."
    )

from textwrap import dedent

doc = Document(
    text=dedent(
        """
        Bobby Shaftoe was only 16 when he went to sea as a deckhand. \
        He fell victim to the scurvy after just 2 months. \
        Unfazed, Captain Jack discharged his duties with extreme prejudice. \
        """
    )
)

response_synth = CompactAndRefine(
    service_context=service_ctx,
    output_cls=People,
)

query_eng = RetrieverQueryEngine(
    retriever=SummaryIndexRetriever(index=SummaryIndex.from_documents([doc])),
    response_synthesizer=response_synth
)

query_eng.query("Who is mentioned in this contrived, minimal 'story'?")

df = session.get_spans_dataframe()

However manually adding the same callback manager to the sub-parts will produce the right trace.

Relevant Logs/Tracbacks

No response

logan-markewich commented 10 months ago

This makes sense, I can likely fix this soon :) Thanks for reporting!

amindadgar commented 8 months ago

Is there any fix implemented for this?

logan-markewich commented 8 months ago

@amindadgar this is solved in the latest release of llama-index actually