abetlen / llama-cpp-python

Python bindings for llama.cpp
https://llama-cpp-python.readthedocs.io
MIT License
8.12k stars 967 forks source link

Inference Speed is Extremely Slow for 72B Model with Long Contexts #1767

Open wrench1997 opened 1 month ago

wrench1997 commented 1 month ago

Description:

When running inference on a 72B model with long context lengths(40960), the process is extremely slow, taking approximately 40 minutes to generate results. However, using a standard transformer package, the same task takes only about 5 minutes.

Details:

Steps to Reproduce:

  1. Set up the 72B model with long context inputs and enforce multiple fields in JSON format.
  2. Run the inference.
  3. Compare the time taken against the standard transformer package.

Environment:

llm_llamacpp = LlamaCPP(model_path="/root/Qwen/qwen2.5:72b-instruct-q8_0.gguf", model_kwargs={ "n_gpu_layers": -1, }, # if compiled to use GPU max_new_tokens=40960 , # 131072 context_window=40960, temperature=0, verbose = True )

from pydantic import BaseModel DEFAULT_SYSTEM_PROMPT = """\ You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.\n\nIf a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.\ """

def get_prompt(message: str, system_prompt: str = DEFAULT_SYSTEM_PROMPT) -> str: return f'[INST] <>\n{system_prompt}\n<>\n\n{message} [/INST]'

class AnswerFormat(BaseModel): first_name: str last_name: str year_of_birth: int num_seasons_in_nba: int ........... About 50 fields

question = <4k length content> question_with_schema = f'{question}{AnswerFormat.schema_json()}' prompt = get_prompt(question_with_schema)

def llamaindex_llamacpp_lm_format_enforcer(llm: LlamaCPP, prompt: str, character_level_parser: Optional[CharacterLevelParser]) -> str: logits_processors: Optional[LogitsProcessorList] = None if character_level_parser: logits_processors = LogitsProcessorList([build_llamacpp_logits_processor(llm._model, character_level_parser)])

# If changing the character level parser each call, inject it before calling complete. If its the same format
# each time, you can set it once after creating the LlamaCPP model
llm.generate_kwargs['logits_processor'] = logits_processors
output = llm.complete(prompt)
text: str = output.text
return text

result = llamaindex_llamacpp_lm_format_enforcer(llm_llamacpp, prompt, JsonSchemaParser(AnswerFormat.schema())) print(result)