ggerganov / llama.cpp

LLM inference in C/C++
MIT License
64.99k stars 9.32k forks source link

Bug: uncached prompt is not used for penalty #8971

Open z80maniac opened 1 month ago

z80maniac commented 1 month ago

What happened?

Sometimes the part of the initial prompt that should be considered for the penalties is ignored. Only the newly generated tokens are used for calculating penalty. For now I can assume it has something to do with the prompt caching (explained below).

Let's add the following debug code to the llama_sample_repetition_penalties_impl right after the token_count map is filled in:

printf("------\n");
for (const auto & entry : token_count) {
    printf("[%d] = %d\n", entry.first, entry.second);
}

It will show the tokens that will be used for penalty calculation.

After starting the server and running this:

curl -s --data '{"prompt": "Note that the file, line, and message properties are", "n_predict": 4, "repeat_penalty": 1.1, "cache_prompt": true}' http://127.0.0.1:8080/completion > /dev/null

the server log shows:

------
[0] = 64
------
[2016] = 1
[0] = 63
------
[1562] = 1
[2016] = 1
[0] = 62
------
[1278] = 1
[1562] = 1
[2016] = 1
[0] = 61

So it ignores the initial prompt and only uses the new tokens.

However, if I run the exact same query the second time, I get this:

------
[1584] = 1
[7192] = 1
[3110] = 1
[5117] = 1
[1321] = 1
[3323] = 1
[1044] = 2
[1278] = 1
[1455] = 1
[12791] = 1
[1] = 1
[0] = 52
------
[1584] = 1
[7192] = 1
[3110] = 1
[5117] = 1
[1321] = 1
[3323] = 1
[1044] = 2
[1278] = 1
[1455] = 1
[12791] = 1
[2016] = 1
[1] = 1
[0] = 51
------
[1536] = 1
[1] = 1
[2016] = 1
[1455] = 1
[12791] = 1
[1278] = 1
[3323] = 1
[1321] = 1
[5117] = 1
[3110] = 1
[0] = 50
[1044] = 2
[7192] = 1
[1584] = 1
------
[1536] = 1
[1] = 1
[2016] = 1
[1455] = 1
[12791] = 1
[1278] = 2
[3323] = 1
[1321] = 1
[5117] = 1
[3110] = 1
[0] = 49
[1044] = 2
[7192] = 1
[1584] = 1

Now it has all the initial tokens + one new token each step.

The bug has something to do with the prompt caching, because it does not happen when the cached prompt is used. But it happens in all other cases:

I tested it with CUDA/no-CUDA builds and two different models - the results are the same.

Name and Version

./llama-server --version version: 3565 (6e02327e) built with cc (Ubuntu 13.2.0-23ubuntu4) 13.2.0 for x86_64-linux-gnu

What operating system are you seeing the problem on?

Linux

Relevant log output

No response

z80maniac commented 4 weeks ago

After some more testing I discovered that this bug is even worse than I described above. It seems like the repeat_last_n API parameter is completely ignored when choosing the tokens for the penalty. The penalty_last_n that is passed to llama_sample_repetition_penalties_impl is always 64, no matter what is passed in the repeat_last_n parameter. So, only 64 tokens are considered for penalties, even if e.g. repeat_last_n = 2048.

How to test:

  1. Add printf("LAST_N: %ld\n", penalty_last_n); into the llama_sample_repetition_penalties_impl.
  2. Build the server and then run it with -c 256.
  3. Run the following query:
    curl -SsL --data '{"prompt": "Alice was beginning to get very tired of sitting by her sister on the bank, and of having nothing to do: once or twice she had peeped into the book her sister was reading, but it had no pictures or conversations in it, and where is the use of a book, thought Alice, without pictures or conversations? So she was considering in her own mind, (as well as she could, for the hot day made her feel very sleepy and stupid,) whether the pleasure of making a daisy-chain would be worth the trouble of getting up and picking the daisies, when suddenly a white rabbit with pink eyes ran close by her.", "n_predict": 32, "cache_prompt": true, "repeat_penalty": 1.5, "repeat_last_n": 128}' http://127.0.0.1:8080/completion > /dev/null

    The context size is 256 tokens, the prompt is 132 tokens and repeat_last_n is 128, but in the server output you will see LAST_N: 64. It seems like this 64 comes from llama_sampling_params::n_prev.

Unless I'm misunderstanding something, this breaks all the penalties for almost all use-cases. Tested on 2fb92678.

bviksoe commented 1 week ago

The bug described in the original issue is caused by a faulty ring-buffer implementation here: https://github.com/ggerganov/llama.cpp/blob/82e3b03c11826d20a24ab66d60f4de58f48ddcdb/common/sampling.cpp#L454 It basically always throws out the last token. So when you start with a clean buffer, there will always be just 1 token in the buffer.

I would wait for #9294 to land before trying to address this. At least that PR seems to fix the ring_buffer ejection issue.

ggerganov commented 6 days ago

How to test:

  1. Add printf("LAST_N: %ld\n", penalty_last_n); into the llama_sample_repetition_penalties_impl.
  2. Build the server and then run it with -c 256.
  3. Run the following query:

I think on latest master this is fixed. Tried the instructions above and the printf shows the correct number. Let us know if you spot any other issues.

z80maniac commented 6 days ago

Yes, the issue with the penalty_last_n is fixed. Thank you for the update.

However, the first one (with the prompt not being used for penalty) is not. The output is slightly different now (no more zero token), but the result is effectively the same.

curl -s --data '{"prompt": "Note that the file, line, and message properties are", "n_predict": 4, "repeat_penalty": 1.1, "cache_prompt": true}' http://127.0.0.1:8080/completion > /dev/null

This is what I get on the first try:

------
------
[5949] = 1
------
[5949] = 1
[1317] = 1
------
[5949] = 1
[1317] = 1
[1747] = 1

Only the new tokens are in the token_count map. No tokens from the existing context are considered.

On the second try (exactly the same query) I get:

------
[1] = 1
[12791] = 1
[1455] = 1
[1278] = 1
[1044] = 2
[3323] = 1
[1321] = 1
[5117] = 1
[3110] = 1
[7192] = 1
[1584] = 1
------
[1] = 1
[12791] = 1
[1455] = 1
[1278] = 1
[1044] = 2
[3323] = 1
[1321] = 1
[5117] = 1
[3110] = 1
[7192] = 1
[1584] = 1
[1805] = 1
------
[1] = 1
[12791] = 1
[1455] = 1
[1278] = 1
[1044] = 2
[3323] = 1
[1321] = 1
[5117] = 1
[3110] = 1
[7192] = 1
[1584] = 1
[1805] = 1
[1307] = 1
------
[1] = 1
[12791] = 1
[1455] = 1
[3323] = 1
[1321] = 1
[5117] = 1
[3110] = 1
[7192] = 1
[1584] = 1
[1805] = 1
[1307] = 1
[1044] = 2
[1278] = 2

Now it's properly including all the prior context.

Tested on 49006c67.