langchain-ai / langchain

🦜🔗 Build context-aware reasoning applications
https://python.langchain.com
MIT License
92.27k stars 14.74k forks source link

Progress bar for LLMChain #6053

Closed louisoutin closed 11 months ago

louisoutin commented 1 year ago

Hello, Is there a way to track progress when giving a list of inputs to a LLMChain object using tqdm for example? I didn't see any parameter that would allow me to use tqdm. I also checked if I could write a Callback for this. But the hooks doesn't seem to allow for that. Anyone managed to use some progress bar?

mukut03 commented 1 year ago

Can you provide a code snippet explaining the current behavior and the behavior you want?

louisoutin commented 1 year ago

Sorry for the lack of example. We can take the example in the documentation: here

Current Behavior

input_list = [
    {"product": "socks"},
    {"product": "computer"},
    {"product": "shoes"}
]

llm_chain.apply(input_list)

Output:

[{'text': '\n\nSocktastic!'},
 {'text': '\n\nTechCore Solutions.'},
 {'text': '\n\nFootwear Factory.'}]

Behavior I would like

input_list = [
    {"product": "socks"},
    {"product": "computer"},
    {"product": "shoes"}
]

llm_chain.apply(input_list, show_progress=True)

Output:

 33%|█████                    | 1/3 [00:02<00:09, Xit/s] (inference w batch size X)
[{'text': '\n\nSocktastic!'},
 {'text': '\n\nTechCore Solutions.'},
 {'text': '\n\nFootwear Factory.'}]

Basically I'm just looking for an easy way to track progression for long input list. Thanks for the help

dosubot[bot] commented 12 months ago

Hi, @louisoutin! I'm Dosu, and I'm here to help the LangChain team manage their backlog. I wanted to let you know that we are marking this issue as stale.

From what I understand, you are requesting the addition of a progress bar to the LLMChain object in order to track progress when giving a list of inputs. You have provided a code snippet explaining the current behavior and the behavior you would like.

Before we proceed, we would like to confirm if this issue is still relevant to the latest version of the LangChain repository. If it is, please let us know by commenting on this issue. Otherwise, feel free to close the issue yourself or it will be automatically closed in 7 days.

Thank you for your understanding and contribution to the LangChain project!

solalatus commented 10 months ago

Any progress on this, or anybody having a nice hack? For MapReduce like chains on large amount of doc chunks, this would totally make sense!

BrandonStudio commented 5 months ago

You can use callback to perform this.

First, define your callback

from typing import Any, Dict
from uuid import UUID
from tqdm.auto import tqdm
from langchain_core.callbacks import BaseCallbackHandler
class BatchCallback(BaseCallbackHandler):
    def __init__(self, total: int):
        super().__init__()
        self.count = 0
        self.progress_bar = tqdm(total=total) # define a progress bar

    # Override on_llm_end method. This is called after every response from LLM
    def on_llm_end(self, response: LLMResult, *, run_id: UUID, parent_run_id: UUID | None = None, **kwargs: Any) -> Any:
        self.count += 1
        self.progress_bar.update(1)

Then, initialize an instance of callback and run batch with it

# Assume your chain is `chain`, inputs is `inputs`
cb = BatchCallback(len(inputs)) # init callback
chain.batch(inputs, config={"callbacks": [cb]})
cb.progress_bar.close()
hxia-neos commented 2 months ago

This callback does not work if your chain has more than one LLM call. ~I am frustrated that I could not find a way to just listen to the even at the beginning of processing the input, and not every nested calls to LLM or chains.~

Here is how to get the correct behaviour:

    def on_chain_start(
        self, serialized: dict[str, Any], inputs: dict[str, Any], **kwargs: Any
    ) -> Any:
        if kwargs["parent_run_id"] is None:
            self.count += 1
            self.progress_bar.update(1)
thiswillbeyourgithub commented 2 months ago

Personaly my hackish code for tqdm progress bars in my chain seems to work on my setup:

@optional_typecheck
def pbar_chain(
    llm: Union[ChatLiteLLM, ChatOpenAI, FakeListChatModel],
    len_func: str,
    **tqdm_kwargs,
    ) -> RunnableLambda:
    "create a chain that just sets a tqdm progress bar"

    @chain
    def actual_pbar_chain(
        inputs: Union[dict, List],
        llm: Union[ChatLiteLLM, ChatOpenAI, FakeListChatModel] = llm,
        ) -> Union[dict, List]:

        llm.callbacks[0].pbar.append(
            tqdm(
                total=eval(len_func),
                **tqdm_kwargs,
            )
        )
        assert llm.callbacks[0].pbar[-1].total

        return inputs

    return actual_pbar_chain

@optional_typecheck
def pbar_closer(
    llm: Union[ChatLiteLLM, ChatOpenAI, FakeListChatModel],
    ) -> RunnableLambda:
    "close a pbar created by pbar_chain"

    @chain
    def actual_pbar_closer(
        inputs: Union[dict, List],
        llm: Union[ChatLiteLLM, ChatOpenAI, FakeListChatModel] = llm,
        ) -> Union[dict, List]:
        pbar = llm.callbacks[0].pbar[-1]
        pbar.update(pbar.total - pbar.n)
        pbar.close()

        return inputs
    return actual_pbar_closer

With this in my callback:

@optional_typecheck
class PriceCountingCallback(BaseCallbackHandler):
    "source: https://python.langchain.com/docs/modules/callbacks/"
    def __init__(self, verbose, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.verbose = verbose
        self.total_tokens = 0
        self.prompt_tokens = 0
        self.completion_tokens = 0
        self.methods_called = []
        self.authorized_methods = [
            "on_llm_start",
            "on_chat_model_start",
            "on_llm_end",
            "on_llm_error",
            "on_chain_start",
            "on_chain_end",
            "on_chain_error",
        ]
        self.pbar = []

    def __repr__(self) -> str:
        # setting __repr__ and __str__ is important because it can
        # maybe be used for caching?
        return "PriceCountingCallback"

    def __str__(self) -> str:
        return "PriceCountingCallback"

    def _check_methods_called(self) -> bool:
        assert all(meth in dir(self) for meth in self.methods_called), (
            "unexpected method names!")
        wrong = [
            meth for meth in self.methods_called
            if meth not in self.authorized_methods]
        if wrong:
            raise Exception(
                f"Unauthorized_method were called: {','.join(wrong)}")
        return True

    def on_llm_start(
        self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
    ) -> Any:
        """Run when LLM starts running."""
        if self.verbose:
            print("Callback method: on_llm_start")
            print(serialized)
            print(prompts)
            print(kwargs)
            print("Callback method end: on_llm_start")
        self.methods_called.append("on_llm_start")
        self._check_methods_called()

    def on_chat_model_start(
        self, serialized: Dict[str, Any], messages: List[List[BaseMessage]], **kwargs: Any
    ) -> Any:
        """Run when Chat Model starts running."""
        if self.verbose:
            print("Callback method: on_chat_model_start")
            print(serialized)
            print(messages)
            print(kwargs)
            print("Callback method end: on_chat_model_start")
        self.methods_called.append("on_chat_model_start")
        self._check_methods_called()

    def on_llm_end(self, response: LLMResult, **kwargs: Any) -> Any:
        """Run when LLM ends running."""
        if self.verbose:
            print("Callback method: on_llm_end")
            print(response)
            print(kwargs)
            print("Callback method end: on_llm_end")

        new_p = response.llm_output["token_usage"]["prompt_tokens"]
        new_c = response.llm_output["token_usage"]["completion_tokens"]
        self.prompt_tokens += new_p
        self.completion_tokens += new_c
        self.total_tokens += new_p + new_c
        assert self.total_tokens == self.prompt_tokens + self.completion_tokens
        self.methods_called.append("on_llm_end")
        self._check_methods_called()

    def on_llm_error(
        self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
    ) -> Any:
        """Run when LLM errors."""
        if self.verbose:
            print("Callback method: on_llm_error")
            print(error)
            print(kwargs)
            print("Callback method end: on_llm_error")
        self.methods_called.append("on_llm_error")
        self._check_methods_called()

    def on_chain_start(
        self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
    ) -> Any:
        """Run when chain starts running."""
        if self.verbose:
            print("Callback method: on_chain_start")
            print(serialized)
            print(inputs)
            print(kwargs)
            print("Callback method end: on_chain_start")
        self.methods_called.append("on_chain_start")
        self._check_methods_called()

    def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> Any:
        """Run when chain ends running."""
        if self.verbose:
            print("Callback method: on_chain_end")
            print(outputs)
            print(kwargs)
            print("Callback method end: on_chain_end")
        self.methods_called.append("on_chain_end")
        self._check_methods_called()
        if self.pbar:
            self.pbar[-1].update(1)

    def on_chain_error(
        self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
    ) -> Any:
        """Run when chain errors."""
        if self.verbose:
            print("Callback method: on_chain_error")
            print(error)
            print(kwargs)
            print("Callback method end: on_chain_error")
        self.methods_called.append("on_chain_error")
        self._check_methods_called()

Here's how I use it:

rag_chain = (
    loaded_memory
    | standalone_question
    | retrieve_documents
    | pbar_chain(
            llm=self.eval_llm,
            len_func="len(inputs['unfiltered_docs'])",
            desc="LLM evaluation",
            unit="doc",
        )
    | refilter_documents
    | pbar_closer(llm=self.eval_llm)
    | pbar_chain(
            llm=self.llm,
            len_func="len(inputs['filtered_docs'])",
            desc="Answering each",
            unit="doc",
        )
    | answer_all_docs
    | pbar_closer(llm=self.llm)
)