langchain-ai / chat-langchain

https://chat.langchain.com
MIT License
4.98k stars 1.16k forks source link

BaseCallbackHandler is working fine with ChatOpenAI but raising error when we use ChatGoogleGenerativeAI LLM #277

Open Adil-Ashraf opened 4 months ago

Adil-Ashraf commented 4 months ago

this is my code for openai llm. And it is working fine

from fastapi import FastAPI, Request, APIRouter import pdb import cohere from fastapi.responses import StreamingResponse from langchain.schema.messages import HumanMessage, AIMessage from collections.abc import Generator from queue import Queue, Empty from langchain.chat_models import ChatAnthropic, ChatOpenAI from chatbot.app.utils.chain import QueueCallback, create_chain from chatbot.app.utils.retriever import get_retriever, create_retriever_chain from threading import Thread from chatbot.app.settings import COHERE_API_KEY

router = APIRouter()

@router.post("/chat") async def chat_endpoint(request: Request): data = await request.json() question = data.get("message") chat_history = data.get("history", []) converted_chat_history = [] for message in chat_history: if message.get("human") is not None: converted_chat_history.append(HumanMessage(content=message["human"])) if message.get("ai") is not None: converted_chat_history.append(AIMessage(content=message["ai"])) data.get("conversation_id")

cohere_client = cohere.Client(COHERE_API_KEY)

def stream() -> Generator: q = Queue() job_done = object()

llm = ChatOpenAI(
  model="gpt-3.5-turbo-16k",
  streaming=True,
  temperature=0,
  callbacks=[QueueCallback(q)],
)

llm_without_callback = ChatOpenAI(
  model="gpt-3.5-turbo-16k",
  streaming=True,
  temperature=0,
)

def task():
  retriever_chain = create_retriever_chain(
    chat_history, llm_without_callback, get_retriever()
  )
  chain = create_chain(llm, retriever_chain)
  docs = retriever_chain.invoke(
    {"question": question, "chat_history": chat_history},
  )
  documents_for_rerank = [{"text": doc.page_content} for doc in docs]

  # Perform reranking using Cohere
  results = cohere_client.rerank(
    query=question,
    documents=documents_for_rerank,
    top_n=3,
    model='rerank-english-v2.0'
  )

  # Extract texts from the rerank results with a relevance score greater than 0.70
  high_score_texts = {
    result.document['text'] for result in results
    if result.relevance_score > 0.70
  }

  # Construct the filtered_docs list by checking the page_content against high_score_texts
  filtered_docs = [doc for doc in docs if doc.page_content in high_score_texts]

  url_set = set()
  if filtered_docs:
    for doc in filtered_docs:
      if doc.metadata["source"] in url_set:
        continue
      q.put(doc.metadata["title"] + ":" + doc.metadata["source"] + "\n")
      url_set.add(doc.metadata["source"])

  q.put("SOURCES:----------------------------")

  chain.invoke(
    {
      "question": question,
      "chat_history": converted_chat_history,
      "context": docs,
    },
  )
  q.put(job_done)

t = Thread(target=task)
t.start()

content = ""

while True:
  try:
    next_token = q.get(True, timeout=1)
    if next_token is job_done:
      break
    content += next_token
    yield next_token
  except Empty:
      continue

return StreamingResponse(stream()) this is the QueueCallback class from langchain.callbacks.base import BaseCallbackHandler

class QueueCallback(BaseCallbackHandler): def init(self, q): self.q = q

def on_llm_new_token(self, token: str, **kwargs: any) -> None: self.q.put(token)

def on_llm_end(self, *args, **kwargs: any) -> None: return self.q.empty()

but when i use it ChatGoogleGenerativeAI it raise error.Below is the ChatGoogleGenerativeAI code.

from fastapi import FastAPI, Request, APIRouter import pdb import cohere from fastapi.responses import StreamingResponse from langchain.schema.messages import HumanMessage, AIMessage from collections.abc import Generator from queue import Queue, Empty from langchain.chat_models import ChatAnthropic, ChatOpenAI from chatbot.app.utils.chain import create_chain from chatbot.app.utils.retriever import get_retriever, create_retriever_chain from threading import Thread

from langchain_google_genai import ChatGoogleGenerativeAI from chatbot.app.settings import COHERE_API_KEY, GOOGLE_API_KEY from langchain.callbacks.manager import BaseCallbackManager

router = APIRouter()

class QueueCallback(BaseCallbackManager): def init(self, q): self.q = q

def on_llm_new_token(self, token: str, **kwargs: any) -> None: self.q.put(token)

def on_llm_end(self, *args, **kwargs: any) -> None: return self.q.empty()

@router.post("/chat") async def chat_endpoint(request: Request): data = await request.json() question = data.get("message") chat_history = data.get("history", []) converted_chat_history = [] for message in chat_history: if message.get("human") is not None: converted_chat_history.append(HumanMessage(content=message["human"])) if message.get("ai") is not None: converted_chat_history.append(AIMessage(content=message["ai"])) data.get("conversation_id")

cohere_client = cohere.Client(COHERE_API_KEY)

def stream() -> Generator: q = Queue() job_done = object()

callback_manager = BaseCallbackManager([])

# callback_manager.add_callback(QueueCallback(q))

llm = ChatGoogleGenerativeAI(
  model="gemini-pro",
  streaming=True,
  temperature=0,
  google_api_key=GOOGLE_API_KEY,
  callbacks=[QueueCallback(q)],
)

llm_without_callback = ChatGoogleGenerativeAI(
  model="gemini-pro",
  streaming=True,
  temperature=0,
  google_api_key=GOOGLE_API_KEY,
)

def task():
  retriever_chain = create_retriever_chain(
    chat_history, llm_without_callback, get_retriever()
  )
  chain = create_chain(llm, retriever_chain)
  docs = retriever_chain.invoke(
    {"question": question, "chat_history": chat_history},
  )
  documents_for_rerank = [{"text": doc.page_content} for doc in docs]

  # Perform reranking using Cohere
  results = cohere_client.rerank(
    query=question,
    documents=documents_for_rerank,
    top_n=3,
    model='rerank-english-v2.0'
  )

  # Extract texts from the rerank results with a relevance score greater than 0.70
  high_score_texts = {
    result.document['text'] for result in results
    if result.relevance_score > 0.70
  }

  # Construct the filtered_docs list by checking the page_content against high_score_texts
  filtered_docs = [doc for doc in docs if doc.page_content in high_score_texts]

  url_set = set()
  if filtered_docs:
    for doc in filtered_docs:
      if doc.metadata["source"] in url_set:
        continue
      q.put(doc.metadata["title"] + ":" + doc.metadata["source"] + "\n")
      url_set.add(doc.metadata["source"])

  q.put("SOURCES:----------------------------")

  chain.invoke(
    {
      "question": question,
      "chat_history": converted_chat_history,
      "context": docs,
    },
  )
  q.put(job_done)

t = Thread(target=task)
t.start()

content = ""

while True:
  try:
    next_token = q.get(True, timeout=1)
    if next_token is job_done:
      break
    content += next_token
    yield next_token
  except Empty:
      continue

return StreamingResponse(stream()) below is the error that I got Exception in ASGI application Traceback (most recent call last): File "/home/adil/.cache/pypoetry/virtualenvs/chatbot-L7SgJIPD-py3.10/lib/python3.10/site-packages/uvicorn/protocols/http/h11_impl.py", line 408, in run_asgi result = await app( # type: ignore[func-returns-value] File "/home/adil/.cache/pypoetry/virtualenvs/chatbot-L7SgJIPD-py3.10/lib/python3.10/site-packages/uvicorn/middleware/proxy_headers.py", line 84, in call return await self.app(scope, receive, send) File "/home/adil/.cache/pypoetry/virtualenvs/chatbot-L7SgJIPD-py3.10/lib/python3.10/site-packages/fastapi/applications.py", line 292, in call await super().call(scope, receive, send) File "/home/adil/.cache/pypoetry/virtualenvs/chatbot-L7SgJIPD-py3.10/lib/python3.10/site-packages/starlette/applications.py", line 122, in call await self.middleware_stack(scope, receive, send) File "/home/adil/.cache/pypoetry/virtualenvs/chatbot-L7SgJIPD-py3.10/lib/python3.10/site-packages/starlette/middleware/errors.py", line 184, in call raise exc File "/home/adil/.cache/pypoetry/virtualenvs/chatbot-L7SgJIPD-py3.10/lib/python3.10/site-packages/starlette/middleware/errors.py", line 162, in call await self.app(scope, receive, _send) File "/home/adil/.cache/pypoetry/virtualenvs/chatbot-L7SgJIPD-py3.10/lib/python3.10/site-packages/starlette/middleware/cors.py", line 91, in call await self.simple_response(scope, receive, send, request_headers=headers) File "/home/adil/.cache/pypoetry/virtualenvs/chatbot-L7SgJIPD-py3.10/lib/python3.10/site-packages/starlette/middleware/cors.py", line 146, in simple_response await self.app(scope, receive, send) File "/home/adil/.cache/pypoetry/virtualenvs/chatbot-L7SgJIPD-py3.10/lib/python3.10/site-packages/starlette/middleware/exceptions.py", line 79, in call raise exc File "/home/adil/.cache/pypoetry/virtualenvs/chatbot-L7SgJIPD-py3.10/lib/python3.10/site-packages/starlette/middleware/exceptions.py", line 68, in call await self.app(scope, receive, sender) File "/home/adil/.cache/pypoetry/virtualenvs/chatbot-L7SgJIPD-py3.10/lib/python3.10/site-packages/fastapi/middleware/asyncexitstack.py", line 20, in call raise e File "/home/adil/.cache/pypoetry/virtualenvs/chatbot-L7SgJIPD-py3.10/lib/python3.10/site-packages/fastapi/middleware/asyncexitstack.py", line 17, in call await self.app(scope, receive, send) File "/home/adil/.cache/pypoetry/virtualenvs/chatbot-L7SgJIPD-py3.10/lib/python3.10/site-packages/starlette/routing.py", line 718, in call await route.handle(scope, receive, send) File "/home/adil/.cache/pypoetry/virtualenvs/chatbot-L7SgJIPD-py3.10/lib/python3.10/site-packages/starlette/routing.py", line 276, in handle await self.app(scope, receive, send) File "/home/adil/.cache/pypoetry/virtualenvs/chatbot-L7SgJIPD-py3.10/lib/python3.10/site-packages/starlette/routing.py", line 69, in app await response(scope, receive, send) File "/home/adil/.cache/pypoetry/virtualenvs/chatbot-L7SgJIPD-py3.10/lib/python3.10/site-packages/starlette/responses.py", line 270, in call async with anyio.create_task_group() as task_group: File "/home/adil/.cache/pypoetry/virtualenvs/chatbot-L7SgJIPD-py3.10/lib/python3.10/site-packages/anyio/_backends/_asyncio.py", line 597, in aexit raise exceptions[0] File "/home/adil/.cache/pypoetry/virtualenvs/chatbot-L7SgJIPD-py3.10/lib/python3.10/site-packages/starlette/responses.py", line 273, in wrap await func() File "/home/adil/.cache/pypoetry/virtualenvs/chatbot-L7SgJIPD-py3.10/lib/python3.10/site-packages/starlette/responses.py", line 262, in stream_response async for chunk in self.body_iterator: File "/home/adil/.cache/pypoetry/virtualenvs/chatbot-L7SgJIPD-py3.10/lib/python3.10/site-packages/starlette/concurrency.py", line 63, in iterate_in_threadpool yield await anyio.to_thread.run_sync(_next, iterator) File "/home/adil/.cache/pypoetry/virtualenvs/chatbot-L7SgJIPD-py3.10/lib/python3.10/site-packages/anyio/to_thread.py", line 33, in run_sync return await get_asynclib().run_sync_in_worker_thread( File "/home/adil/.cache/pypoetry/virtualenvs/chatbot-L7SgJIPD-py3.10/lib/python3.10/site-packages/anyio/_backends/_asyncio.py", line 877, in run_sync_in_worker_thread return await future File "/home/adil/.cache/pypoetry/virtualenvs/chatbot-L7SgJIPD-py3.10/lib/python3.10/site-packages/anyio/_backends/_asyncio.py", line 807, in run result = context.run(func, *args) File "/home/adil/.cache/pypoetry/virtualenvs/chatbot-L7SgJIPD-py3.10/lib/python3.10/site-packages/starlette/concurrency.py", line 53, in _next return next(iterator) File "/home/adil/Documents/Devbox/chat-langchain/chatbot/app/routers/chat.py", line 53, in stream llm = ChatGoogleGenerativeAI( File "/home/adil/.cache/pypoetry/virtualenvs/chatbot-L7SgJIPD-py3.10/lib/python3.10/site-packages/langchain_core/load/serializable.py", line 107, in init super().init(**kwargs) File "/home/adil/.cache/pypoetry/virtualenvs/chatbot-L7SgJIPD-py3.10/lib/python3.10/site-packages/pydantic/v1/main.py", line 341, in init raise validation_error pydantic.v1.error_wrappers.ValidationError: 2 validation errors for ChatGoogleGenerativeAI callbacks -> 0 instance of BaseCallbackHandler expected (type=type_error.arbitrary_type; expected_arbitrary_type=BaseCallbackHandler) callbacks instance of BaseCallbackManager expected (type=type_error.arbitrary_type; expected_arbitrary_type=BaseCallbackManager) I have tried to use from langchain.callbacks.base import BaseCallbackHandler but didn't found it helpful. How can i fix it?