gusye1234 / nano-graphrag

A simple, easy-to-hack GraphRAG implementation
MIT License
1.7k stars 164 forks source link

Empty Entity extraction results from Llama 3.1 8B #15

Closed NumberChiffre closed 2 months ago

NumberChiffre commented 3 months ago

Description

Hey guys,

I'm trying out the repo with the same structure as the DeepSeek example with local llama 3.1 8B as both the best and cheap model. The problem is, I get empty dictionaries from the entity extraction results. I thought it was a token context problem, so I even reduced the max token size down to 4096 and even 1024 with the same outcome, so they were probably not the cause. I need your help to figure out why this is not working with llama 3.1 8B, so far I'm thinking the cause could be a combination of these:

I can confirm that this works with either GPT-4 or DeepSeek-v2 chat. Do we need some kind of prompt format specifically for smaller models?

Updates:

Simplifying the entity extraction prompt to something super simple for llama 3.1 8B did not work:

entity_extract_prompt = """
Extract entities from the following text. For each entity, provide:
1. Entity type (PERSON, LOCATION, ORGANIZATION, etc.)
2. Entity name
3. Brief description

Text: {input_text}

Entities:
"""

Error output:

Below is the error that I got with llama 3.1 8B:

DEBUG:nano-graphrag:Entity extraction results: [({}, {}), ({}, {}), ({}, {}), ({}, {}), ({}, {}), ({}, {}), ({}, {}), ({}, {}), ({}, {}), ({}, {}), ({}, {}), ({}, {}), ({}, {}), ({}, {}), ({}, {}), ({}, {}), ({}, {}), ({}, {}), ({}, {}), ({}, {}), ({}, {}), ({}, {}), ({}, {}), ({}, {}), ({}, {}), ({}, {}), ({}, {}), ({}, {}), ({}, {}), ({}, {}), ({}, {}), ({}, {}), ({}, {}), ({}, {}), ({}, {}), ({}, {}), ({}, {}), ({}, {}), ({}, {}), ({}, {}), ({}, {}), ({}, {})]
INFO:nano-graphrag:Inserting 0 vectors to entities
Traceback (most recent call last):
  File "/Users/tiliu/Documents/nano-graphrag/examples/using_ollama_as_llm.py", line 103, in <module>
    insert(text=text)
  File "/Users/tiliu/Documents/nano-graphrag/examples/using_ollama_as_llm.py", line 95, in insert
    rag.insert(text)
  File "/Users/tiliu/Documents/nano-graphrag/nano_graphrag/graphrag.py", line 145, in insert
    return loop.run_until_complete(self.ainsert(string_or_strings))
  File "/opt/homebrew/Caskroom/miniconda/base/envs/nano-graphrag/lib/python3.10/asyncio/base_events.py", line 649, in run_until_complete
    return future.result()
  File "/Users/tiliu/Documents/nano-graphrag/nano_graphrag/graphrag.py", line 226, in ainsert
    self.chunk_entity_relation_graph = await extract_entities(
  File "/Users/tiliu/Documents/nano-graphrag/nano_graphrag/_op.py", line 335, in extract_entities
    await entity_vdb.upsert(data_for_vdb)
  File "/Users/tiliu/Documents/nano-graphrag/nano_graphrag/_storage.py", line 108, in upsert
    embeddings = np.concatenate(embeddings_list)
ValueError: need at least one array to concatenate

Code to reproduce the error:

import os
import logging
from ollama import AsyncClient
from nano_graphrag import GraphRAG, QueryParam
from nano_graphrag.base import BaseKVStorage
from nano_graphrag._utils import compute_args_hash

logging.basicConfig(level=logging.WARNING)
logging.getLogger("nano-graphrag").setLevel(logging.DEBUG)

os.environ["OPENAI_API_KEY"] = "sk-......"
OLLAMA_MODEL = "llama3.1"
WORKING_DIR = "./nano_graphrag_cache_ollama_TEST"

async def ollama_model_if_cache(
    prompt: str, system_prompt: str = None, history_messages: list = [], **kwargs
) -> str:
    messages = []
    if system_prompt:
        messages.append({"role": "system", "content": system_prompt})

    # Get the cached response if having-------------------
    hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
    messages.extend(history_messages)
    messages.append({"role": "user", "content": prompt})
    if hashing_kv is not None:
        args_hash = compute_args_hash(OLLAMA_MODEL, messages)
        if_cache_return = await hashing_kv.get_by_id(args_hash)
        if if_cache_return is not None:
            return if_cache_return["return"]
    # -----------------------------------------------------

    client = AsyncClient()
    response = await client.chat(model=OLLAMA_MODEL, messages=messages)
    content = response['message']['content']

    # Cache the response if having-------------------
    if hashing_kv is not None:
        await hashing_kv.upsert(
            {args_hash: {"return": content, "model": OLLAMA_MODEL}}
        )
    # -----------------------------------------------------
    return content

def remove_if_exist(file):
    if os.path.exists(file):
        os.remove(file)

def load_files(file_directory: str) -> list[str]:
    file_paths = [os.path.join(file_directory, file) for file in os.listdir(file_directory)]
    contents = []
    for file_path in file_paths:
        if os.path.exists(file_path):
            with open(file_path, 'r', encoding='utf-8-sig') as file:
                contents.append(file.read().strip())
        else:
            print(f"Warning: File not found - {file_path}")
    return contents

def query(query: str, param: QueryParam):
    rag = GraphRAG(
        working_dir=WORKING_DIR,
        best_model_func=ollama_model_if_cache,
        cheap_model_func=ollama_model_if_cache,
        best_model_max_token_size=4096,
        best_model_max_async=8,
        cheap_model_max_token_size=4096,
        cheap_model_max_async=8,
    )
    print(rag.query(query=query, param=param))

def insert(text: str | list[str]):
    from time import time
    remove_if_exist(f"{WORKING_DIR}/milvus_lite.db")
    remove_if_exist(f"{WORKING_DIR}/kv_store_full_docs.json")
    remove_if_exist(f"{WORKING_DIR}/kv_store_text_chunks.json")
    remove_if_exist(f"{WORKING_DIR}/kv_store_community_reports.json")
    remove_if_exist(f"{WORKING_DIR}/graph_chunk_entity_relation.graphml")
    rag = GraphRAG(
        working_dir=WORKING_DIR,
        enable_llm_cache=True,
        best_model_func=ollama_model_if_cache,
        cheap_model_func=ollama_model_if_cache,
        best_model_max_token_size=1024,
        best_model_max_async=2,
        cheap_model_max_token_size=1024,
        cheap_model_max_async=2,
    )
    start = time()
    rag.insert(text)
    print("indexing time:", time() - start)

if __name__ == "__main__":
    with open("./tests/mock_data.txt", encoding="utf-8-sig") as f:
        text = f.read()
    insert(text=text)
    # query(
    #     query="What are the main themes in these documents?",
    #     param=QueryParam(mode="global"),
    # )
gusye1234 commented 3 months ago

Yeah, I agree with you on the specific prompts for smaller models. Many developers have said that smaller models like qwen2-7B have troubles on extracting entities and relations, I added a FAQ.md to claim this problem.