tascheidt / jdmgpt

RAG based GPT based on open access Pubmed data related to Juvenile Dermatomyositis
MIT License
0 stars 0 forks source link

Test auto merging retriever #6

Open tascheidt opened 7 months ago

tascheidt commented 7 months ago

https://learn.deeplearning.ai/building-evaluating-advanced-rag/lesson/5/auto-merging-retrieval

import os

from llama_index import ( ServiceContext, StorageContext, VectorStoreIndex, load_index_from_storage, ) from llama_index.node_parser import HierarchicalNodeParser from llama_index.node_parser import get_leaf_nodes from llama_index import StorageContext, load_index_from_storage from llama_index.retrievers import AutoMergingRetriever from llama_index.indices.postprocessor import SentenceTransformerRerank from llama_index.query_engine import RetrieverQueryEngine

def build_automerging_index( documents, llm, embed_model="local:BAAI/bge-small-en-v1.5", save_dir="merging_index", chunk_sizes=None, ): chunk_sizes = chunk_sizes or [2048, 512, 128] node_parser = HierarchicalNodeParser.from_defaults(chunk_sizes=chunk_sizes) nodes = node_parser.get_nodes_from_documents(documents) leaf_nodes = get_leaf_nodes(nodes) merging_context = ServiceContext.from_defaults( llm=llm, embed_model=embed_model, ) storage_context = StorageContext.from_defaults() storage_context.docstore.add_documents(nodes)

if not os.path.exists(save_dir):
    automerging_index = VectorStoreIndex(
        leaf_nodes, storage_context=storage_context, service_context=merging_context
    )
    automerging_index.storage_context.persist(persist_dir=save_dir)
else:
    automerging_index = load_index_from_storage(
        StorageContext.from_defaults(persist_dir=save_dir),
        service_context=merging_context,
    )
return automerging_index

def get_automerging_query_engine( automerging_index, similarity_top_k=12, rerank_top_n=6, ): base_retriever = automerging_index.as_retriever(similarity_top_k=similarity_top_k) retriever = AutoMergingRetriever( base_retriever, automerging_index.storage_context, verbose=True ) rerank = SentenceTransformerRerank( top_n=rerank_top_n, model="BAAI/bge-reranker-base" ) auto_merging_engine = RetrieverQueryEngine.from_args( retriever, node_postprocessors=[rerank] ) return auto_merging_engine

new section

from llama_index.llms import OpenAI

index = build_automerging_index( [document], llm=OpenAI(model="gpt-3.5-turbo", temperature=0.1), save_dir="./merging_index", )

new section

query_engine = get_automerging_query_engine(index, similarity_top_k=6)