microsoft / LLMLingua

To speed up LLMs' inference and enhance LLM's perceive of key information, compress the prompt and KV-Cache, which achieves up to 20x compression with minimal performance loss.
https://llmlingua.com/
MIT License
4.18k stars 222 forks source link

[Question]: Reproduce LongLLMLingua on the LongBench MultiDoc dataset #136

Open Liangyx2 opened 2 months ago

Liangyx2 commented 2 months ago

Describe the issue

We are interested in your longllmlingua results on longbench. We referred to these two parts of your code[https://github.com/microsoft/LLMLingua/blob/main/experiments/llmlingua2/evaluation/compress.py ][https://github.com/microsoft/LLMLingua/blob/main/experiments/llmlingua2/evaluation/eval_longbench.py], and other settings are as follows.

target token: 3000 compresse model: Llama-2-7b-chat-hf llm model: Mistral-7B-v0.1 compress config:

contexts_list = origin.split("\n")
demonstration = [
    "\n".join(contexts_list[ii : ii + 4]) for ii in range(0, len(contexts_list), 4)
]
comp_dict = compressor.compress_prompt(
    demonstration,
    instruction=dataset2instruction[copy.deepcopy(sample["task"])],
    question=dataset2question[copy.deepcopy(sample["task"])].format(input=copy.deepcopy(sample["question"])),
    target_token=args.target_token,
    condition_compare=True,
    condition_in_question="after",
    rank_method="longllmlingua",
    use_sentence_level_filter=False,
    context_budget="+100",
    dynamic_context_compression_ratio=0.4,
    reorder_context="sort",
)

The results when running the local model (Mistral-7B-v0.1) appear different from the conclusions in the paper:

  1. The overall latency increases directly because it is a two-step reasoning overhead;
  2. The accuracy problem is very serious.
  3. Bugs occurs on the musique dataset.

result: task: 2wikimqa compress_time:816.7587068080902 s Original Prompt: {'2wikimqa': {'score': 34.85, 'num': 200}, 'avg': 34.85} predict_time:304.49882340431213 s LongLLMLingua: {'2wikimqa': {'score': 34.99, 'num': 200}, 'avg': 34.99} predict_time:224.4709668159485 s

task: dureader compress_time:986.9709990024567 s Original Prompt: {'dureader': {'score': 0.68, 'num': 200}, 'avg': 0.68} predict_time:888.872100353241 s LongLLMLingua: {'dureader': {'score': 0.02, 'num': 200}, 'avg': 0.02} predict_time:735.0324370861053 s

task: hotpotqa compress_time:1154.135990858078 s Original Prompt: {'hotpotqa': {'score': 41.16, 'num': 200}, 'avg': 41.16} predict_time:342.1117823123932 s LongLLMLingua: {'hotpotqa': {'score': 38.77, 'num': 200}, 'avg': 38.77} predict_time:237.45726418495178 s

task: musique

Traceback (most recent call last):
  File "/data2/lkk/yuxiang/llmlingua/compress.py", line 174, in <module>
    comp_dict = compressor.compress_prompt(
  File "/home/lvkaokao/anaconda3/envs/llmlingua1/lib/python3.10/site-packages/llmlingua/prompt_compressor.py", line 676, in compress_prompt
    start = self.get_prefix_length(prefix + "\n\n", context[0])
  File "/home/lvkaokao/anaconda3/envs/llmlingua1/lib/python3.10/site-packages/llmlingua/prompt_compressor.py", line 995, in get_prefix_length
    assert self.tokenizer.decode(full_input_ids[i:]) == text[:100]
AssertionError

Test hardware platform: We used 1 Nvidia A100-SXM4-80GB per experiment.

I would like to communicate with you the reason for these phenomenon. Thank you very much for your reply.

iofu728 commented 2 months ago

Hi @Liangyx2, thanks for your support and the very detailed experiment information.

Although we haven't tried using LongLLMLingua with Mistral as the LLM, my intuition tells me that it should still work.

  1. Concerning the latency issue, there are mainly two reasons: 1) The current version's sentence segmentation is quite coarse, leading to an excessively high number of sentences; 2) The coarse-level compression process does not use batch inference, which significantly increases latency. You might try modifying the segmentation method with the following code:
    context = tokenizer.encode(context)
    context = [tokenizer.decode(context[ii:ii + 500]) for ii in range(0, len(context), 500)]
  2. Thank you for your experimental results. It looks like the results for 2wikimqa are reasonably higher than the original prompt, which is encouraging. However, it seems there might be some bugs with dureader. Could you provide your scripts? I can try to help you debug this and the musique issue when I have the bandwidth. However, it might take a few days.

Thanks again for your support and interest.

Liangyx2 commented 2 months ago

Thanks for your reply and help! We'll continue to improve the experiment based on your suggestions. The command that triggers the error is

CUDA_VISIBLE_DEVICES=2 python compress.py \
--model_name /models/Llama-2-7b-chat-hf/ \
--load_origin_from experiment/longbench_musique.json \
--load_key context \
--save_key comp \
--save_path experiment/longbench_test_llama2_7b_3000_musique.json \
--target_token 3000 \
--device cuda

The script codes we use are as follows.

format_data_longbench.py

import json
import os

from datasets import load_dataset

os.makedirs("experiment/", exist_ok=True)

tasks = ["narrativeqa", "qasper", "multifieldqa_en", "multifieldqa_zh", "2wikimqa", "dureader", "hotpotqa", "musique"]

for task in tasks:
    dataset = load_dataset("THUDM/LongBench", task , split="test")
    data = []
    for idx, sample in enumerate(dataset):
        if task=="multifieldqa_zh":
            context=sample["context"].encode('unicode-escape').decode('unicode-escape')
            input=sample["input"].encode('unicode-escape').decode('unicode-escape')
            answers=[sample["answers"][0].encode('unicode-escape').decode('unicode-escape')]
        else:
            context=sample["context"]
            input=sample["input"]
            answers=sample["answers"]

        temp = {}
        temp["idx"] = idx
        temp["task"] = task
        temp["context"] = context
        temp["question"] = input
        temp["answers"] = answers
        temp["all_classes"] = sample["all_classes"]
        temp["length"] = sample["length"]
        data.append(temp)

        json.dump(
            data,
            open("experiment/longbench_"+task+".json", "w", encoding="utf8"),
            indent=4,
            ensure_ascii=False,
        )

compress.py

import argparse
import copy
import json
import os
import time

from tqdm import tqdm

from llmlingua.prompt_compressor import PromptCompressor

parser = argparse.ArgumentParser(description="compress any prompt.")

parser.add_argument(
    "--model_name",
    help="llm used to compress",
    default="microsoft/llmlingua-2-xlm-roberta-large-meetingbank",
)
parser.add_argument(
    "--load_origin_from", help="dataset used to compress", required=True
)
parser.add_argument(
    "--load_key", help="the key to load the text to compress", default="prompt"
)
parser.add_argument(
    "--save_key",
    help="the key to save the compressed text",
    default="compressed_prompt",
)

parser.add_argument("--save_path", help="path to save results", required=True)

parser.add_argument(
    "--target_token", help="number of target tokens", type=int, default=-1
)
parser.add_argument(
    "--force_tokens",
    help="the tokens which will be forcely preserved, comma separated",
    type=str,
    default=None,
)
parser.add_argument("--device", type=str)

args = parser.parse_args()
os.makedirs(os.path.dirname(args.save_path), exist_ok=True)
if args.force_tokens is not None:
    args.force_tokens = [
        str(item).replace("\\n", "\n") for item in args.force_tokens.split(",")
    ]
else:
    args.force_tokens = []
print(f"force tokens: {args.force_tokens}")

data = json.load(open(args.load_origin_from))
print(f"num data: {len(data)}")

dataset2instruction = {
    "narrativeqa": "You are given a story, which can be either a novel or a movie script, and a question. Answer the question asconcisely as you can, using a single phrase if possible. Do not provide any explanation.\n\nStory: ",
    "qasper": "You are given a scientific article and a question. Answer the question as concisely as you can, using a single phrase or sentence if possible. If the question cannot be answered based on the information in the article, write \"unanswerable\". If the question is a yes/no question, answer \"yes\", \"no\", or \"unanswerable\". Do not provide any explanation.\n\nArticle: ",
    "multifieldqa_en": "Read the following text and answer briefly.\n\n",
    "multifieldqa_zh": "阅读以下文字并用中文简短回答:\n\n",
    "2wikimqa": "Answer the question based on the given passages. Only give me the answer and do not output any other words.\n\nThe following are given passages.\n",
    "dureader": "请基于给定的文章回答下述问题。\n\n文章:",
    "hotpotqa": "Answer the question based on the given passages. Only give me the answer and do not output any other words.\n\nThe following are given passages.\n",
    "musique": "Answer the question based on the given passages. Only give me the answer and do not output any other words.\n\nThe following are given passages.\n",
    }

dataset2question = {
    "narrativeqa": "\n\nNow, answer the question based on the story asconcisely as you can, using a single phrase if possible. Do not provide any explanation.\n\nQuestion: {input}\n\nAnswer:",
    "qasper": "\n\n Answer the question based on the above article as concisely as you can, using a single phrase or sentence if possible. If the question cannot be answered based on the information in the article, write \"unanswerable\". If the question is a yes/no question, answer \"yes\", \"no\", or \"unanswerable\". Do not provide any explanation.\n\nQuestion: {input}\n\nAnswer:",
    "multifieldqa_en": "\n\nNow, answer the following question based on the above text, only give me the answer and do not output any other words.\n\nQuestion: {input}\nAnswer:",
    "multifieldqa_zh": "\n\n现在请基于上面的文章回答下面的问题,只告诉我答案,不要输出任何其他字词。\n\n问题:{input}\n回答:",
    "2wikimqa": "\n\nAnswer the question based on the given passages. Only give me the answer and do not output any other words.\n\nQuestion: {input}\nAnswer:",
    "dureader": "\n\n请基于上述文章回答下面的问题。\n\n问题:{input}\n回答:",
    "hotpotqa": "\n\nAnswer the question based on the given passages. Only give me the answer and do not output any other words.\n\nQuestion: {input}\nAnswer:",
    "musique": "\n\nAnswer the question based on the given passages. Only give me the answer and do not output any other words.\n\nQuestion: {input}\nAnswer:",
    }

compressor = PromptCompressor(model_name=args.model_name, device_map=args.device)

results = {}
results_list = []
total_time = 0

if os.path.exists(args.save_path):
    results = json.load(open(args.save_path))

for sample in tqdm(data):
    idx = int(sample["idx"])
    origin = copy.deepcopy(sample[args.load_key])
    if origin is None:
        continue
    if idx in results or str(idx) in results:
        print(f"{idx}-th sample is processed")
        continue
    t = time.time()

    contexts_list = origin.split("\n")
    demonstration = [
        "\n".join(contexts_list[ii : ii + 4]) for ii in range(0, len(contexts_list), 4)
    ]
    comp_dict = compressor.compress_prompt(
        # demonstration.split("\n"),
        demonstration,
        instruction=dataset2instruction[copy.deepcopy(sample["task"])],
        question=dataset2question[copy.deepcopy(sample["task"])].format(input=copy.deepcopy(sample["question"])),
        target_token=args.target_token,
        condition_compare=True,
        condition_in_question="after",
        rank_method="longllmlingua",
        use_sentence_level_filter=False,
        context_budget="+100",
        dynamic_context_compression_ratio=0.4,
        reorder_context="sort",
    )

    total_time += time.time() - t
    comp = comp_dict["compressed_prompt"]

    new_sample = copy.deepcopy(sample)
    new_sample[args.save_key] = comp

    results[idx] = new_sample
    json.dump(
        results,
        open(args.save_path, "w", encoding="utf8"),
        indent=4,
        ensure_ascii=False,
    )

print(args.save_path, total_time)
json.dump(
    results, open(args.save_path, "w", encoding="utf8"), indent=4, ensure_ascii=False
)
iofu728 commented 2 months ago

Hi @Liangyx2, The last issue is fixed. You can update to the latest version using pip install git+https://github.com/microsoft/LLMLingua.git.