kakaotech-bootcamp-11 / kakaotech-bootcamp-11-ktb-11-project-1-chatbot-nlp-server

0 stars 0 forks source link

# 퓨샷프롬프트 적용하기 #24

Open Oh-JunTaek opened 2 months ago

Oh-JunTaek commented 2 months ago

퓨샷러닝

# 기존 import에 추가할 부분
from langchain_core.prompts.few_shot import FewShotPromptTemplate

# Few-Shot Learning에 사용할 예제 데이터 로드 함수 추가
def load_examples_from_file(file_path):
    with open(file_path, 'r', encoding='utf-8') as file:
        data = json.load(file)
    return data['examples']

# Few-Shot Learning에 사용할 예제 데이터 로드
json_file_path = './few_shot_examples.json'  # 예제 데이터 파일 경로
examples = load_examples_from_file(json_file_path)

# Few-Shot Prompt Template 생성
few_shot_prompt = FewShotPromptTemplate(
    examples=examples,
    example_prompt="Question: {question}\nAnswer: {answer}",
    suffix="Question: {question}\nAnswer:",
    input_variables=["question"]
)

json예제 형식

{
    "examples": [
        {
            "question": "9시 6분에 출결하면 지각이야?",
            "answer": "아니요. 9시 10분 이후에 출결하면 지각입니다."
        },
        {
            "question": "지각 6번을 하면 어떻게 돼?",
            "answer": "지각 6번이면 2일 결석 처리됩니다."
        }
    ]
}

app.py에서 적용하기

jieun-lim commented 2 months ago

@Oh-JunTaek

수정 요청 사항

Oh-JunTaek commented 2 months ago

수정안

이렇게 하면 되는지 한 번 더 검토부탁드립니다 ㅜ

document_retriever.py

from langchain_community.document_loaders import TextLoader
from langchain_core.documents import Document
from langchain_text_splitters import MarkdownHeaderTextSplitter
from langchain_community.vectorstores import FAISS
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough
from langchain_core.prompts import PromptTemplate
from langchain_core.prompts.few_shot import FewShotPromptTemplate
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
from langchain_community.retrievers import BM25Retriever
from langchain.retrievers import EnsembleRetriever
from langchain.retrievers.multi_query import MultiQueryRetriever

import pprint
import json
import logging, os
import ssl
from dotenv import load_dotenv

# Few-Shot Learning에 사용할 예제 데이터를 JSON 파일에서 로드하는 함수
def load_examples_from_file(file_path):
    with open(file_path, 'r', encoding='utf-8') as file:
        data = json.load(file)
    return data['examples']

# Few-Shot Learning에 사용할 예제 데이터를 로드하고, Few-Shot Prompt Template을 생성
json_file_path = './few_shot_examples.json'  # 예제 데이터 파일 경로
examples = load_examples_from_file(json_file_path)

# Few-Shot Prompt Template 생성
few_shot_prompt = FewShotPromptTemplate(
    examples=examples,
    example_prompt="Question: {question}\nAnswer: {answer}",
    suffix="Question: {question}\nAnswer:",
    input_variables=["question"]
)

# 주어진 파일 경로에서 Markdown(.md) 파일을 로드하는 함수
def load_md_files(file_path):
    loader = TextLoader(file_path)
    documents = loader.load()  # 파일에서 문서를 로드
    print(f"Loaded {len(documents)} documents from the MD.")
    print("len(docs):", len(documents))
    return documents

# 로드한 문서를 Markdown 헤더 기준으로 분할하는 함수
def split_docs(documents):
    assert len(documents) == 1  # 문서가 하나만 로드되었는지 확인
    assert isinstance(documents[0], Document)  # 문서 객체인지 확인
    readme_content = documents[0].page_content

    headers_to_split_on = [
        ("#", "Header 1"),
        ("##", "Header 2"),
        ("###", "Header 3"),
    ]

    markdown_splitter = MarkdownHeaderTextSplitter(headers_to_split_on=headers_to_split_on, strip_headers=False)
    splitted_md = markdown_splitter.split_text(readme_content)
    return splitted_md

# 분할된 문서를 이용하여 BM25 Retriever를 생성하는 함수
def create_bm25_retriever(splitted_docs):
    bm25_retriever = BM25Retriever.from_documents(
        splitted_docs,
    )
    bm25_retriever.k = 1  # 검색 결과 개수를 1개로 제한
    print("bm25 retriever created")
    return bm25_retriever

# 분할된 문서를 이용하여 FAISS Retriever를 생성하는 함수
def create_FAISS_retriever(splitted_docs):
    embedding_function = OpenAIEmbeddings()
    faiss_db = None
    faiss_index_path = "data/retrievers/faiss_index"  # FAISS 인덱스를 저장할 경로
    index_faiss_file, index_pkl_file = os.path.join(faiss_index_path, "index.faiss"), os.path.join(faiss_index_path, "index.pkl")

    if os.path.exists(index_faiss_file) and os.path.exists(index_pkl_file):
        print("이미 FAISS index 존재")
        faiss_db = FAISS.load_local(
            faiss_index_path,
            embeddings=embedding_function,
            allow_dangerous_deserialization=True  # 역직렬화를 허용
        )
    else:
        print("새롭게 FAISS index 만들기")
        faiss_db = FAISS.from_documents(splitted_docs, embedding=embedding_function)
        faiss_db.save_local(faiss_index_path)  # FAISS 인덱스를 로컬에 저장
    faiss_retriever = faiss_db.as_retriever()
    return faiss_retriever

# BM25와 FAISS를 결합하여 Ensemble Retriever를 생성하는 함수
def create_ensemble_retriever(retrievers):
    ensemble_retriever = EnsembleRetriever(
        retrievers=retrievers,
        weights=[0.7, 0.3],  # 각각의 가중치 설정 (BM25: 70%, FAISS: 30%)
    )
    print("Retriever created.")
    return ensemble_retriever

# QA 체인을 생성하는 함수
def create_qa_chain(ensemble_retriever):
    prompt = PromptTemplate.from_template(
        """You are an assistant for question-answering tasks. 
        Use the following pieces of retrieved context to answer the question. 
        Consider the intent behind the question to provide the most relevant and accurate response. 
        If you don't know the answer, just say this: ```해당 정보는 제공된 문서들에 포함되어 있지 않습니다.```. 
        If I ask you what you don't know and what you do know, answer what you know clearly and in detail. 
        Remember to compare the specific time in the question with the time mentioned in the context to determine the correct answer.

        #Question: 
        {question} 
        #Context: 
        {context} 

        #Answer:"""
    )
    llm = ChatOpenAI(model_name="gpt-4o-mini", temperature=0)
    multiquery_retriever = MultiQueryRetriever.from_llm(
        retriever=ensemble_retriever,
        llm=llm,
    )

    print("LLM created.")
    rag_chain = (
        {"context": multiquery_retriever, "question": RunnablePassthrough()}
        | prompt
        | llm
        | StrOutputParser()
    )
    return rag_chain

# Few-Shot Learning을 적용하여 답변을 생성하는 함수
def apply_few_shot_learning(user_input):
    # Few-Shot Learning을 사용하여 답변을 생성
    response = few_shot_prompt.format(question=user_input)
    return response

# RAG 기반의 문서 검색과 QA 체인을 사용하는 함수를 생성하는 메인 함수
def my_retriever(file_path):
    ssl._create_default_https_context = ssl._create_unverified_context     
    load_dotenv()  # 환경 변수 로드

    documents = load_md_files(file_path)  # MD 파일 로드
    splitted_docs = split_docs(documents)  # 문서 분할
    bm25_retriever = create_bm25_retriever(splitted_docs)  # BM25 생성
    faiss_retriever = create_FAISS_retriever(splitted_docs)  # FAISS 생성
    ensemble_retriever = create_ensemble_retriever([bm25_retriever, faiss_retriever])  # Ensemble 생성
    rag_chain = create_qa_chain(ensemble_retriever)  # QA 체인 생성

    def retrieve_answer(user_input):
        response = rag_chain.invoke(user_input)
        # RAG 결과가 적절하지 않거나 부족할 경우 Few-Shot Learning 적용
        if response and not any(phrase in response for phrase in ["해당 정보는 제공된 문서들에 포함되어 있지 않습니다."]):
            return response
        else:
            few_shot_response = apply_few_shot_learning(user_input)
            # Few-Shot Learning으로도 답변을 생성하지 못했을 경우
            if few_shot_response and not "해당 질문에 대한 답변을 찾을 수 없습니다" in few_shot_response:
                return few_shot_response
            else:
                # 일반적인 LLM을 사용하여 기본적인 답변 생성
                return generate_general_response(user_input)

    return retrieve_answer

# 일반적인 LLM을 사용하여 기본적인 답변을 생성하는 함수
def generate_general_response(user_input):
    # 여기서는 기본적으로 ChatGPT를 사용하여 답변을 생성할 수 있습니다.
    llm = ChatOpenAI(model_name="gpt-4o-mini", temperature=0)
    system_prompt = "You are a helpful assistant. Please provide a thoughtful response to the user's question."

    # 사용자 질문을 기반으로 LLM이 응답 생성
    response = llm({
        "role": "system",
        "content": system_prompt
    }, {
        "role": "user",
        "content": user_input
    })

    return response['choices'][0]['message']['content']

# ==== test ======
"""if __name__ == "__main__":
    retrieve_answer = my_retriever('data/ktb_data_09.md')
    question = '8월에 며칠 이상 출석해야 훈련 장려금 받을 수 있어?'
    print("question:\n", question)
    response = retrieve_answer(question)
    print('response:\n', response)
"""

변경사항 및 주요 함수 설명

  1. load_examples_from_file : 예제 데이터를 json으로 불러옴. data양식은 이슈1 참고
  2. few_shot_prompt: 로드된 예제 데이터를 기반으로 Template을 생성. 질문에 대한 적절한 답변을 생성하는 데 사용.
  3. apply_few_shot_learning 함수 : Few-Shot Learning을 사용하여 질문에 대한 답변을 생성. RAG 기반 검색 결과가 충분하지 않을 때 사용.
  4. my_retriever 함수 : 전체 프로세스를 조합하여 RAG 기반의 문서 검색과 QA 체인을 생성. retrieve_answer 내부 함수에서 RAG 기반의 응답을 제공, RAG 결과가 부족할 경우 Few-Shot Learning을 적용. 퓨샷러닝도 실패할 경우 일반적인 답변 생성