langchain-ai / langchain

🦜🔗 Build context-aware reasoning applications
https://python.langchain.com
MIT License
95k stars 15.39k forks source link

Issue with RunnableAssign<answer>: TypeError('can only concatenate str (not "ChatPromptValue") to str')Traceback (most recent call last): #23505

Closed akileshjayakumar closed 4 months ago

akileshjayakumar commented 4 months ago

Checked other resources

Example Code

The following code:

def test_chain(chain): test_queries = [ "What is the capital of France?", "Explain the process of photosynthesis.", ]

for query in test_queries:
    try:
        logging.info(f"Running query: {query}")
        response = chain.invoke(query)
        logging.info(f"Query: {query}")
        logging.info(f"Response: {response}")
        print(f"Query: {query}")
        print(f"Response: {response}\n")
    except Exception as e:
        logging.error(
            f"An error occurred while processing the query '{query}': {e}")
        traceback.print_exc()

if name == "main": chain = main() test_chain(chain)

Error Message and Stack Trace (if applicable)

TypeError('can only concatenate str (not "ChatPromptValue") to str')Traceback (most recent call last):

File "/usr/local/lib/python3.10/site-packages/langchain_core/runnables/base.py", line 1626, in _call_with_config context.run(

File "/usr/local/lib/python3.10/site-packages/langchain_core/runnables/config.py", line 347, in call_func_with_variable_args return func(input, **kwargs) # type: ignore[call-arg]

File "/usr/local/lib/python3.10/site-packages/langchain_core/runnables/passthrough.py", line 456, in _invoke **self.mapper.invoke(

File "/usr/local/lib/python3.10/site-packages/langchain_core/runnables/base.py", line 3142, in invoke output = {key: future.result() for key, future in zip(steps, futures)}

File "/usr/local/lib/python3.10/site-packages/langchain_core/runnables/base.py", line 3142, in output = {key: future.result() for key, future in zip(steps, futures)}

File "/usr/local/lib/python3.10/concurrent/futures/_base.py", line 458, in result return self.__get_result()

File "/usr/local/lib/python3.10/concurrent/futures/_base.py", line 403, in __get_result raise self._exception

File "/usr/local/lib/python3.10/concurrent/futures/thread.py", line 58, in run result = self.fn(*self.args, **self.kwargs)

File "/usr/local/lib/python3.10/site-packages/langchain_core/runnables/base.py", line 2499, in invoke input = step.invoke(

File "/usr/local/lib/python3.10/site-packages/langchain_core/runnables/base.py", line 3963, in invoke return self._call_with_config(

File "/usr/local/lib/python3.10/site-packages/langchain_core/runnables/base.py", line 1626, in _call_with_config context.run(

File "/usr/local/lib/python3.10/site-packages/langchain_core/runnables/config.py", line 347, in call_func_with_variable_args return func(input, **kwargs) # type: ignore[call-arg]

File "/usr/local/lib/python3.10/site-packages/langchain_core/runnables/base.py", line 3837, in _invoke output = call_func_with_variable_args(

File "/usr/local/lib/python3.10/site-packages/langchain_core/runnables/config.py", line 347, in call_func_with_variable_args return func(input, **kwargs) # type: ignore[call-arg]

File "/usr/local/lib/python3.10/site-packages/transformers/pipelines/text_generation.py", line 263, in call return super().call(text_inputs, **kwargs)

File "/usr/local/lib/python3.10/site-packages/transformers/pipelines/base.py", line 1243, in call return self.run_single(inputs, preprocess_params, forward_params, postprocess_params)

File "/usr/local/lib/python3.10/site-packages/transformers/pipelines/base.py", line 1249, in run_single model_inputs = self.preprocess(inputs, **preprocess_params)

File "/usr/local/lib/python3.10/site-packages/transformers/pipelines/text_generation.py", line 288, in preprocess prefix + prompt_text,

TypeError: can only concatenate str (not "ChatPromptValue") to str

Description

I expect to see an answer generated by the llm, but always end up running into this error: TypeError('can only concatenate str (not "ChatPromptValue") to str')

Even though the chain is valid.

System Info

pip freeze | grep langchain langchain==0.1.13 langchain-community==0.0.31 langchain-core==0.1.52 langchain-openai==0.1.1 langchain-qdrant==0.1.1 langchain-text-splitters==0.0.2

akileshjayakumar commented 4 months ago

Entire script:

import json import logging import os import time import traceback from typing import List, Any

import httpx import torch from accelerate import Accelerator from dotenv import load_dotenv from langchain.prompts import ChatPromptTemplate, HumanMessagePromptTemplate, PromptTemplate from langchain.schema import Document from langchain.schema.output_parser import StrOutputParser from langchain_community.vectorstores import Qdrant from langchain_core.documents import Document from langchain_core.runnables import RunnablePassthrough, RunnableParallel, RunnableAssign from langchain_core.vectorstores import VectorStoreRetriever from langsmith import traceable from qdrant_client import QdrantClient, models from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline, BitsAndBytesConfig

from dochub.src.embed_helper import EmbeddingClient from dochub.src.reader_helper import PDFLoader, UnstructuredLoader, is_pdf_files from generate.src.chain_helper import RetrievalChain from generate.src.retriever_helper import RetrieverClient from ragas_eval_helper import generate_answer, ragas_evaluate

load_dotenv()

@traceable def load_documents(doc_dir: str) -> List[Document]: """ Load documents from the specified directory, including JSON files and other supported formats.

Args:
    doc_dir (str): Directory containing documents.

Returns:
    List[Document]: List of loaded documents.

Raises:
    FileNotFoundError: If the directory does not exist or contains no valid files.
    ValueError: If no valid documents are loaded.
"""
logging.info(f"Loading documents from directory: {doc_dir}")

if not doc_dir or not os.path.exists(doc_dir):
    raise FileNotFoundError(
        f"Document directory '{doc_dir}' does not exist or is not set.")

supported_extensions = ('.pdf', '.txt', '.docx', '.json')
file_list = [file for file in os.listdir(
    doc_dir) if file.lower().endswith(supported_extensions)]

if not file_list:
    raise FileNotFoundError(
        f"No supported files ({', '.join(supported_extensions)}) found in the document directory '{doc_dir}'.")

pdf_files, non_pdf_files = is_pdf_files(
    [os.path.join(doc_dir, file) for file in file_list])
json_files = [
    file for file in non_pdf_files if file.lower().endswith('.json')]
non_pdf_files = [
    file for file in non_pdf_files if not file.lower().endswith('.json')]

pdf_loader = PDFLoader()
unstructured_loader = UnstructuredLoader()
documents = []

try:
    if pdf_files:
        documents.extend(pdf_loader.run(pdf_files))
    if non_pdf_files:
        documents.extend(unstructured_loader.run(non_pdf_files))
    if json_files:
        for json_file in json_files:
            with open(json_file, 'r', encoding='utf-8') as f:
                data = json.load(f)
            documents.extend(
                [Document(page_content=json.dumps(entry)) for entry in data])
except Exception as e:
    logging.error(f"Error loading documents: {e}")
    traceback.print_exc()
    raise

valid_documents = [
    doc for doc in documents if doc.page_content and doc.page_content.strip()]

if not valid_documents:
    raise ValueError("No valid documents loaded.")

return valid_documents

@traceable def wait_for_qdrant(url: str, retries: int = 5, delay: int = 5) -> bool: """ Wait for Qdrant to be available.

Args:
    url (str): URL of the Qdrant service.
    retries (int): Number of retry attempts.
    delay (int): Delay between retries in seconds.

Returns:
    bool: True if Qdrant is available, False otherwise.
"""
for _ in range(retries):
    try:
        response = httpx.get(url)
        if response.status_code == 200:
            logging.info("Successfully connected to Qdrant")
            return True
    except Exception as e:
        logging.error(f"Failed to connect to Qdrant: {e}")
        traceback.print_exc()
    logging.info(f"Retrying in {delay} seconds...")
    time.sleep(delay)
return False

@traceable def load_retriever(embed_model: Any, documents: List[Document], top_k: int = 3, force_recreate: bool = True) -> RetrieverClient: """ Load retriever with the specified embedding model and documents.

Args:
    embed_model (Any): Embedding model.
    documents (List[Document]): List of documents to be indexed.
    top_k (int): Number of top results to retrieve.
    force_recreate (bool): Whether to force recreate the collection if the dimensions do not match.

Returns:
    RetrieverClient: Initialized retriever client.
"""
logging.info("Creating retriever client and storing documents.")

try:
    embed_model = EmbeddingClient(os.getenv("LLM_TYPE")).get_embed()
except AttributeError as e:
    logging.error(f"Error: {e}")
    traceback.print_exc()
    raise

# The dimensions of the embeddings, should match the embeddings model
dimensions = 384
client = QdrantClient(url="http://qdrant:6333")
collection_name = "my_documents"

if not wait_for_qdrant("http://qdrant:6333/collections"):
    raise ConnectionError(
        "Failed to connect to Qdrant after multiple retries")

try:
    # Check if collection exists
    try:
        existing_collection = client.get_collection(collection_name)
        existing_dimensions = existing_collection["vectors"]["size"]
        if existing_dimensions != dimensions:
            logging.warning(
                f"Existing collection '{collection_name}' has dimensions {existing_dimensions}, expected {dimensions}. Recreating collection.")
            client.delete_collection(collection_name)
            client.create_collection(
                collection_name=collection_name,
                vectors_config=models.VectorParams(
                    size=dimensions, distance=models.Distance.COSINE
                ),
            )
            logging.info(
                f"Recreated collection '{collection_name}' with correct dimensions.")
        else:
            logging.info(
                f"Collection '{collection_name}' already exists with correct dimensions.")
    except Exception as e:
        # If the collection does not exist, create it
        logging.info(
            f"Collection '{collection_name}' does not exist. Creating it.")
        client.create_collection(
            collection_name=collection_name,
            vectors_config=models.VectorParams(
                size=dimensions, distance=models.Distance.COSINE
            ),
        )
        logging.info(f"Created collection '{collection_name}'.")

except Exception as e:
    logging.error(f"Error getting or creating Qdrant collection: {e}")
    traceback.print_exc()
    raise

try:
    vectorstore = Qdrant.from_documents(
        documents,
        embed_model,
        location="http://qdrant:6333",
        collection_name=collection_name,
    )
    return vectorstore.as_retriever(search_kwargs={"k": top_k}, force_recreate=force_recreate)
except Exception as e:
    logging.error(f"Error creating retriever: {e}")
    traceback.print_exc()
    raise

@traceable def setup_embedding_model(type: str) -> EmbeddingClient: """ Create an embedding model client.

Args:
    type (str): Type of the embedding model.

Returns:
    EmbeddingClient: Initialized embedding client.
"""
try:
    logging.info(f"Creating embedding model client with type: {type}")
    return EmbeddingClient(type=type)
except Exception as e:
    logging.error(f"Error creating embedding model: {e}")
    traceback.print_exc()
    raise

@traceable def setup_chat_model(dir: str, max_length: int = 1024, temperature: float = 0.01, top_p: float = 0.95, repetition_penalty: float = 1.15, quantise: bool = True, device_map: str = "auto"): """ Set up a chat model pipeline.

Args:
    dir (str): Directory of the model.
    max_length (int): Maximum number of tokens.
    temperature (float): Temperature for sampling.
    top_p (float): Nucleus sampling threshold.
    repetition_penalty (float): Penalty for repeating tokens.
    quantise (bool): Whether to use quantization.
    device_map (str): Device mapping for model loading.

Returns:
    pipeline: Initialized text generation pipeline.
"""
try:
    # Initialize tokenizer
    tokenizer = AutoTokenizer.from_pretrained(dir)

    # Initialize accelerator
    accelerator = Accelerator()

    if quantise:
        # Quantization config
        bnb_config = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_use_double_quant=True,
            bnb_4bit_quant_type="nf4",
            bnb_4bit_compute_dtype=torch.bfloat16
        )
        # Load model with quantization
        model = AutoModelForCausalLM.from_pretrained(
            dir, quantization_config=bnb_config, device_map=device_map)
    else:
        # Load model without quantization
        model = AutoModelForCausalLM.from_pretrained(
            dir, torch_dtype=torch.bfloat16, device_map=device_map)

    # Prepare model and tokenizer with accelerator
    model, tokenizer = accelerator.prepare(model, tokenizer)

    # Initialize the text generation pipeline
    text_gen_pipeline = pipeline(
        "text-generation",
        model=model,
        tokenizer=tokenizer,
        max_length=max_length,
        temperature=temperature,
        top_p=top_p,
        repetition_penalty=repetition_penalty,
        return_full_text=False,
        do_sample=True
    )

    return text_gen_pipeline
except Exception as e:
    logging.error(f"Error setting up chat model pipeline: {e}")
    traceback.print_exc()
    raise

@traceable def setup_chain(chain_type: str, llm: Any, retriever: Any = None): logging.info(f"Setting up {chain_type} chain.") try: if chain_type == "retrieval":

Define the QA prompt template

        qa_template = """You are an assistant for question-answering tasks. Use the following pieces of retrieved context to answer the question. If you don't know the answer, just say that you don't know. Use three sentences maximum and keep the answer concise.
        Question: {question}
        Context: {context}
        Answer:
        """

        chat_prompt_template = ChatPromptTemplate.from_template(
            qa_template)

        # Define the chain components
        rag_chain = (
            RunnablePassthrough.assign(
                context=(lambda x: format_docs(x["context"])))
            | chat_prompt_template
            | llm
            | StrOutputParser()
        )

        full_chain = RunnableParallel(
            {"context": retriever, "question": RunnablePassthrough()}
        ).assign(answer=rag_chain)

        logging.info("Chain setup complete.")
        return full_chain
    else:
        raise ValueError(
            "Invalid chain type or missing retriever for retrieval chain.")
except Exception as e:
    logging.error(
        f"An error occurred while setting up {chain_type} chain: {e}")
    traceback.print_exc()
    raise

def format_docs(docs): logging.info("Formatting documents...") for doc in docs: logging.info(f"Document content: {doc.page_content}") formatted_docs = "\n\n".join(doc.page_content for doc in docs) logging.info(f"Formatted documents: {formatted_docs}") return formatted_docs

@traceable def main(): logging.basicConfig(level=logging.INFO)

document_directory = os.getenv("DOC_DIR")
embed_model_type = os.getenv("LLM_TYPE")
chat_model_dir = os.getenv("CHAT_MODEL_DIR")
chain_type = "retrieval"

try:
    # Load documents
    documents = load_documents(document_directory)
    logging.info(f"Loaded {len(documents)} documents.")

    # Setup embedding model and retriever
    embed_model = setup_embedding_model(embed_model_type)
    retriever = load_retriever(embed_model, documents)

    # Setup chat model and chain
    chat_model = setup_chat_model(chat_model_dir)
    chain = setup_chain(chain_type, chat_model, retriever)
    logging.info("Main setup complete.")
    return chain

except Exception as e:
    logging.error(f"An error occurred: {e}")
    traceback.print_exc()
    raise

def test_chain(chain): test_queries = [ "What is the capital of France?", "Explain the process of photosynthesis.", ]

for query in test_queries:
    try:
        logging.info(f"Running query: {query}")
        response = chain.invoke(query)
        logging.info(f"Query: {query}")
        logging.info(f"Response: {response}")
        print(f"Query: {query}")
        print(f"Response: {response}\n")
    except Exception as e:
        logging.error(
            f"An error occurred while processing the query '{query}': {e}")
        traceback.print_exc()

if name == "main": chain = main() test_chain(chain)

akileshjayakumar commented 4 months ago

I managed to find the fix. The main issue relies on the setup_chat_model() function. This is the updated code for the setup_chat_model() function.

def setup_chat_model(dir: str, max_length: int = 1024, temperature: float = 0.01, top_p: float = 0.95, repetition_penalty: float = 1.15, quantise: bool = False, device_map: str = "auto"):
    """
    Set up a chat model pipeline.

    Args:
        dir (str): Directory of the model.
        max_length (int): Maximum number of tokens.
        temperature (float): Temperature for sampling.
        top_p (float): Nucleus sampling threshold.
        repetition_penalty (float): Penalty for repeating tokens.
        quantise (bool): Whether to use quantization.
        device_map (str): Device mapping for model loading.

    Returns:
        pipeline: Initialized text generation pipeline.
    """
    try:
        # Initialize tokenizer
        tokenizer = AutoTokenizer.from_pretrained(dir)

        # Initialize accelerator
        accelerator = Accelerator()

        if quantise:
            # Quantization config
            bnb_config = BitsAndBytesConfig(
                load_in_4bit=True,
                bnb_4bit_use_double_quant=True,
                bnb_4bit_quant_type="nf4",
                bnb_4bit_compute_dtype=torch.bfloat16
            )
            # Load model with quantization
            model = AutoModelForCausalLM.from_pretrained(
                dir, quantization_config=bnb_config, device_map=device_map)
        else:
            # Load model without quantization
            model = AutoModelForCausalLM.from_pretrained(
                dir, torch_dtype=torch.bfloat16, device_map=device_map)

        # Prepare model and tokenizer with accelerator
        model, tokenizer = accelerator.prepare(model, tokenizer)

        # Initialize the text generation pipeline
        text_gen_pipeline = pipeline(
            "text-generation",
            model=model,
            tokenizer=tokenizer,
            max_length=max_length,
            temperature=temperature,
            top_p=top_p,
            repetition_penalty=repetition_penalty,
            return_full_text=False,
            do_sample=True,
            truncation=True
        )

        lc_pipeline = HuggingFacePipeline(pipeline=text_gen_pipeline)

        return lc_pipeline
    except Exception as e:
        logging.error(f"Error setting up chat model pipeline: {e}")
        traceback.print_exc()
        raise

This updated function resolves the TypeError('can only concatenate str (not "StringPromptValue") to str')Traceback (most recent call last): error.

akileshjayakumar commented 4 months ago

Thanks!