run-llama / llama_index

LlamaIndex is a data framework for your LLM applications
https://docs.llamaindex.ai
MIT License
36.53k stars 5.22k forks source link

在使用rebel提取实体的时候发生了下面的警告,且好像rebel并没有起到作用,在进行检索和问答时,我的图数据库中没有任何实体和关系,所以没有结果返回。 #14621

Closed li-hhhh closed 3 weeks ago

li-hhhh commented 4 months ago

Question Validation

Question

警告:DeprecationWarning: Call to deprecated class method from_defaults. (ServiceContext is deprecated, please use llama_index.settings.Settings instead.) -- Deprecated since version 0.10.0. service_context = ServiceContext.from_defaults(llm=llm,embed_model=embed_model,chunk_size=256) 代码如下: from llama_index.core import SimpleDirectoryReader,KnowledgeGraphIndex,StorageContext,Settings,ServiceContext import nest_asyncio

from llama_index.embeddings.huggingface import HuggingFaceEmbedding

from llama_index.llms.huggingface import HuggingFaceLLM from llama_index.core import PromptTemplate from llama_index.graph_stores.neo4j import Neo4jGraphStore import torch from transformers import pipeline

import yapf

nest_asyncio.apply()

def extracttriplets(text): triplets = [] relation, subject, relation, object = '', '', '', '' text = text.strip() current = 'x' for token in text.replace("", "").replace("", "").replace("", "").split(): if token == "": current = 't' if relation != '': triplets.append({'head': subject.strip(), 'type': relation.strip(),'tail': object.strip()}) relation = '' subject = '' elif token == "": current = 's' if relation != '': triplets.append({'head': subject.strip(), 'type': relation.strip(),'tail': object.strip()}) object = '' elif token == "": current = 'o' relation = '' else: if current == 't': subject += ' ' + token elif current == 's': object += ' ' + token elif current == 'o': relation += ' ' + token if subject != '' and relation != '' and object != '': triplets.append({'head': subject.strip(), 'type': relation.strip(),'tail': object.strip()}) return triplets

Load documents

documents = SimpleDirectoryReader('/home/cdhd/PythonCode/pythonProject/TestFile/公司主要人员架构类').load_data() print("documents",documents[0].text) model_config = {"protected_namespaces": ()}

SYSTEM_PROMPT = """You are a helpful AI assistant. """

query_wrapper_prompt = PromptTemplate( "[INST]<>\n" + SYSTEM_PROMPT + "<>\n\n{query_str}[/INST] " )

llm = HuggingFaceLLM( context_window=4096, max_new_tokens=1024, generate_kwargs={"temperature": 0.2, "do_sample": True}, query_wrapper_prompt=query_wrapper_prompt,

tokenizer_name='/home/cdhd/PythonCode/pythonProject/model/gpt2',

# model_name='/home/cdhd/PythonCode/pythonProject/model/gpt2',
tokenizer_name='/home/cdhd/Desktop/glm4/glm-4-9b-chat',
model_name='/home/cdhd/Desktop/glm4/glm-4-9b-chat',
device_map="auto",
# change these settings below depending on your GPU
model_kwargs={"torch_dtype": torch.float16},

)

Settings.llm = llm

triplet_extractor = pipeline('text2text-generation', model='/home/cdhd/PythonCode/pythonProject/model/rebel', tokenizer='/home/cdhd/PythonCode/pythonProject/model/rebel')

llm = AutoModel.from_pretrained('/home/cdhd/Desktop/glm4/glm-4-9b-chat', trust_remote_code=True)

We need to use the tokenizer manually since we need special tokens.

extracted_text = triplet_extractor.tokenizer.batch_decode([triplet_extractor(documents[0].text, return_tensors=True, return_text=False)[0]["generated_token_ids"]])

embed_model = HuggingFaceEmbedding( model_name='/home/cdhd/PythonCode/pythonProject/model/bge-large-zh-v1.5/models--BAAI--bge-large-zh-v1.5/snapshots/79e7739b6ab944e86d6171e44d24c997fc1e0116',

model_name = r"E:\Huggingface_model\bge-large-zh-v1.5",

# cache_folder=r"E:\Huggingface_model\bge-large-zh-v1.5\models--BAAI--bge-large-zh-v1.5\snapshots\79e7739b6ab944e86d6171e44d24c997fc1e0116",

)

Settings.embed_model = embed_model

graph_store = Neo4jGraphStore( username="neo4j", password="12345678", url="bolt://localhost:7687", database="neo4j"

) service_context = ServiceContext.from_defaults(llm=llm,embed_model=embed_model,chunk_size=256)

清理数据集

graph_store.query( """ MATCH (n) DETACH DELETE n """ )

创建存储上下文

storage_context = StorageContext.from_defaults(graph_store=graph_store) extracted_triplets = extract_triplets(extracted_text[0])

构建包含嵌入的索引

index = KnowledgeGraphIndex.from_documents( documents, kg_triplet_extract_fn=extract_triplets, service_context=service_context, storage_context=storage_context, max_triplets_per_chunk=2, include_embeddings=True, embed_model=embed_model, llm=llm, ) retriever = index.as_retriever() ans = retriever.retrieve("公司的董事长是谁") print(ans) for node in ans: print(node.text)

query_engine = index.as_query_engine( include_text = True, response_mode="tree_summarize", llm=llm )

response = query_engine.query( "公司的董事长是谁?" ) print("response:",response)

dosubot[bot] commented 4 months ago

要解决弃用警告并确保正确提取和存储实体和关系,请移除 ServiceContext 并使用 Settings 进行配置。以下是更新后的代码示例:

from llama_index.core import SimpleDirectoryReader, KnowledgeGraphIndex, StorageContext, Settings
import nest_asyncio
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
from llama_index.llms.huggingface import HuggingFaceLLM
from llama_index.core import PromptTemplate
from llama_index.graph_stores.neo4j import Neo4jGraphStore
import torch
from transformers import pipeline

nest_asyncio.apply()

def extract_triplets(text):
    triplets = []
    relation, subject, relation, object_ = '', '', '', ''
    text = text.strip()
    current = 'x'
    for token in text.replace("<s>", "").replace("<pad>", "").replace("</s>", "").split():
        if token == "<triplet>":
            current = 't'
            if relation != '':
                triplets.append({'head': subject.strip(), 'type': relation.strip(), 'tail': object_.strip()})
                relation = ''
            subject = ''
        elif token == "<subj>":
            current = 's'
            if relation != '':
                triplets.append({'head': subject.strip(), 'type': relation.strip(), 'tail': object_.strip()})
            object_ = ''
        elif token == "<obj>":
            current = 'o'
            relation = ''
        else:
            if current == 't':
                subject += ' ' + token
            elif current == 's':
                object_ += ' ' + token
            elif current == 'o':
                relation += ' ' + token
    if subject != '' and relation != '' and object_ != '':
        triplets.append({'head': subject.strip(), 'type': relation.strip(), 'tail': object_.strip()})
    return triplets

# Load documents
documents = SimpleDirectoryReader('/home/cdhd/PythonCode/pythonProject/TestFile/公司主要人员架构类').load_data()
print("documents", documents[0].text)

SYSTEM_PROMPT = """You are a helpful AI assistant.
"""

query_wrapper_prompt = PromptTemplate(
    "[INST]<<SYS>>\n" + SYSTEM_PROMPT + "<</SYS>>\n\n{query_str}[/INST] "
)

llm = HuggingFaceLLM(
    context_window=4096,
    max_new_tokens=1024,
    generate_kwargs={"temperature": 0.2, "do_sample": True},
    query_wrapper_prompt=query_wrapper_prompt,
    tokenizer_name='/home/cdhd/Desktop/glm4/glm-4-9b-chat',
    model_name='/home/cdhd/Desktop/glm4/glm-4-9b-chat',
    device_map="auto",
    model_kwargs={"torch_dtype": torch.float16},
)

Settings.llm = llm

triplet_extractor = pipeline('text2text-generation', model='/home/cdhd/PythonCode/pythonProject/model/rebel', tokenizer='/home/cdhd/PythonCode/pythonProject/model/rebel')
extracted_text = triplet_extractor.tokenizer.batch_decode([triplet_extractor(documents[0].text, return_tensors=True, return_text=False)[0]["generated_token_ids"]])

embed_model = HuggingFaceEmbedding(
    model_name='/home/cdhd/PythonCode/pythonProject/model/bge-large-zh-v1.5/models--BAAI--bge-large-zh-v1.5/snapshots/79e7739b6ab944e86d6171e44d24c997fc1e0116',
)

Settings.embed_model = embed_model

graph_store = Neo4jGraphStore(
    username="neo4j",
    password="12345678",
    url="bolt://localhost:7687",
    database="neo4j"
)

# 清理数据集
graph_store.query(
    """
MATCH (n) DETACH DELETE n
"""
)

# 创建存储上下文
storage_context = StorageContext.from_defaults(graph_store=graph_store)
extracted_triplets = extract_triplets(extracted_text[0])

# 构建包含嵌入的索引
index = KnowledgeGraphIndex.from_documents(
    documents,
    kg_triplet_extract_fn=extract_triplets,
    storage_context=storage_context,
    max_triplets_per_chunk=2,
    include_embeddings=True,
    embed_model=embed_model,
    llm=llm,
)

retriever = index.as_retriever()
ans = retriever.retrieve("公司的董事长是谁")
print(ans)
for node in ans:
    print(node.text)

query_engine = index.as_query_engine(
    include_text=True,
    response_mode="tree_summarize",
    llm=llm
)

response = query_engine.query(
    "公司的董事长是谁?"
)
print("response:", response)

此更新移除了 ServiceContext 并使用 Settings 进行配置,确保实体和关系正确提取并存储在图数据库中[1][2].

To continue talking to Dosu, mention @dosu.

li-hhhh commented 4 months ago

要解决弃用警告并确保正确提取和存储实体和关系,请移除 ServiceContext 并使用 Settings 进行配置。以下是更新后的代码示例:

from llama_index.core import SimpleDirectoryReader, KnowledgeGraphIndex, StorageContext, Settings
import nest_asyncio
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
from llama_index.llms.huggingface import HuggingFaceLLM
from llama_index.core import PromptTemplate
from llama_index.graph_stores.neo4j import Neo4jGraphStore
import torch
from transformers import pipeline

nest_asyncio.apply()

def extract_triplets(text):
    triplets = []
    relation, subject, relation, object_ = '', '', '', ''
    text = text.strip()
    current = 'x'
    for token in text.replace("<s>", "").replace("<pad>", "").replace("</s>", "").split():
        if token == "<triplet>":
            current = 't'
            if relation != '':
                triplets.append({'head': subject.strip(), 'type': relation.strip(), 'tail': object_.strip()})
                relation = ''
            subject = ''
        elif token == "<subj>":
            current = 's'
            if relation != '':
                triplets.append({'head': subject.strip(), 'type': relation.strip(), 'tail': object_.strip()})
            object_ = ''
        elif token == "<obj>":
            current = 'o'
            relation = ''
        else:
            if current == 't':
                subject += ' ' + token
            elif current == 's':
                object_ += ' ' + token
            elif current == 'o':
                relation += ' ' + token
    if subject != '' and relation != '' and object_ != '':
        triplets.append({'head': subject.strip(), 'type': relation.strip(), 'tail': object_.strip()})
    return triplets

# Load documents
documents = SimpleDirectoryReader('/home/cdhd/PythonCode/pythonProject/TestFile/公司主要人员架构类').load_data()
print("documents", documents[0].text)

SYSTEM_PROMPT = """You are a helpful AI assistant.
"""

query_wrapper_prompt = PromptTemplate(
    "[INST]<<SYS>>\n" + SYSTEM_PROMPT + "<</SYS>>\n\n{query_str}[/INST] "
)

llm = HuggingFaceLLM(
    context_window=4096,
    max_new_tokens=1024,
    generate_kwargs={"temperature": 0.2, "do_sample": True},
    query_wrapper_prompt=query_wrapper_prompt,
    tokenizer_name='/home/cdhd/Desktop/glm4/glm-4-9b-chat',
    model_name='/home/cdhd/Desktop/glm4/glm-4-9b-chat',
    device_map="auto",
    model_kwargs={"torch_dtype": torch.float16},
)

Settings.llm = llm

triplet_extractor = pipeline('text2text-generation', model='/home/cdhd/PythonCode/pythonProject/model/rebel', tokenizer='/home/cdhd/PythonCode/pythonProject/model/rebel')
extracted_text = triplet_extractor.tokenizer.batch_decode([triplet_extractor(documents[0].text, return_tensors=True, return_text=False)[0]["generated_token_ids"]])

embed_model = HuggingFaceEmbedding(
    model_name='/home/cdhd/PythonCode/pythonProject/model/bge-large-zh-v1.5/models--BAAI--bge-large-zh-v1.5/snapshots/79e7739b6ab944e86d6171e44d24c997fc1e0116',
)

Settings.embed_model = embed_model

graph_store = Neo4jGraphStore(
    username="neo4j",
    password="12345678",
    url="bolt://localhost:7687",
    database="neo4j"
)

# 清理数据集
graph_store.query(
    """
MATCH (n) DETACH DELETE n
"""
)

# 创建存储上下文
storage_context = StorageContext.from_defaults(graph_store=graph_store)
extracted_triplets = extract_triplets(extracted_text[0])

# 构建包含嵌入的索引
index = KnowledgeGraphIndex.from_documents(
    documents,
    kg_triplet_extract_fn=extract_triplets,
    storage_context=storage_context,
    max_triplets_per_chunk=2,
    include_embeddings=True,
    embed_model=embed_model,
    llm=llm,
)

retriever = index.as_retriever()
ans = retriever.retrieve("公司的董事长是谁")
print(ans)
for node in ans:
    print(node.text)

query_engine = index.as_query_engine(
    include_text=True,
    response_mode="tree_summarize",
    llm=llm
)

response = query_engine.query(
    "公司的董事长是谁?"
)
print("response:", response)

此更新移除了 ServiceContext 并使用 Settings 进行配置,确保实体和关系正确提取并存储在图数据库中[1][2].

To continue talking to Dosu, mention @dosu.

我这个程序执行,rebel依然没有成功抽取出实体,neo4j图数据中依然没有任何实体和关系,这是怎么回事儿,有其他的解决方法吗?

dosubot[bot] commented 4 months ago

rebel 依然没有成功抽取出实体,Neo4j 图数据库中依然没有任何实体和关系,可能是因为以下几个原因:

  1. 抽取函数的问题extract_triplets 函数可能没有正确解析输入文本中的实体和关系。
  2. 数据过滤问题:如果使用了 extract_triplets_wiki 函数,可能是因为过滤器 WikiFilter 没有找到匹配的维基百科页面,导致所有的关系都被过滤掉了。
  3. 数据加载问题:从 ArxivReader 加载的数据可能不包含任何可识别的实体和关系。

可以尝试以下解决方法:

  1. 检查 extract_triplets 函数:确保该函数能够正确解析输入文本并返回有效的三元组。
  2. 检查 WikiFilter:确保过滤器能够正确连接到维基百科并返回有效的实体。
  3. 检查输入数据:确保从 ArxivReader 加载的数据包含可识别的实体和关系。

以下是 extract_triplets 函数的代码,可以检查并调试:

def extract_triplets(input_text):
    text = triplet_extractor.tokenizer.batch_decode(
        [
            triplet_extractor(
                input_text, return_tensors=True, return_text=False
            )[0]["generated_token_ids"]
        ]
    )[0]

    triplets = []
    relation, subject, relation, object_ = "", "", "", ""
    text = text.strip()
    current = "x"
    for token in (
        text.replace("<s>", "")
        .replace("<pad>", "")
        .replace("</s>", "")
        .split()
    ):
        if token == "<triplet>":
            current = "t"
            if relation != "":
                triplets.append(
                    (subject.strip(), relation.strip(), object_.strip())
                )
                relation = ""
            subject = ""
        elif token == "<subj>":
            current = "s"
            if relation != "":
                triplets.append(
                    (subject.strip(), relation.strip(), object_.strip())
                )
            object_ = ""
        elif token == "<obj>":
            current = "o"
            relation = ""
        else:
            if current == "t":
                subject += " " + token
            elif current == "s":
                object_ += " " + token
            elif current == "o":
                relation += " " + token

    if subject != "" and relation != "" and object_ != "":
        triplets.append((subject.strip(), relation.strip(), object_.strip()))

    return triplets

如果使用了 extract_triplets_wiki 函数,可以检查以下代码:

def extract_triplets_wiki(text):
    relations = extract_triplets(text)

    filtered_relations = []
    for relation in relations:
        (subj, rel, obj) = relation
        filtered_subj = wiki_filter.filter(subj)
        filtered_obj = wiki_filter.filter(obj)

        # skip if at least one entity not linked to wiki
        if filtered_subj is None and filtered_obj is None:
            continue

        filtered_relations.append(
            (
                filtered_subj or subj,
                rel,
                filtered_obj or obj,
            )
        )

    return filtered_relations

确保 WikiFilter 能够正确工作:

class WikiFilter:
    def __init__(self):
        self.cache = {}

    def filter(self, candidate_entity):
        # check the cache to avoid network calls
        if candidate_entity in self.cache:
            return self.cache[candidate_entity]["title"]

        # pull the page from wikipedia -- if it exists
        try:
            page = wikipedia.page(candidate_entity, auto_suggest=False)
            entity_data = {
                "title": page.title,
                "url": page.url,
                "summary": page.summary,
            }

            # cache the page title and original entity
            self.cache[candidate_entity] = entity_data
            self.cache[page.title] = entity_data

            return entity_data["title"]
        except:
            return None

如果以上检查都没有问题,建议检查从 ArxivReader 加载的数据是否包含有效的文本[1]

To continue talking to Dosu, mention @dosu.