microsoft / graphrag

A modular graph-based Retrieval-Augmented Generation (RAG) system
https://microsoft.github.io/graphrag/
MIT License
18.98k stars 1.86k forks source link

Bug: Could not automatically map llama3 to a tokeniser. Please use `tiktoken.get_encoding` to explicitly get the tokeniser you expect. #365

Closed bmaltais closed 3 months ago

bmaltais commented 4 months ago

When trying to use the graphrag.prompt_tune with python -m graphrag.prompt_tune --root . --no-entity-types using the following settings.yaml:


encoding_model: cl100k_base
skip_workflows: []
llm:
  api_key: ${GRAPHRAG_API_KEY}
  type: openai_chat # or azure_openai_chat
  model: llama3
  model_supports_json: true # recommended if this is available for your model.
  max_tokens: 1500
  # request_timeout: 180.0
  api_base: http://localhost:11434/v1
  # api_version: 2024-02-15-preview
  # organization: <organization_id>
  # deployment_name: <azure_model_deployment_name>
  # tokens_per_minute: 150_000 # set a leaky bucket throttle
  # requests_per_minute: 10_000 # set a leaky bucket throttle
  max_retries: 2
  # max_retry_wait: 10.0
  # sleep_on_rate_limit_recommendation: true # whether to sleep when azure suggests wait-times
  concurrent_requests: 1 # the number of parallel inflight requests that may be made

parallelization:
  stagger: 0.3
  # num_threads: 50 # the number of threads to use for parallel processing

async_mode: threaded # or asyncio

embeddings:
  ## parallelization: override the global parallelization settings for embeddings
  async_mode: threaded # or asyncio
  llm:
    api_key: ${GRAPHRAG_API_KEY}
    type: openai_embedding # or azure_openai_embedding
    model: text-embedding-3-small
    api_base: https://api.openai.com/v1
    # api_version: 2024-02-15-preview
    # organization: <organization_id>
    # deployment_name: <azure_model_deployment_name>
    # tokens_per_minute: 150_000 # set a leaky bucket throttle
    # requests_per_minute: 10_000 # set a leaky bucket throttle
    max_retries: 1
    # max_retry_wait: 10.0
    # sleep_on_rate_limit_recommendation: true # whether to sleep when azure suggests wait-times
    concurrent_requests: 1 # the number of parallel inflight requests that may be made
    batch_size: 1 # the number of documents to send in a single request
    batch_max_tokens: 8191 # the maximum number of tokens to send in a single request
    # target: required # or optional

chunks:
  size: 300
  overlap: 100
  group_by_columns: [id] # by default, we don't allow chunks to cross documents

input:
  type: file # or blob
  file_type: text # or csv
  base_dir: "input"
  file_encoding: utf-8
  file_pattern: ".*\\.txt$"

cache:
  type: file # or blob
  base_dir: "cache"
  # connection_string: <azure_blob_storage_connection_string>
  # container_name: <azure_blob_storage_container_name>

storage:
  type: file # or blob
  base_dir: "output/${timestamp}/artifacts"
  # connection_string: <azure_blob_storage_connection_string>
  # container_name: <azure_blob_storage_container_name>

reporting:
  type: file # or console, blob
  base_dir: "output/${timestamp}/reports"
  # connection_string: <azure_blob_storage_connection_string>
  # container_name: <azure_blob_storage_container_name>

entity_extraction:
  ## llm: override the global llm settings for this task
  ## parallelization: override the global parallelization settings for this task
  ## async_mode: override the global async_mode settings for this task
  prompt: "prompts/entity_extraction.txt"
  entity_types: [organization,person,geo,event]
  max_gleanings: 0

summarize_descriptions:
  ## llm: override the global llm settings for this task
  ## parallelization: override the global parallelization settings for this task
  ## async_mode: override the global async_mode settings for this task
  prompt: "prompts/summarize_descriptions.txt"
  max_length: 500

claim_extraction:
  ## llm: override the global llm settings for this task
  ## parallelization: override the global parallelization settings for this task
  ## async_mode: override the global async_mode settings for this task
  # enabled: true
  prompt: "prompts/claim_extraction.txt"
  description: "Any claims or facts that could be relevant to information discovery."
  max_gleanings: 0

community_report:
  ## llm: override the global llm settings for this task
  ## parallelization: override the global parallelization settings for this task
  ## async_mode: override the global async_mode settings for this task
  prompt: "prompts/community_report.txt"
  max_length: 2000
  max_input_length: 4000

cluster_graph:
  max_cluster_size: 10

embed_graph:
  enabled: false # if true, will generate node2vec embeddings for nodes
  # num_walks: 10
  # walk_length: 40
  # window_size: 2
  # iterations: 3
  # random_seed: 597832

umap:
  enabled: false # if true, will generate UMAP embeddings for nodes

snapshots:
  graphml: true
  raw_entities: false
  top_level_nodes: false

local_search:
  # text_unit_prop: 0.5
  # community_prop: 0.1
  # conversation_history_max_turns: 5
  # top_k_mapped_entities: 10
  # top_k_relationships: 10
  max_tokens: 2000

global_search:
  max_tokens: 5000
  data_max_tokens: 5000
  map_max_tokens: 1000
  reduce_max_tokens: 2000
  concurrency: 1

I get:

INFO: Reading settings from settings.yaml

Loading Input (text).
INFO: Generating domain...

INFO: Generated domain: I'd be happy to help you analyze the information in this text document!

The document appears to be a cloud governance standard for object naming, specifically for Azure. It provides guidelines and conventions for naming objects such as resources, storage accounts, queues, and tables.

Here are some key takeaways:

1. **Naming Convention**: The standard recommends using a specific format for naming objects, which includes:
        * Environment (e.g., "Azure")
        * CSP region (e.g., "Canada")
        * Device Type (e.g., "VM")
        * User-defined string
        * Suffix (e.g., "-mycontainername")
2. **Allowed Characters**: The standard specifies that the following characters are allowed in object names:
        * Alphanumeric characters
        * Hyphen (-)
        * @ symbol
3. **Length and Casing**: The standard provides guidelines for length and casing of object names, including:
        * Length: 1-99 characters (e.g., for resource group and resource names)
        * Case-insensitive (e.g., for table names)
4. **Object Naming Tables**: The document includes several tables that provide specific naming conventions for different types of objects, such as:
        * Resource Group and Resource
        * Storage Account
        * Queue name
        * Table name
5. **Automation and Accounting**: The standard emphasizes the importance of automation and accounting in cloud governance, including:
        * Automation: enabling IT to manage resources more efficiently
        * Accounting: supporting chargeback/showback accounting by organizing cloud resources

Overall, this document provides a comprehensive framework for naming objects in Azure, with specific guidelines for different types of objects. It also highlights the importance of automation and accounting in cloud governance.

Would you like me to help you analyze any specific aspects of this document or provide recommendations for implementing this standard?

INFO: Generating persona...

INFO: Generated persona: You are an expert Data Analyst. You are skilled at extracting insights from complex data sets, identifying patterns and relationships, and creating visualizations to communicate findings. You are adept at helping people with analyzing large datasets, identifying trends and correlations, and providing actionable recommendations for decision-making.  

INFO: Generating entity relationship examples...

INFO: Done generating entity relationship examples

INFO: Generating entity extraction prompt...
Traceback (most recent call last):
  File "C:\Users\berna\AppData\Local\Programs\Python\Python310\lib\runpy.py", line 196, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "C:\Users\berna\AppData\Local\Programs\Python\Python310\lib\runpy.py", line 86, in _run_code
    exec(code, run_globals)
  File "H:\llm_stuff\graphrag\venv\lib\site-packages\graphrag\prompt_tune\__main__.py", line 100, in <module>
    loop.run_until_complete(
  File "C:\Users\berna\AppData\Local\Programs\Python\Python310\lib\asyncio\base_events.py", line 649, in run_until_complete
    return future.result()
  File "H:\llm_stuff\graphrag\venv\lib\site-packages\graphrag\prompt_tune\cli.py", line 59, in fine_tune
    await fine_tune_with_config(
  File "H:\llm_stuff\graphrag\venv\lib\site-packages\graphrag\prompt_tune\cli.py", line 127, in fine_tune_with_config
    await generate_indexing_prompts(
  File "H:\llm_stuff\graphrag\venv\lib\site-packages\graphrag\prompt_tune\cli.py", line 194, in generate_indexing_prompts
    create_entity_extraction_prompt(
  File "H:\llm_stuff\graphrag\venv\lib\site-packages\graphrag\prompt_tune\generator\entity_extraction_prompt.py", line 77, in create_entity_extraction_prompt
    example_tokens = num_tokens_from_string(example_formatted, model=model_name)
  File "H:\llm_stuff\graphrag\venv\lib\site-packages\graphrag\index\utils\tokens.py", line 16, in num_tokens_from_string
    encoding = tiktoken.encoding_for_model(model)
  File "H:\llm_stuff\graphrag\venv\lib\site-packages\tiktoken\model.py", line 103, in encoding_for_model
    return get_encoding(encoding_name_for_model(model_name))
  File "H:\llm_stuff\graphrag\venv\lib\site-packages\tiktoken\model.py", line 90, in encoding_name_for_model
    raise KeyError(
KeyError: 'Could not automatically map llama3 to a tokeniser. Please use `tiktoken.get_encoding` to explicitly get the tokeniser you expect.'

Look like the use of a local model via ollama is not expected in the code.

bmaltais commented 4 months ago

I have been able to fix this with the following code update to the entity_extraction_prompt.py:

# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License

"""Entity Extraction prompt generator module."""

from pathlib import Path

# from graphrag.index.utils.tokens import num_tokens_from_string
from graphrag.prompt_tune.template import (
    EXAMPLE_EXTRACTION_TEMPLATE,
    GRAPH_EXTRACTION_JSON_PROMPT,
    GRAPH_EXTRACTION_PROMPT,
    UNTYPED_EXAMPLE_EXTRACTION_TEMPLATE,
    UNTYPED_GRAPH_EXTRACTION_PROMPT,
)

ENTITY_EXTRACTION_FILENAME = "entity_extraction.txt"

import tiktoken

def num_tokens_from_string(string: str, model: str) -> int:
    """Returns the number of tokens in a text string."""
    try:
        encoding = tiktoken.encoding_for_model(model)
    except KeyError:
        encoding = tiktoken.get_encoding("gpt2")

    num_tokens = len(encoding.encode(string))
    return num_tokens

def create_entity_extraction_prompt(
    entity_types: str | list[str] | None,
    docs: list[str],
    examples: list[str],
    model_name: str,
    max_token_count: int,
    json_mode: bool = False,
    output_path: Path | None = None,
) -> str:
    """
    Create a prompt for entity extraction.

    Parameters
    ----------
    - entity_types (str | list[str]): The entity types to extract
    - docs (list[str]): The list of documents to extract entities from
    - examples (list[str]): The list of examples to use for entity extraction
    - model_name (str): The name of the model to use for token counting
    - max_token_count (int): The maximum number of tokens to use for the prompt
    - json_mode (bool): Whether to use JSON mode for the prompt. Default is False
    - output_path (Path | None): The path to write the prompt to. Default is None. If None, the prompt is not written to a file. Default is None.

    Returns
    -------
    - str: The entity extraction prompt
    """
    prompt = (
        (GRAPH_EXTRACTION_JSON_PROMPT if json_mode else GRAPH_EXTRACTION_PROMPT)
        if entity_types
        else UNTYPED_GRAPH_EXTRACTION_PROMPT
    )
    if isinstance(entity_types, list):
        entity_types = ", ".join(entity_types)

    tokens_left = (
        max_token_count
        - num_tokens_from_string(prompt, model=model_name)
        - num_tokens_from_string(entity_types, model=model_name)
        if entity_types
        else 0
    )

    examples_prompt = ""

    # Iterate over examples, while we have tokens left or examples left
    for i, output in enumerate(examples):
        input = docs[i]
        example_formatted = (
            EXAMPLE_EXTRACTION_TEMPLATE.format(
                n=i + 1, input_text=input, entity_types=entity_types, output=output
            )
            if entity_types
            else UNTYPED_EXAMPLE_EXTRACTION_TEMPLATE.format(
                n=i + 1, input_text=input, output=output
            )
        )

        example_tokens = num_tokens_from_string(example_formatted, model=model_name)

        # Squeeze in at least one example
        if i > 0 and example_tokens > tokens_left:
            break

        examples_prompt += example_formatted
        tokens_left -= example_tokens

    prompt = (
        prompt.format(entity_types=entity_types, examples=examples_prompt)
        if entity_types
        else prompt.format(examples=examples_prompt)
    )

    if output_path:
        output_path.mkdir(parents=True, exist_ok=True)

        output_path = output_path / ENTITY_EXTRACTION_FILENAME
        # Write file to output path
        with output_path.open("w", encoding="utf-8") as file:
            file.write(prompt)

    return prompt

Should I make a PR to fix the code in the project? Or would you address this in a different manner... My fix might be a bit on the rough side ;-)

You will notice I also set the encoding to utf-8 because I ran in an error writing the prompt output without that.

AlonsoGuevara commented 4 months ago

@bmaltais I'm about to submit a PR that will address this using the encoding_model setting from the settings.yaml file

AlonsoGuevara commented 3 months ago

Closing as code is now merged

sichaolong commented 3 months ago

需要修改graphrag的源码,将

import tiktoken
tiktoken.model.MODEL_TO_ENCODING["你的LLM,如llama3.1"] = "cl100k_base"

加到

graphrag/prompt_tune/main.py 文件中

我的是这样解决的。