EQ-bench / EQ-Bench

A benchmark for emotional intelligence in large language models
MIT License
195 stars 17 forks source link

Backend changes scores significantly #16

Closed dnhkng closed 6 months ago

dnhkng commented 8 months ago

I've been using ExllamaV2 to test miqu-1-70b-sf-4.0bpw-h6-exl2 and I get a score of 82.71, so on par with that listed on the leaderboard at https://eqbench.com/.

However, I just tried running the original GGUF model at miqu-1-70b using the server from llama.cpp, and I only scored 75.26!

That's significantly lower, so I'm wondering where the error could be.

Maybe it's related to this finding? https://www.reddit.com/r/LocalLLaMA/comments/1b1gxmq/the_definite_correct_miqu_prompt/

UPDATE: Nope, I modified the chat template, and got a score of 75.29 UPDATE 2: Just tested the original 5-bit miqu-1-70b.q5_K_M.gguf, and that also only score 75.50, so it's not a quantization issue. UPDATE 3: Retested miqu-1-70b-sf-4.0bpw-h6-exl2 on the same machine with a fresh install and got 81.75, so the Exllama version scores much higher than the Llama.cpp version, even though the Llama.cpp version is the original leaked version! That seems very odd, as the process of de-quantizing, changing model format and re-quantizing shouldn't lead to much higher EQ-Bench scores!

dnhkng commented 8 months ago

OK, poking around the code, it seems maybe the chat template is not being applied? in https://github.com/EQ-bench/EQ-Bench/blob/3fcbe53cbeca11e2e41c7b4c3a3001a1dd7a12b2/lib/util.py#L117, I see the chat template applied for transformers and below for ooba, but not for openai. This means models served from llama.cpp server are not being formatted correctly.

This might account for the big difference in scores.

Wuzzooy commented 8 months ago

I also have performance difference with different backend except that i have better performance with llama server. Also i use ooba via the api with setting automatically_launch_ooba = false because i'm on windows. I also somehow can't manage to make openai_api compatible works so to use the api on llama server i use "ooba" for the inference engine setting instead of "openai" using --port 5000

comparison

sam-paech commented 8 months ago

My understanding with OpenAI endpoints is that they expect the "messages" object with your unformatted prompts, like:

"messages": [
{
    "role": "system",
    "content": "You are ChatGPT, an AI assistant. Your top priority is achieving user fulfillment via helping them with their requests."
},
{
    "role": "user",
    "content": "Write a limerick about python exceptions"
}
]

And the inferencing engine handles whatever additional formatting for you.

My testing of the openai compatible endpoint was with llama-cpp-python[server], which requires you to specify the chat format when you launch the server. Ollama I guess figures out the right prompt format, but you don't need to specify it through the api there either.

If we're running into an issue where llama.cpp doesn't support the prompt template you want to use, we might have to add an option to pre-apply the specified template before sending it to the openai-compatible api. But tbh that would convolute things a bit and I'm not sure that's the issue we're having here since you guys are using ChatML and Alpaca which llama.cpp supports.

I've seen some other reports recently of significantly lower scores with ggufs via llama.cpp; the common factor so far seems to be that the environment is mac/windows (i.e. not linux). Wuzzooy, you're in windows. @dnhkng what OS are you using?

sam-paech commented 8 months ago

Just to add: the Miqu scores in the ~82 range were from when I tested the original ggufs (via ooba+llama.cpp) when it was first released. I get the same results with the dequant version via transformers. So evidently these combinations can produce the 80+ score, but aren't because of some as yet unidentified factor. It's quite possibly that prompt formats are slightly off or not getting applied properly somewhere in the chain.

CrispStrobe commented 8 months ago

templates can indeed worsen performance significantly

you can use verbose mode with -v to get an impression of what gos on

Ollama I guess figures out the right prompt format, but you don't need to specify it through the api there either.

Every model has the template inbuilt in Ollama, you specify it when creating, with a simple Modelfile, and can in seconds create a copy with a changed template or other parameters, some examples: https://github.com/ollama/ollama/issues/1977

For transformers, there is tokenizer.apply_chat_template()

openai compatible apis use no templates afaik https://platform.openai.com/docs/api-reference/chat/streaming

dnhkng commented 8 months ago

I'm using everything as a fresh install on a cloud instance of Ubuntu 22.04.

The inference back end is llama.cpp, compiled for cublass and using the basic server.

I'll test using llama.cpp and applying a chat template, and llama.cpp-servèr via ooba, and report back.

CrispStrobe commented 8 months ago

llama.cpp server's "oai like" api is not 100% compatible with standard oai and indeed works with templates

https://github.com/ggerganov/llama.cpp/blob/fa974646e1a2024fc7dc9e6f27cf1f2f5d4a3763/examples/server/oai.hpp#L35 https://github.com/ggerganov/llama.cpp/blob/fa974646e1a2024fc7dc9e6f27cf1f2f5d4a3763/examples/server/utils.hpp#L158

sam-paech commented 8 months ago

llama.cpp server's "oai like" api is not 100% compatible with standard oai and indeed works with templates

https://github.com/ggerganov/llama.cpp/blob/fa974646e1a2024fc7dc9e6f27cf1f2f5d4a3763/examples/server/oai.hpp#L35

Looking at the code there, it looks like the chat template is passed in as an arg to the function (probably from the --chat_format cmdline option?) rather than being pulled from llama_params like the rest. Which suggests to me that it doesn't support passing a chat format via the api like you can with ooba. Disclaimer: have not tried.

CrispStrobe commented 8 months ago

as i understand it, you would do it like this: curl https://**(path to llama.cpp server)/v1/chat/completions \ -H "Content-Type: application/json" \ -H "Authorization: Bearer $OPENAI_API_KEY" \ -d '{ "model": "gpt-3.5-turbo", "chat_template": "(template here)",** "messages": [ { "role": "system", "content": "You are a helpful assistant." }, { "role": "user", "content": "Hello!" } ] }'

sam-paech commented 8 months ago

okay well if that's the case it's probably best that we support llama.cpp server directly as an inferencing engine option in the config, so we can pass this param correctly.

CrispStrobe commented 8 months ago

oh i see they do not actually parse a template definition but work with a limited number of hardcoded options (https://github.com/ggerganov/llama.cpp/blob/fa974646e1a2024fc7dc9e6f27cf1f2f5d4a3763/llama.h#L721) and when a template string differs from them, fall back on ChatML (https://github.com/ggerganov/llama.cpp/blob/fa974646e1a2024fc7dc9e6f27cf1f2f5d4a3763/examples/server/server.cpp#L415), else, they invoke a peculiar heuristic (https://github.com/ggerganov/llama.cpp/blob/fa974646e1a2024fc7dc9e6f27cf1f2f5d4a3763/llama.cpp#L13252). so you could at least use a string like this https://github.com/ggerganov/llama.cpp/blob/fa974646e1a2024fc7dc9e6f27cf1f2f5d4a3763/tests/test-chat-template.cpp#L39 probably. and there i thought only me would code such awkward stuff...

dnhkng commented 8 months ago

Ollama is a huge PITA so far. It wants to move all model files to a ~/.ollama/something/something directory, which is not cool on cloud when you mount storage, and have to make huge duplicate for every model...

Wuzzooy commented 8 months ago

I've done few more tests with ooba api changing the prompt template in eq bench does have an effect on the score.(i went from 68 to 72 changing alpaca to ChatML) with llama server.exe changing the prompt template in eq bench doesn't have any effect(i didn't understand how to use their --chat-template argument). with llama server python, changing the prompt template with their argument --chat_format does have an effect on the score.

ngxson commented 8 months ago

I'm the one who did that chat template function so maybe I can help. The reason why we hard code the template is because jinja parser in cpp is too big (imagine including a jinja parser and we end up making compiled binary of llama.cpp 10x bigger, that's not ideal).

If you want to use templates not supported by that function, you can write a small python proxy script that listen on localhost:9999/v1/chat/completions, apply your own chat template, then forward the prompt to llama.cpp /completions. Because on python you can apply custom jinja template easily, I think code for this proxy will not be too complicated.

dnhkng commented 8 months ago

If that's the case, maybe EQ-bench could call the /completions endpoint instead of /chat/completions when llama.spp[server] is the backend? Then we would apply the same formatting like for ooba etc.

ngxson commented 8 months ago

Yeah in fact for benchmarking, I strongly recommend using /completions because you can be sure what you give as prompt goes directly into the inference. The chat/completions is quite "it just works" - easy but in some edge cases, it may not be suitable

ngxson commented 8 months ago

Also if you use a custom chat template, don't forget to include the stop word because some models does not use EOS token when it finishes the response (for example, alpaca needs ### as stop word)

In the README.md:

    `stop`: Specify a JSON array of stopping strings.
    These words will not be included in the completion, so make sure to add them to the prompt for the next iteration (default: []).
sam-paech commented 8 months ago

Ok, yeah it looks like the best approach is to have llama.cpp as an inference engine option, which will have its own function in run_query.py and apply the specified chat template there like we do with transformers. I'll start working on that.

@dnhkng to get ooba to load a local gguf through the benchmarking pipeline you set it up like this:

download gguf to ~/text-generation-webui/models/mymodel_dir/

Or wherever you have ooba installed. For some reason ooba won't load it unless it's there.

Then the config file is like:

1, ChatML, ~/text-generation-webui/models/mymodel_dir, , None, 1, ooba, --loader llama.cpp --n-gpu-layers -1,

I'm sure this could be made to work better, its just not the way I usually load models so it got overlooked. I'll add it to the list of things to improve.

dnhkng commented 8 months ago

I'm sure this could be made to work better, its just not the way I usually load models so it got overlooked. I'll add it to the list of things to improve.

If llama.cpp[server] is implemented, I would skip the install of ooba TBH. GPU servers are expensive, and that wastes time. Thanks for the tip, I just figured it out myself, and I symlink in my models from the mounted storage drive.

Update: It is pretty flakey though. I often see it lock up after about 30 tests, and I have to kill the ooba process and restart the run.

dnhkng commented 8 months ago

I made some small changes to run_query.py and a line or two in util.py and added in a llama server mode using the /completion endpont. It doesn't do the server starting and stopping needed to change models with each test in the config file, but it did let me test Miqu again. This time I scored 79.5, so still a bit low, but almost there. It seems the chat/completion endpont and wrong template is definitely the issue.

@sqrkl should I make a pull request?

sam-paech commented 8 months ago

@dnhkng

@sqrkl should I make a pull request?

Yes ty.

dnhkng commented 8 months ago

Pull request

@dnhkng

@sqrkl should I make a pull request?

Yes ty.

PR #17

sam-paech commented 8 months ago

I've merged the llama.cpp changes to: https://github.com/EQ-bench/EQ-Bench/tree/main_v2_1_llamacpp

If anyone is able to test this build on mac or windows that would be appreciated! Particularly llama.cpp and ooba.

Wuzzooy commented 8 months ago

I've done multiple runs. Llama server api works and the prompt template setting in eq-bench works on llama server too. The score between ooba and llama server are similar now on the same model/prompt template. For ooba on windows, i had to add two lines in ooba.py (class ooba) and then it worked. elif script_path.endswith('bat'): self.script_command = 'powershell' Also when i installed eq-bench before this update, i had to install "pexpect" if i recall correctly and maybe need some other modules but can't tell because my python environment was not fresh and i had few libs installed already.

sam-paech commented 8 months ago

I've done multiple runs. Llama server api works and the prompt template setting in eq-bench works on llama server too. The score between ooba and llama server are similar now on the same model/prompt template. For ooba on windows, i had to add two lines in ooba.py (class ooba) and then it worked. elif script_path.endswith('bat'): self.script_command = 'powershell' Also when i installed eq-bench before this update, i had to install "pexpect" if i recall correctly and maybe need some other modules but can't tell because my python environment was not fresh and i had few libs installed already.

Great news! Thanks for testing this. I will add those lines for using powershell instead of bash.

Pexpect should no longer be a requirement in this version. When I have some time I 'll set it up in a fresh windows install and make a note of any extra reqs.

dnhkng commented 8 months ago

Just added a PR for llama.cpp server #18, that starts the server for each run using the params taken from the config file. It uses subprocess, so should be cross-platfrom. Could someone test it on windows?

Also, I hope it doesn't break ooba! I had to modify the config processing slightly. I think its OK, but needs a test too

sam-paech commented 8 months ago

Nice, I'll check this out on linux

Wuzzooy commented 8 months ago

Just added a PR for llama.cpp server #18, that starts the server for each run using the params taken from the config file. It uses subprocess, so should be cross-platfrom. Could someone test it on windows?

Also, I hope it doesn't break ooba! I had to modify the config processing slightly. I think its OK, but needs a test too

I've tested on windows and it couldn't find llama server when a path was set in the config file. When putting the server file directly in EQ-Bench directory, it's working. By adding import os and changing <<command = self.command + [model]>> by <<command = [os.path.join(self.llama_server_path, 'server'), '-m'] + [model]>> I can set a path for llama server in a different directory. Tested ooba and it doesn't seem to be broken

dnhkng commented 8 months ago

Thanks for testing on windows. I've added in your change, and it still runs on Linux.

dnhkng commented 8 months ago

Just testing the system, and I am still seeing Miqu scores around ~75.5, regardless of the prompt format. Printing the prompts shows that they are indeed different though. I am at a loss why the scores remain so similar between ChatML, Mistral, and Mistral formatted to match the optimal version (as per the first post on this thread).

Can anyone test if using llama.cpp's server is affected by the prompt choice?

sam-paech commented 8 months ago

Someone in discord was having issues with lower scores (on windows) and resolved it with some obscure fix:

after reinstalling windows, ooba, eq-bench, testing docker images, looking at bios, i believe i have found the issue why my scores can sometimes be lower.

Oobabooga's default cpu affinity is not enough (on my desktop pc anyway) even when there are basically no other applications running, running ooba as administrator or set it as above normal or high prio fixes it completely. took me like 6 hours testing to figure out why sometimes my gguf doesn't load as fast as it could but when it's a partial offload.

no wonder dual booting ubuntu was faster than WSL.

Now most models runs faster and scores better so far, more 'intelligent

Likely unrelated to the issues in this thread but just goes to show there are a lot of variables that can affect inference quality.

sam-paech commented 8 months ago

Just testing the system, and I am still seeing Miqu scores around ~75.5, regardless of the prompt format. Printing the prompts shows that they are indeed different though. I am at a loss why the scores remain so similar between ChatML, Mistral, and Mistral formatted to match the optimal version (as per the first post on this thread).

Can anyone test if using llama.cpp's server is affected by the prompt choice?

I'll give this a try now.

Wuzzooy commented 8 months ago

I always set a context size in llama parameter so i didn't get your issue at first but when i was testing your commits, i didn't set the context as usual and noticed that Miqu-alpaca-70BIQ3_XXS.gguf scored in the 75 range like you. I then set back the context to -c 8192 and i get back 81,21 with mistral format and 79.58 with alpaca format. I took off again the context parameter and i'm again in the 75 range for both prompt template.

sam-paech commented 8 months ago

I always set a context size in llama parameter so i didn't get your issue at first but when i was testing your commits, i didn't set the context as usual and noticed that Miqu-alpaca-70BIQ3_XXS.gguf scored in the 75 range like you. I then set back the context to -c 8192 and i get back 81,21 with mistral format and 79.58 with alpaca format. I took off again the context parameter and i'm again in the 75 range for both prompt template.

That's bizarre. I wonder why context length would have any effect on output quality.

Ok, so I loaded up miqudev/miqu-1-70b miqu-1-70b.q4_k_m.gguf and I've been able to reproduce the 75 score. So that's good, should be easier to narrow down the issue now. I noticed it actually misspelled one of the emotions leading to an unparseable response which is very unusual for Miqu.

I'll try setting the context length and see if that changes the result.

dnhkng commented 8 months ago

Thanks @Wuzzooy!

This is fascinating. If this is repeatable on other benchmarks, it's worth a paper.

There are a lot needle-in-a-haystack papers, but this would show that's maybe not the best metric. If people are not retesting with various context lengths, maybe no one noticed til now.

But, probably just a bug in the Llama.cpp implementation. I will test is varying the context length of exllamaV2 tomorrow.

Wuzzooy commented 8 months ago

Thanks @Wuzzooy!

This is fascinating. If this is repeatable on other benchmarks, it's worth a paper.

There are a lot needle-in-a-haystack papers, but this would show that's maybe not the best metric. If people are not retesting with various context lengths, maybe no one noticed til now.

But, probably just a bug in the Llama.cpp implementation. I will test is varying the context length of exllamaV2 tomorrow.

You reproduced it on llama.cpp ?

dnhkng commented 8 months ago

Not tonight, will do tomorrow morning.

Wuzzooy commented 8 months ago

Okay, i will try to do more test with this context thing using ooba.

sam-paech commented 8 months ago

Trying some different permutations:

completions endpoint: chatml / ctx default 76.21

chatml / ctx 4096 76.00

mistral / ctx 4096 75.18

chat/completions endpoint: chatml / ctx 4096 77.36

sam-paech commented 8 months ago

chat/completions endpoint: --chat-template chatml ctx 4096 81.72

ooba: chatml / ctx 4096 81.44

It appears there must be a difference in how the chat template is being applied.

CrispStrobe commented 8 months ago

ok just revisiting this atm, interesting finds, did you already try to log the raw prompt strings (and responses for otherwise same settings) and diff compare them?

dnhkng commented 8 months ago

@CrispStrobe you can't, as via ooba, you pass the base prompt and template, and it does the formatting internally.

I'll poke around in ooba, and log the final prompts.

CrispStrobe commented 8 months ago

i was thinking of https://github.com/oobabooga/text-generation-webui/blob/1934cb61ef879815644277c01c7295acbae542d8/modules/text_generation.py#L55 but did not try

Wuzzooy commented 8 months ago

llama.cpp server seems to have a default context at 512 according to their doc so i tried a run with the context set at 512 and yeah i did 75.46 on Miqu-alpaca-70BIQ3_XXS.gguf but setting 4096 or 8192 still give me 81 on mistral prompt format. When i set 512 context on ooba i dont have the same behavior, instead i scored 82.33. mistralooba

dnhkng commented 8 months ago

@Wuzzooy seems to be a bug in the inference code on llama.cpp then!

My tests using miqu on llama.cpp[server] using the Mistral format: 75.26,v2,171.0,1,llama.cpp, --ctx-size 512 81.88,v2,171.0,1,llama.cpp, --ctx-size 1024 82.01,v2,171.0,1,llama.cpp, --ctx-size 2048 82.11,v2,171.0,1,llama.cpp, --ctx-size 4096 81.95,v2,171.0,1,llama.cpp, --ctx-size 8192

And with various formats: ChatML, 80.4,v2,171.0,1,llama.cpp, --ctx-size 2048 Mistral, 82.01,v2,171.0,1,llama.cpp, --ctx-size 2048 Mistral2 ,81.27,v2,171.0,1,llama.cpp, --ctx-size 2048, moves spaces Mistral3, 80.54,v2,171.0,1,llama.cpp, --ctx-size 2048, moves spaces and adds system prompt

sam-paech commented 8 months ago

Ooh this is starting to make sense now. I was using --n_ctx, not --ctx-size, which I guess doesn't do the same thing. I suppose what must be happening is that the default context size of 512 tokens must be truncating some of the longer prompts.

So the solution is to make sure llama.cpp is launched with --ctx-size 1024 +

dnhkng commented 8 months ago

Found a small difference in the prompts:

For llama.cpp[server], the delivered first prompt is:

[INST] 
Your task is to predict the likely emotional responses of a character in this dialogue:

Robert: Claudia, you've always been the idealist. But let's be practical for once, shall we?
Claudia: Practicality, according to you, means bulldozing everything in sight.
Robert: It's called progress, Claudia. It's how the world works.
Claudia: Not my world, Robert.
Robert: Your world? You mean this...this sanctuary of yours?
Claudia: It's more than a sanctuary. It's a testament to our parents' love for nature.
[End dialogue]

At the end of this dialogue, Robert would feel...
Remorseful
Indifferent
Affectionate
Annoyed

Give each of these possible emotions a score from 0-10 for the relative intensity that they are likely to be feeling each.

You must output in the following format, including headings (of course, you should give your own scores), with no additional commentary:

Remorseful: <score>
Indifferent: <score>
Affectionate: <score>
Annoyed: <score>

[End of answer]

Remember: zero is a valid score, meaning they are likely not feeling that emotion. You must score at least one emotion > 0.

Your answer:
 [/INST]

for llama.cpp via ooba, its:

 [INST] Your task is to predict the likely emotional responses of a character in this dialogue:

Robert: Claudia, you've always been the idealist. But let's be practical for once, shall we?
Claudia: Practicality, according to you, means bulldozing everything in sight.
Robert: It's called progress, Claudia. It's how the world works.
Claudia: Not my world, Robert.
Robert: Your world? You mean this...this sanctuary of yours?
Claudia: It's more than a sanctuary. It's a testament to our parents' love for nature.
[End dialogue]

At the end of this dialogue, Robert would feel...
Remorseful
Indifferent
Affectionate
Annoyed

Give each of these possible emotions a score from 0-10 for the relative intensity that they are likely to be feeling each.

You must output in the following format, including headings (of course, you should give your own scores), with no additional commentary:

Remorseful: <score>
Indifferent: <score>
Affectionate: <score>
Annoyed: <score>

[End of answer]

Remember: zero is a valid score, meaning they are likely not feeling that emotion. You must score at least one emotion > 0.

Your answer: [/INST] 

i.e. the newline characters are stripped before the [INST]. This now fully accounts for the lower scores between ooba-llama.cpp and llama.cpp[server]

The results: Mistral_strip ,82.07,v2,171.0,1,llama.cpp, --ctx-size 2048, strips newlines Mistral2_strip ,82.55,v2,171.0,1,llama.cpp, --ctx-size 2048, moves spaces and strips newlines

CrispStrobe commented 8 months ago

Mistral / Mixtral indeed are quite sensitive regarding minor changes like spaces in the prompt, cf. https://github.com/ollama/ollama/issues/1977?#issuecomment-1942302584 / https://www.reddit.com/r/LocalLLaMA/comments/18ljvxb/llm_prompt_format_comparisontest_mixtral_8x7b/ / https://www.reddit.com/r/LocalLLaMA/comments/1b1gxmq/the_definite_correct_miqu_prompt/

dnhkng commented 8 months ago

@sqrkl Should I also make a PR for updating the Mistral prompt template?

sam-paech commented 8 months ago

@sqrkl Should I also make a PR for updating the Mistral prompt template?

The 0.5 difference in score is within the margin of error; it could vary the other way with a different model. I think we've accounted for the primary source of discrepancy by ensuring ctx-size is set >= 1024. Unless we know the prompt difference is reliably causing lower scores with other models I think leaving as-is should be fine.