Open Oh-JunTaek opened 2 months ago
@Oh-JunTaek
document_retriever.py
파일 수정 document_retriever.py
파일을 수정해주세요. 이렇게 하면 되는지 한 번 더 검토부탁드립니다 ㅜ
from langchain_community.document_loaders import TextLoader
from langchain_core.documents import Document
from langchain_text_splitters import MarkdownHeaderTextSplitter
from langchain_community.vectorstores import FAISS
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough
from langchain_core.prompts import PromptTemplate
from langchain_core.prompts.few_shot import FewShotPromptTemplate
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
from langchain_community.retrievers import BM25Retriever
from langchain.retrievers import EnsembleRetriever
from langchain.retrievers.multi_query import MultiQueryRetriever
import pprint
import json
import logging, os
import ssl
from dotenv import load_dotenv
# Few-Shot Learning에 사용할 예제 데이터를 JSON 파일에서 로드하는 함수
def load_examples_from_file(file_path):
with open(file_path, 'r', encoding='utf-8') as file:
data = json.load(file)
return data['examples']
# Few-Shot Learning에 사용할 예제 데이터를 로드하고, Few-Shot Prompt Template을 생성
json_file_path = './few_shot_examples.json' # 예제 데이터 파일 경로
examples = load_examples_from_file(json_file_path)
# Few-Shot Prompt Template 생성
few_shot_prompt = FewShotPromptTemplate(
examples=examples,
example_prompt="Question: {question}\nAnswer: {answer}",
suffix="Question: {question}\nAnswer:",
input_variables=["question"]
)
# 주어진 파일 경로에서 Markdown(.md) 파일을 로드하는 함수
def load_md_files(file_path):
loader = TextLoader(file_path)
documents = loader.load() # 파일에서 문서를 로드
print(f"Loaded {len(documents)} documents from the MD.")
print("len(docs):", len(documents))
return documents
# 로드한 문서를 Markdown 헤더 기준으로 분할하는 함수
def split_docs(documents):
assert len(documents) == 1 # 문서가 하나만 로드되었는지 확인
assert isinstance(documents[0], Document) # 문서 객체인지 확인
readme_content = documents[0].page_content
headers_to_split_on = [
("#", "Header 1"),
("##", "Header 2"),
("###", "Header 3"),
]
markdown_splitter = MarkdownHeaderTextSplitter(headers_to_split_on=headers_to_split_on, strip_headers=False)
splitted_md = markdown_splitter.split_text(readme_content)
return splitted_md
# 분할된 문서를 이용하여 BM25 Retriever를 생성하는 함수
def create_bm25_retriever(splitted_docs):
bm25_retriever = BM25Retriever.from_documents(
splitted_docs,
)
bm25_retriever.k = 1 # 검색 결과 개수를 1개로 제한
print("bm25 retriever created")
return bm25_retriever
# 분할된 문서를 이용하여 FAISS Retriever를 생성하는 함수
def create_FAISS_retriever(splitted_docs):
embedding_function = OpenAIEmbeddings()
faiss_db = None
faiss_index_path = "data/retrievers/faiss_index" # FAISS 인덱스를 저장할 경로
index_faiss_file, index_pkl_file = os.path.join(faiss_index_path, "index.faiss"), os.path.join(faiss_index_path, "index.pkl")
if os.path.exists(index_faiss_file) and os.path.exists(index_pkl_file):
print("이미 FAISS index 존재")
faiss_db = FAISS.load_local(
faiss_index_path,
embeddings=embedding_function,
allow_dangerous_deserialization=True # 역직렬화를 허용
)
else:
print("새롭게 FAISS index 만들기")
faiss_db = FAISS.from_documents(splitted_docs, embedding=embedding_function)
faiss_db.save_local(faiss_index_path) # FAISS 인덱스를 로컬에 저장
faiss_retriever = faiss_db.as_retriever()
return faiss_retriever
# BM25와 FAISS를 결합하여 Ensemble Retriever를 생성하는 함수
def create_ensemble_retriever(retrievers):
ensemble_retriever = EnsembleRetriever(
retrievers=retrievers,
weights=[0.7, 0.3], # 각각의 가중치 설정 (BM25: 70%, FAISS: 30%)
)
print("Retriever created.")
return ensemble_retriever
# QA 체인을 생성하는 함수
def create_qa_chain(ensemble_retriever):
prompt = PromptTemplate.from_template(
"""You are an assistant for question-answering tasks.
Use the following pieces of retrieved context to answer the question.
Consider the intent behind the question to provide the most relevant and accurate response.
If you don't know the answer, just say this: ```해당 정보는 제공된 문서들에 포함되어 있지 않습니다.```.
If I ask you what you don't know and what you do know, answer what you know clearly and in detail.
Remember to compare the specific time in the question with the time mentioned in the context to determine the correct answer.
#Question:
{question}
#Context:
{context}
#Answer:"""
)
llm = ChatOpenAI(model_name="gpt-4o-mini", temperature=0)
multiquery_retriever = MultiQueryRetriever.from_llm(
retriever=ensemble_retriever,
llm=llm,
)
print("LLM created.")
rag_chain = (
{"context": multiquery_retriever, "question": RunnablePassthrough()}
| prompt
| llm
| StrOutputParser()
)
return rag_chain
# Few-Shot Learning을 적용하여 답변을 생성하는 함수
def apply_few_shot_learning(user_input):
# Few-Shot Learning을 사용하여 답변을 생성
response = few_shot_prompt.format(question=user_input)
return response
# RAG 기반의 문서 검색과 QA 체인을 사용하는 함수를 생성하는 메인 함수
def my_retriever(file_path):
ssl._create_default_https_context = ssl._create_unverified_context
load_dotenv() # 환경 변수 로드
documents = load_md_files(file_path) # MD 파일 로드
splitted_docs = split_docs(documents) # 문서 분할
bm25_retriever = create_bm25_retriever(splitted_docs) # BM25 생성
faiss_retriever = create_FAISS_retriever(splitted_docs) # FAISS 생성
ensemble_retriever = create_ensemble_retriever([bm25_retriever, faiss_retriever]) # Ensemble 생성
rag_chain = create_qa_chain(ensemble_retriever) # QA 체인 생성
def retrieve_answer(user_input):
response = rag_chain.invoke(user_input)
# RAG 결과가 적절하지 않거나 부족할 경우 Few-Shot Learning 적용
if response and not any(phrase in response for phrase in ["해당 정보는 제공된 문서들에 포함되어 있지 않습니다."]):
return response
else:
few_shot_response = apply_few_shot_learning(user_input)
# Few-Shot Learning으로도 답변을 생성하지 못했을 경우
if few_shot_response and not "해당 질문에 대한 답변을 찾을 수 없습니다" in few_shot_response:
return few_shot_response
else:
# 일반적인 LLM을 사용하여 기본적인 답변 생성
return generate_general_response(user_input)
return retrieve_answer
# 일반적인 LLM을 사용하여 기본적인 답변을 생성하는 함수
def generate_general_response(user_input):
# 여기서는 기본적으로 ChatGPT를 사용하여 답변을 생성할 수 있습니다.
llm = ChatOpenAI(model_name="gpt-4o-mini", temperature=0)
system_prompt = "You are a helpful assistant. Please provide a thoughtful response to the user's question."
# 사용자 질문을 기반으로 LLM이 응답 생성
response = llm({
"role": "system",
"content": system_prompt
}, {
"role": "user",
"content": user_input
})
return response['choices'][0]['message']['content']
# ==== test ======
"""if __name__ == "__main__":
retrieve_answer = my_retriever('data/ktb_data_09.md')
question = '8월에 며칠 이상 출석해야 훈련 장려금 받을 수 있어?'
print("question:\n", question)
response = retrieve_answer(question)
print('response:\n', response)
"""
퓨샷러닝
json예제 형식
app.py에서 적용하기
일부 함수 변경
퓨샷러닝에 쓸 데이터
코드에서 가져오지 않고 json으로 정리해서 json을 불러오는 방식으로 적용