HKUNLP / ChunkLlama

[ICML'24] Data and code for our paper "Training-Free Long-Context Scaling of Large Language Models"
Apache License 2.0
295 stars 14 forks source link

Llama-2-70b-chat-hf model fails to pass the 125k passkey test #6

Closed 2793145003 closed 4 months ago

2793145003 commented 4 months ago

Hi, great work!

I have been conducting passkey tests on several models. The TinyLlama-1.1B-Chat-v1.0(2k) model successfully passed the 20k and, after fine-tuning, the 125k tests with a 60% accuracy rate. However, the Llama-2-70b-chat-hf(4k) model only achieved 40% accuracy in a 50k context and 0% in a 125k context.

I have been using the following script:

from chunkllama_attn_replace import replace_with_chunkllama

replace_with_chunkllama(pretraining_length=4096)
model_path = '/model/Llama-2-70b-chat-hf'
tokenizer_path = model_path
model = LlamaForCausalLM.from_pretrained(model_path, torch_dtype=torch.float16, device_map="auto")
tokenizer = LlamaTokenizer.from_pretrained(tokenizer_path)
model.eval()

def gen_prompt(all_count):
    prompt = "There is an important info hidden inside a lot of irrelevant text. Find it and memorize it. I will quiz you about the important information later.\n"
    n = random.randint(0, all_count)
    prompt += 'The grass is green. The sky is blue. The sun is yellow. Here we go. There and back again.\n' * n
    pass_key = random.randint(10000, 99999)
    prompt += f'The pass key is {pass_key}. Remember it. {pass_key} is the pass key.\n'
    prompt += 'The grass is green. The sky is blue. The sun is yellow. Here we go. There and back again.\n' * (all_count-n)
    prompt += 'What is the pass key? The pass key is '
    return prompt, pass_key

for i in range(10):
    prompt, target = gen_prompt(5000)
    prompt_postfix = "What is the pass key? The pass key is "
    input_ids = tokenizer(prompt, return_tensors="pt").input_ids.cuda()
    print("-----------------------------------")
    print(f"#Tokens of Prompt: {input_ids.shape[1]}", end=" ")
    print(f"Passkey target: {target}")

    tokens = model.generate(input_ids, max_new_tokens=6)
    r = tokenizer.decode(tokens[0].tolist()[input_ids.shape[1]:], skip_special_tokens=True)
    answer = f"ChunkLlama:     [{prompt_postfix}{r}]"
    answer = answer.replace("\n", "\\n")
    print(answer)

The results I've been getting are as follows:

#Tokens of Prompt: 125070 Passkey target: 97476
This is a friendly reminder - the current text generation call will exceed the model's predefined maximum length (4096). Depending on the model, you may observe exceptions, performance degradation, or nothing at all.
ChunkLlama:     [What is the pass key? The pass key is 9456. Remember]
-----------------------------------
#Tokens of Prompt: 125070 Passkey target: 96339
ChunkLlama:     [What is the pass key? The pass key is 934212]
-----------------------------------
#Tokens of Prompt: 125070 Passkey target: 10344
ChunkLlama:     [What is the pass key? The pass key is 1234. Remember]
-----------------------------------
#Tokens of Prompt: 125070 Passkey target: 64645
ChunkLlama:     [What is the pass key? The pass key is 654632]
-----------------------------------
#Tokens of Prompt: 125070 Passkey target: 55491
ChunkLlama:     [What is the pass key? The pass key is 54211.]

How can I achieve results consistent with those reported in README?

Thank you.

ChenxinAn-fdu commented 4 months ago

Hi! Thank you for this issue!
The results reported in README and our paper are measured by perplexity. However, Llama-70B-chat should not achieve such poor performance. Based on your code, I suggest following the chat template of Llama2-chat models like [INST] <<sys>> This is a passkey retrieval task<</sys >>long doc What is the pass key? [/INST] The pass key is (please refer to run_llama_100k.py). If changing the prompts doesn't work, please feel free to leave a comment andI will detailedly check what happens.

2793145003 commented 4 months ago

Thanks for your response!

I've changed the prompts to

def gen_prompt(all_count):
    B_INST, E_INST = "[INST]", "[/INST]"
    B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
    sys_prompt = "This is a passkey retrieval task.\nThere is an important info hidden inside a lot of irrelevant text. Find it and memorize them. I will quiz you about the important information there."
    n = random.randint(0, all_count)
    content = 'The grass is green. The sky is blue. The sun is yellow. Here we go. There and back again.\n' * n
    pass_key = random.randint(10000, 99999)
    content += f'The pass key is {pass_key}. Remember it. {pass_key} is the pass key.\n'
    content += 'The grass is green. The sky is blue. The sun is yellow. Here we go. There and back again.\n' * (all_count-n)
    question = 'What is the pass key?'
    message = B_INST + B_SYS + sys_prompt + E_SYS + content + f"Question:\n{question}" + E_INST + "Answer:\n"
    return message, pass_key

and got answers like

#Tokens of Prompt: 125100 Passkey target: 95994
This is a friendly reminder - the current text generation call will exceed the model's predefined maximum length (4096). Depending on the model, you may observe exceptions, performance degradation, or nothing at all.
ChunkLlama:     [The key is 942.]
-----------------------------------
#Tokens of Prompt: 125100 Passkey target: 77050
ChunkLlama:     [The key is 750.]
-----------------------------------
#Tokens of Prompt: 125100 Passkey target: 16789
ChunkLlama:     [The key is 788.]
-----------------------------------
#Tokens of Prompt: 125100 Passkey target: 33519
ChunkLlama:     [The key is 31A.]
-----------------------------------
#Tokens of Prompt: 125100 Passkey target: 54180
ChunkLlama:     [The pass is 542.\nThe grass is green. The sky is blue. The]

Let me know if I did anything wrong.

ChenxinAn-fdu commented 4 months ago

Thank you so much for letting me know! It will definitely help us improve this work.

I will try to solve this problem on Monday.

ChenxinAn-fdu commented 4 months ago

Hi I've updated the code in chunkllama_attn_replace.py via adding temperature adaption proposed by Yarn (Line 65). Based on my test, Llama-70B (chat) now can a achieve a 100% acc on the 100k passkey retrieval task.

2793145003 commented 4 months ago

Thanks for your work. But I get OOM this time. I'm running on 8xA100.

-----------------------------------
#Tokens of Prompt: 125100 Passkey target: 51703
Traceback (most recent call last):
  File "/data/model/test.py", line 69, in <module>
    tokens = model.generate(input_ids, max_new_tokens=20)
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/transformers/generation/utils.py", line 1474, in generate
    return self.greedy_search(
  File "/usr/local/lib/python3.10/dist-packages/transformers/generation/utils.py", line 2335, in greedy_search
    outputs = self(
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/accelerate/hooks.py", line 166, in new_forward
    output = module._old_forward(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py", line 1183, in forward
    outputs = self.model(
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py", line 1043, in forward
    attention_mask = _prepare_4d_causal_attention_mask(
  File "/usr/local/lib/python3.10/dist-packages/transformers/modeling_attn_mask_utils.py", line 307, in _prepare_4d_causal_attention_mask
    attention_mask = attn_mask_converter.to_4d(
  File "/usr/local/lib/python3.10/dist-packages/transformers/modeling_attn_mask_utils.py", line 121, in to_4d
    causal_4d_mask = self._make_causal_mask(
  File "/usr/local/lib/python3.10/dist-packages/transformers/modeling_attn_mask_utils.py", line 160, in _make_causal_mask
    mask = mask.to(dtype)
torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 29.15 GiB. GPU 0 has a total capacty of 95.04 GiB of which 17.10 GiB is free. Process 3318804 has 77.94 GiB memory in use. Of the allocated memory 76.74 GiB is allocated by PyTorch, and 5.12 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF
ChenxinAn-fdu commented 4 months ago

I have adapted the code to transformers 4.37. So please remember to set the attn_implementation parameter.

In this case: model = LlamaForCausalLM.from_pretrained(model_path, attn_implementation="flash_attention_2", device_map="auto", trust_remote_code=True, torch_dtype=torch.bfloat16)

2793145003 commented 4 months ago

It works like a charm! Thank you for your patience.

ChenxinAn-fdu commented 4 months ago

Thank you! This issue really helps improve this work. I am also willing to answer your further questions.