Closed 2793145003 closed 4 months ago
Hi! Thank you for this issue!
The results reported in README and our paper are measured by perplexity. However, Llama-70B-chat should not achieve such poor performance. Based on your code, I suggest following the chat template of Llama2-chat models like [INST] <<sys>> This is a passkey retrieval task<</sys >>long doc What is the pass key? [/INST] The pass key is
(please refer to run_llama_100k.py
). If changing the prompts doesn't work, please feel free to leave a comment andI will detailedly check what happens.
Thanks for your response!
I've changed the prompts to
def gen_prompt(all_count):
B_INST, E_INST = "[INST]", "[/INST]"
B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
sys_prompt = "This is a passkey retrieval task.\nThere is an important info hidden inside a lot of irrelevant text. Find it and memorize them. I will quiz you about the important information there."
n = random.randint(0, all_count)
content = 'The grass is green. The sky is blue. The sun is yellow. Here we go. There and back again.\n' * n
pass_key = random.randint(10000, 99999)
content += f'The pass key is {pass_key}. Remember it. {pass_key} is the pass key.\n'
content += 'The grass is green. The sky is blue. The sun is yellow. Here we go. There and back again.\n' * (all_count-n)
question = 'What is the pass key?'
message = B_INST + B_SYS + sys_prompt + E_SYS + content + f"Question:\n{question}" + E_INST + "Answer:\n"
return message, pass_key
and got answers like
#Tokens of Prompt: 125100 Passkey target: 95994
This is a friendly reminder - the current text generation call will exceed the model's predefined maximum length (4096). Depending on the model, you may observe exceptions, performance degradation, or nothing at all.
ChunkLlama: [The key is 942.]
-----------------------------------
#Tokens of Prompt: 125100 Passkey target: 77050
ChunkLlama: [The key is 750.]
-----------------------------------
#Tokens of Prompt: 125100 Passkey target: 16789
ChunkLlama: [The key is 788.]
-----------------------------------
#Tokens of Prompt: 125100 Passkey target: 33519
ChunkLlama: [The key is 31A.]
-----------------------------------
#Tokens of Prompt: 125100 Passkey target: 54180
ChunkLlama: [The pass is 542.\nThe grass is green. The sky is blue. The]
Let me know if I did anything wrong.
Thank you so much for letting me know! It will definitely help us improve this work.
I will try to solve this problem on Monday.
Hi I've updated the code in chunkllama_attn_replace.py
via adding temperature adaption proposed by Yarn (Line 65). Based on my test, Llama-70B (chat) now can a achieve a 100% acc on the 100k passkey retrieval task.
Thanks for your work. But I get OOM this time. I'm running on 8xA100.
-----------------------------------
#Tokens of Prompt: 125100 Passkey target: 51703
Traceback (most recent call last):
File "/data/model/test.py", line 69, in <module>
tokens = model.generate(input_ids, max_new_tokens=20)
File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 115, in decorate_context
return func(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/transformers/generation/utils.py", line 1474, in generate
return self.greedy_search(
File "/usr/local/lib/python3.10/dist-packages/transformers/generation/utils.py", line 2335, in greedy_search
outputs = self(
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/accelerate/hooks.py", line 166, in new_forward
output = module._old_forward(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py", line 1183, in forward
outputs = self.model(
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py", line 1043, in forward
attention_mask = _prepare_4d_causal_attention_mask(
File "/usr/local/lib/python3.10/dist-packages/transformers/modeling_attn_mask_utils.py", line 307, in _prepare_4d_causal_attention_mask
attention_mask = attn_mask_converter.to_4d(
File "/usr/local/lib/python3.10/dist-packages/transformers/modeling_attn_mask_utils.py", line 121, in to_4d
causal_4d_mask = self._make_causal_mask(
File "/usr/local/lib/python3.10/dist-packages/transformers/modeling_attn_mask_utils.py", line 160, in _make_causal_mask
mask = mask.to(dtype)
torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 29.15 GiB. GPU 0 has a total capacty of 95.04 GiB of which 17.10 GiB is free. Process 3318804 has 77.94 GiB memory in use. Of the allocated memory 76.74 GiB is allocated by PyTorch, and 5.12 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting max_split_size_mb to avoid fragmentation. See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF
I have adapted the code to transformers 4.37. So please remember to set the attn_implementation
parameter.
In this case: model = LlamaForCausalLM.from_pretrained(model_path, attn_implementation="flash_attention_2", device_map="auto", trust_remote_code=True, torch_dtype=torch.bfloat16)
It works like a charm! Thank you for your patience.
Thank you! This issue really helps improve this work. I am also willing to answer your further questions.
Hi, great work!
I have been conducting passkey tests on several models. The TinyLlama-1.1B-Chat-v1.0(2k) model successfully passed the 20k and, after fine-tuning, the 125k tests with a 60% accuracy rate. However, the Llama-2-70b-chat-hf(4k) model only achieved 40% accuracy in a 50k context and 0% in a 125k context.
I have been using the following script:
The results I've been getting are as follows:
How can I achieve results consistent with those reported in README?
Thank you.