Daethyra / Build-RAGAI

Interactive notes (Jupyter Notebooks) for building AI-powered applications
Other
27 stars 3 forks source link

Suggestion for `query_local_docs.py` code refactorization #88

Closed Daethyra closed 10 months ago

Daethyra commented 10 months ago

! Contains hallucinated code !

  • Minimum one instance: -> from langchain.retrievers import VectorStoreRetriever
import os
from typing import List, Tuple
from langchain.document_loaders import PyPDFLoader
from langchain.text_splitters import RecursiveCharacterTextSplitter
from langchain.embeddings import OpenAIEmbeddings
from langchain.vectorstores import Chroma
from langchain.retrievers import VectorStoreRetriever
from langchain.chat_models import ChatOpenAI
from langchain.schema.output_parser import StrOutputParser
from langchain.schema.runnable import RunnablePassthrough
from langchain.prompts import ChatPromptTemplate

class DocumentRetrievalChatbot:
    def __init__(self, pdf_directory: str, persist_directory: str = "./chroma_db"):
        self.pdf_directory = pdf_directory
        self.persist_directory = persist_directory
        self.db = self._initialize_chroma_db()
        self.retriever = VectorStoreRetriever(self.db)
        self.chat = self._initialize_chat_model()

    def _initialize_chroma_db(self):
        loader = PyPDFLoader(self.pdf_directory, recursive=True)
        documents = loader.load()

        text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000)
        docs = text_splitter.split_documents(documents)

        embedding_function = OpenAIEmbeddings()
        db = Chroma.from_documents(docs, embedding_function, persist_directory=self.persist_directory)
        return db

    def _initialize_chat_model(self):
        output_parser = StrOutputParser()
        template = ChatPromptTemplate()
        chat = ChatOpenAI(model_name="gpt-3.5-turbo", temperature=0)
        return chat

    def get_responses(self, query: str, top_k: int = 5) -> str:
        retrieved_docs = self.retriever.retrieve(query, top_k=top_k)
        responses = []

        for doc in retrieved_docs:
            response = self.chat.generate_response(query, doc.page_content)
            responses.append(response.text)

        return " ".join(responses)

    def run_query_loop(self):
        while True:
            query = input("Enter your query (or 'q' to quit): ")
            if query.lower() == "q":
                break

            response = self.get_responses(query)
            print("Response:", response)

if __name__ == "__main__":
    pdf_directory = "data/"
    bot = DocumentRetrievalChatbot(pdf_directory)
    bot.run_query_loop()

Daethyra commented 10 months ago

Simple, working logic

from langchain.document_loaders.pdf import PyPDFLoader
from langchain.embeddings import OpenAIEmbeddings
from langchain.vectorstores import Chroma
from langchain.chat_models import ChatOpenAI
from langchain.prompts import ChatPromptTemplate
from langchain.schema.output_parser import StrOutputParser
from langchain.schema.runnable import RunnablePassthrough
from langchain.chains import RunnableParallel
from langchain.utils.text_splitter import RecursiveCharacterTextSplitter
from langchain.hub import Hub

# Initialize the hub
hub = Hub()

try:
    # Load PDF documents using PyPDFLoader with text splitting
    pdf_loader = PyPDFLoader("docs/", text_splitter=RecursiveCharacterTextSplitter(chunk_size=2048, chunk_overlap=256))
    pdf_documents = pdf_loader.load_and_split()

    # Initialize OpenAIEmbeddings
    embeddings = OpenAIEmbeddings()

    # Initialize Chroma vector store and embed the PDF documents
    vector_store = Chroma.from_documents(pdf_documents, embeddings)

    # Initialize ChatOpenAI with gpt-3.5-turbo-1106 model and temperature of 0.25
    chat_model = ChatOpenAI(model_name="gpt-3.5-turbo-1106", temperature=0.25)

    # Function to format the documents
    def format_documents(docs):
        return "\n\n".join(doc.page_content for doc in docs)

    formatted_docs = format_documents(pdf_documents)

    retriever = vector_store.as_retriever(search_type="similarity", top_k=4, search_kwargs={'param': {'boost': {'title': 1.05}}})

    # Pull the RAG prompt from the hub
    prompt = hub.pull("daethyra/rag-prompt")
    prompt_template = ChatPromptTemplate.from_template(prompt)
    output_parser = StrOutputParser()

    # Create a custom RAG chain
    rag_chain = RunnableParallel(
        {"context": formatted_docs, "question": RunnablePassthrough()}
    ) | prompt_template | chat_model | output_parser

    # Get user query and invoke the RAG chain
    user_query = input("Please enter your query: ")
    result = rag_chain.invoke({"question": user_query})

    # Print the answer
    print(result)

except Exception as e:
    print(f"An error occurred: {e}")

Accompanying Unittester

import unittest
from unittest.mock import MagicMock

from langchain.document_loaders.pdf import PyPDFLoader
from langchain.embeddings import OpenAIEmbeddings
from langchain.vectorstores import Chroma
from langchain.chat_models import ChatOpenAI
from langchain.prompts import ChatPromptTemplate
from langchain.schema.output_parser import StrOutputParser
from langchain.schema.runnable import RunnablePassthrough, RunnableParallel
from langchain.utils.text_splitter import RecursiveCharacterTextSplitter
from langchain.hub import Hub

class TestRAGChain(unittest.TestCase):

    def setUp(self):
        # Mocking external dependencies
        self.hub = Hub()
        self.hub.pull = MagicMock(return_value="Mocked RAG prompt")

        self.pdf_loader = PyPDFLoader("docs/", text_splitter=RecursiveCharacterTextSplitter(chunk_size=2048, chunk_overlap=256))
        self.pdf_loader.load_and_split = MagicMock(return_value=["Mocked document content"])

        self.embeddings = OpenAIEmbeddings()
        self.vector_store = Chroma.from_documents(["Mocked document content"], self.embeddings)

        self.chat_model = ChatOpenAI(model_name="gpt-3.5-turbo-1106", temperature=0.25)

        self.prompt_template = ChatPromptTemplate.from_template(self.hub.pull("daethyra/rag-prompt"))
        self.output_parser = StrOutputParser()

        self.rag_chain = RunnableParallel(
            {"context": "Mocked formatted document", "question": RunnablePassthrough()}
        ) | self.prompt_template | self.chat_model | self.output_parser

    def test_rag_chain_invocation(self):
        # Mocking the chat model's response
        self.chat_model.__call__ = MagicMock(return_value="Mocked response")

        # Test invocation
        result = self.rag_chain.invoke({"question": "Test query"})

        # Assertions
        self.assertEqual(result, "Mocked response")
        self.chat_model.__call__.assert_called_with("Mocked RAG prompt\n\nTest query")

    def test_document_loading(self):
        # Test the loading of documents
        loaded_docs = self.pdf_loader.load_and_split()
        self.assertEqual(loaded_docs, ["Mocked document content"])

    def test_document_embedding(self):
        # Test the embedding of documents
        embedded_docs = self.vector_store.documents
        self.assertEqual(embedded_docs, ["Mocked document content"])

if __name__ == '__main__':
    unittest.main()