Closed Daethyra closed 10 months ago
retriever = vector_store.as_retriever()
from langchain.document_loaders.pdf import PyPDFLoader
from langchain.embeddings import OpenAIEmbeddings
from langchain.vectorstores import Chroma
from langchain.chat_models import ChatOpenAI
from langchain.prompts import ChatPromptTemplate
from langchain.schema.output_parser import StrOutputParser
from langchain.schema.runnable import RunnablePassthrough
from langchain.chains import RunnableParallel
from langchain.utils.text_splitter import RecursiveCharacterTextSplitter
from langchain.hub import Hub
# Initialize the hub
hub = Hub()
try:
# Load PDF documents using PyPDFLoader with text splitting
pdf_loader = PyPDFLoader("docs/", text_splitter=RecursiveCharacterTextSplitter(chunk_size=2048, chunk_overlap=256))
pdf_documents = pdf_loader.load_and_split()
# Initialize OpenAIEmbeddings
embeddings = OpenAIEmbeddings()
# Initialize Chroma vector store and embed the PDF documents
vector_store = Chroma.from_documents(pdf_documents, embeddings)
# Initialize ChatOpenAI with gpt-3.5-turbo-1106 model and temperature of 0.25
chat_model = ChatOpenAI(model_name="gpt-3.5-turbo-1106", temperature=0.25)
# Function to format the documents
def format_documents(docs):
return "\n\n".join(doc.page_content for doc in docs)
formatted_docs = format_documents(pdf_documents)
retriever = vector_store.as_retriever(search_type="similarity", top_k=4, search_kwargs={'param': {'boost': {'title': 1.05}}})
# Pull the RAG prompt from the hub
prompt = hub.pull("daethyra/rag-prompt")
prompt_template = ChatPromptTemplate.from_template(prompt)
output_parser = StrOutputParser()
# Create a custom RAG chain
rag_chain = RunnableParallel(
{"context": formatted_docs, "question": RunnablePassthrough()}
) | prompt_template | chat_model | output_parser
# Get user query and invoke the RAG chain
user_query = input("Please enter your query: ")
result = rag_chain.invoke({"question": user_query})
# Print the answer
print(result)
except Exception as e:
print(f"An error occurred: {e}")
import unittest
from unittest.mock import MagicMock
from langchain.document_loaders.pdf import PyPDFLoader
from langchain.embeddings import OpenAIEmbeddings
from langchain.vectorstores import Chroma
from langchain.chat_models import ChatOpenAI
from langchain.prompts import ChatPromptTemplate
from langchain.schema.output_parser import StrOutputParser
from langchain.schema.runnable import RunnablePassthrough, RunnableParallel
from langchain.utils.text_splitter import RecursiveCharacterTextSplitter
from langchain.hub import Hub
class TestRAGChain(unittest.TestCase):
def setUp(self):
# Mocking external dependencies
self.hub = Hub()
self.hub.pull = MagicMock(return_value="Mocked RAG prompt")
self.pdf_loader = PyPDFLoader("docs/", text_splitter=RecursiveCharacterTextSplitter(chunk_size=2048, chunk_overlap=256))
self.pdf_loader.load_and_split = MagicMock(return_value=["Mocked document content"])
self.embeddings = OpenAIEmbeddings()
self.vector_store = Chroma.from_documents(["Mocked document content"], self.embeddings)
self.chat_model = ChatOpenAI(model_name="gpt-3.5-turbo-1106", temperature=0.25)
self.prompt_template = ChatPromptTemplate.from_template(self.hub.pull("daethyra/rag-prompt"))
self.output_parser = StrOutputParser()
self.rag_chain = RunnableParallel(
{"context": "Mocked formatted document", "question": RunnablePassthrough()}
) | self.prompt_template | self.chat_model | self.output_parser
def test_rag_chain_invocation(self):
# Mocking the chat model's response
self.chat_model.__call__ = MagicMock(return_value="Mocked response")
# Test invocation
result = self.rag_chain.invoke({"question": "Test query"})
# Assertions
self.assertEqual(result, "Mocked response")
self.chat_model.__call__.assert_called_with("Mocked RAG prompt\n\nTest query")
def test_document_loading(self):
# Test the loading of documents
loaded_docs = self.pdf_loader.load_and_split()
self.assertEqual(loaded_docs, ["Mocked document content"])
def test_document_embedding(self):
# Test the embedding of documents
embedded_docs = self.vector_store.documents
self.assertEqual(embedded_docs, ["Mocked document content"])
if __name__ == '__main__':
unittest.main()
! Contains hallucinated code !
from langchain.retrievers import VectorStoreRetriever