ggerganov / llama.cpp

LLM inference in C/C++
MIT License
61.29k stars 8.76k forks source link

Bug: Qwen2-72B-Instruct (and finetunes) Q4_K_M, Q5_K_M generates random output with CuBLAS prompt processing #8025

Open anunknowperson opened 2 weeks ago

anunknowperson commented 2 weeks ago

What happened?

Qwen2-72B-Instruct Q4_K_M generates output with random tokens (numbers, special symbols, random chunks of words from different languages, etc).

Has been tested on: 1) Tesla P40 24gb + CPU partitioning with offloating half of the layers 2) Inference fully on RAM (on another pc from 1)

Other people say it works with Q6, maybe the problem is with Q4_K_M (i can't test q6).

I've tried with both FlashAttention on and off and MMQ on and off, doesn't work.

I tested with llama.cpp binaries, koboldcpp, text-generation-webui - doesn't work everywhere.

related: https://github.com/LostRuins/koboldcpp/issues/909

image

Name and Version

version: 3181 (37bef894) built with MSVC 19.29.30154.0 for x64

What operating system are you seeing the problem on?

Windows

Relevant log output

No response

anunknowperson commented 2 weeks ago

As a workaround i use IQ4_XS.

ggerganov commented 2 weeks ago

Find the commit that breaks it

tim-janik commented 1 week ago

Find the commit that breaks it

Today, I also ran into Qwen2 random character generations. I think what I am seeing matches this bug description. Whether it is triggered depends on the prompt, so I am attaching a shell script with a slightly lengthy prompt that I used to reliably trigger random char generation, so I could bisect the commit.

repro.sh.md

The bisect result:

a818f3028d1497a51cb2b8eb7d993ad58784940e is the first bad commit
commit a818f3028d1497a51cb2b8eb7d993ad58784940e (tag: b3216)
Author: Johannes Gäßler <johannesg@5d6.de>
Date:   Mon Jun 24 17:43:42 2024 +0200

    CUDA: use MMQ instead of cuBLAS by default (#8075)

 CMakeLists.txt       | 15 ++++++++++-----
 Makefile             |  3 +++
 README.md            |  5 +++--
 ggml-cuda.cu         | 96 ++++++++++++++++++++++++++++++++++--------------------------------------------------------------
 ggml-cuda/common.cuh | 36 +++---------------------------------
 ggml-cuda/mmq.cu     | 36 +++++++++++++++++++++++++++++++++---
 ggml-cuda/mmq.cuh    | 53 ++++++++++++++++++++++++++++++++++++-----------------
 ggml-cuda/mmvq.cuh   |  2 ++
 8 files changed, 124 insertions(+), 122 deletions(-)

Good generation (with d62e4aaa02540c89be8b59426340b909d02bbc9e):

./repro.sh
[...]
* Replaced icon-gen with mogrify

Bad generation (with a818f3028d1497a51cb2b8eb7d993ad58784940e):

./repro.sh
[...]
0!097G$.>@".!F4%=0$!%&D22$!@52::-C<@G47*F9C>>F#CB4G=A40)G):/&BC((07$,7!4FB6G;7H<CD5"*0B0);:D,-7/.>,

This happens with https://huggingface.co/Qwen/Qwen2-72B-Instruct quantized to Q4ks or Q4km. It also happens with

https://huggingface.co/mradermacher/Smaug-Qwen2-72B-Instruct-GGUF/resolve/main/Smaug-Qwen2-72B-Instruct.Q4_K_S.gguf

llama-cli runs on a 4090 here and is compiled with make LLAMA_CUDA=1 CUDA_DOCKER_ARCH=sm_86 CC=ccache\ clang CXX=ccache\ clang++ -j

slaren commented 1 week ago

That commit was made after this issue was opened, so I don't see how that can be the issue.

tim-janik commented 1 week ago

That commit was made after this issue was opened, so I don't see how that can be the issue.

As I wrote above, the behaviour is prompt related. I'm currently running d62e4aaa02540c89be8b59426340b909d02bbc9e (parent of the commit identified above), and I do still see random characters of the same kind seldomly. But with longer prompts and not reliably. Which means, there is some chance for the Qwen2 Q4 code to run into this, and a818f3028d1497a51cb2b8eb7d993ad58784940e made it way more likely, up to the point of being able to produce one example that triggers it reliably.

tim-janik commented 1 week ago

Ok, one more observation.

I am still running d62e4aaa02540c89be8b59426340b909d02bbc9e like this: /ml/llama.cpp/llama-server -fa -nkvo --host 0.0.0.0 --port 8080 -np 4 -m /ml/models/Smaug-Qwen2-72B-Instruct.Q4_K_S.gguf -c 32768 -ngl 43

I have a number of prompts like the one in the above shell script. After sending 29 of those (via a python script that shuts off the connection after reading the first line), Smaug-Qwen2 starts generating random chars for anything, even through the web UI:

llama.cpp
[New UI](http://localhost:8080/index-new.html)

User: Hello?

Llama: "DD+39;DH433D)+>3%678=4$GG8B)&7")@7A9D;@%9A9;2C@A!<A+#&4,:569!F>"<#!CF#4!-4EG7G2@!H;6/9E//0)$$>).=6!>/#<(EC:!."%7@H7F0A92=2":/5&2!G,8062A0D$:A&B,HD&38C!/:6+(HA74;D%CFB1:'%CHE><1!2(21B61D#!:;--'$+.6&>>">>G0&5&>%6:##%4F0F2D@B8!BA$0;CD42&6&84A,4&!2$:7396294,$48F+%-8.,+@9-,/F:F08:.E&"0>:/)<'+$#&,F)+B:;@'A4B-A>,+)!FG&E+)7::.(GBHC@!A&5H@"%-A8F82C5!C56#'AB&A0B,85AA.C"6+89C<=5C#>D<$.05EC";+"'/*/CB

400 predicted, 446 cached, 389ms per token, 2.57 tokens per second

Powered by [llama.cpp](https://github.com/ggerganov/llama.cpp) and [ggml.ai](https://ggml.ai/).

Same thing for qwen2-72b-ins.Q4km.gguf, it messes up at the same prompt in a series of 29. But not if the server is restarted and I start with that 29th prompt.

Same thing for qwen2-72b-ins.Q4ks.gguf, also generates random chars at the 29th prompt and is stuck generating random chars for any other requests.

So the server state seems to be messed up after that series of requests.

But, I have not seen that with other model families. So far I tried LaDameBlanche-103b_iQ3xxs, c4ai-command-r-v01-Q6, Hermes-2-Theta-Llama-3-70B-32k.Q4_K_S, Yi... Those are still fine after 80+ requests.

Not sure if it helps if I start to bisect this, it'd take fairly long and what could be a good commit to start from?

slaren commented 1 week ago

If I recall correctly, this model had issues with precision in the attention. Can you test if this change fixes it?

diff --git a/src/llama.cpp b/src/llama.cpp
index 3dc0f853..310f46c8 100644
--- a/src/llama.cpp
+++ b/src/llama.cpp
@@ -7478,7 +7478,7 @@ static struct ggml_tensor * llm_build_kqv(

         cur = ggml_flash_attn_ext(ctx, q, k, v, kq_mask, kq_scale, hparams.f_max_alibi_bias);

-        if (model.arch == LLM_ARCH_PHI2 || model.arch == LLM_ARCH_PHI3 || model.arch == LLM_ARCH_GPTNEOX) {
+        if (model.arch == LLM_ARCH_PHI2 || model.arch == LLM_ARCH_PHI3 || model.arch == LLM_ARCH_GPTNEOX || model.arch == LLM_ARCH_QWEN2) {
             ggml_flash_attn_ext_set_prec(cur, GGML_PREC_F32);
         }

@@ -7487,7 +7487,7 @@ static struct ggml_tensor * llm_build_kqv(
         struct ggml_tensor * kq = ggml_mul_mat(ctx, k, q);
         cb(kq, "kq", il);

-        if (model.arch == LLM_ARCH_PHI2 || model.arch == LLM_ARCH_PHI3 || model.arch == LLM_ARCH_GPTNEOX) {
+        if (model.arch == LLM_ARCH_PHI2 || model.arch == LLM_ARCH_PHI3 || model.arch == LLM_ARCH_GPTNEOX || model.arch == LLM_ARCH_QWEN2) {
             // for this arch, we need to perform the KQ multiplication with F32 precision, otherwise we get NaNs
             // ref: https://github.com/ggerganov/llama.cpp/pull/4490#issuecomment-1859055847
             ggml_mul_mat_set_prec(kq, GGML_PREC_F32);
anunknowperson commented 1 week ago

I have some time for tests so I will try different commits. Please note that, unlike @tim-janik, the model generates garbage in any case on two PCs, not just in some situations.

All tested with Q4_K_M, 8k context, same prompt. This time i don't have access to P40, so testing fully on RAM.

KoboldCpp 1.68 (still doesn't work): image

Llama.cpp b3258 (doesn't work) image

a27aa50ab7e07fe46aae619076b6e31d5663e914 Llama.cpp commit built with simple "make" on windows (everything default, no cuda). Works (slow though). image

anunknowperson commented 1 week ago

KoboldCpp 1.68 8k context, mmap off, mmq off - works. Fully on CPU. Can't test P40 unfortunately now.

image

shibe2 commented 1 week ago

I got wrong output from Qwen2-72B-Instruct Q5_K_M with hipBLAS back-end and good output with Vulkan back-end.

tim-janik commented 1 week ago

If I recall correctly, this model had issues with precision in the attention. Can you test if this change fixes it?

diff --git a/src/llama.cpp b/src/llama.cpp
index 3dc0f853..310f46c8 100644
--- a/src/llama.cpp
+++ b/src/llama.cpp
@@ -7478,7 +7478,7 @@ static struct ggml_tensor * llm_build_kqv(

         cur = ggml_flash_attn_ext(ctx, q, k, v, kq_mask, kq_scale, hparams.f_max_alibi_bias);

-        if (model.arch == LLM_ARCH_PHI2 || model.arch == LLM_ARCH_PHI3 || model.arch == LLM_ARCH_GPTNEOX) {
+        if (model.arch == LLM_ARCH_PHI2 || model.arch == LLM_ARCH_PHI3 || model.arch == LLM_ARCH_GPTNEOX || model.arch == LLM_ARCH_QWEN2) {
             ggml_flash_attn_ext_set_prec(cur, GGML_PREC_F32);
         }

@@ -7487,7 +7487,7 @@ static struct ggml_tensor * llm_build_kqv(
         struct ggml_tensor * kq = ggml_mul_mat(ctx, k, q);
         cb(kq, "kq", il);

-        if (model.arch == LLM_ARCH_PHI2 || model.arch == LLM_ARCH_PHI3 || model.arch == LLM_ARCH_GPTNEOX) {
+        if (model.arch == LLM_ARCH_PHI2 || model.arch == LLM_ARCH_PHI3 || model.arch == LLM_ARCH_GPTNEOX || model.arch == LLM_ARCH_QWEN2) {
             // for this arch, we need to perform the KQ multiplication with F32 precision, otherwise we get NaNs
             // ref: https://github.com/ggerganov/llama.cpp/pull/4490#issuecomment-1859055847
             ggml_mul_mat_set_prec(kq, GGML_PREC_F32);

I have applied this on top of commit d62e4aaa02540c89be8b59426340b909d02bbc9e, but the problems remain. After 29 requests I am getting random chars, and from then on, the web UI also can only generate random junk:

llama.cpp
[New UI](http://fury:8080/index-new.html)

User: Hi there...

Llama: -20.@26G'81=(5H#G8+C;:"="H:4'#,5!9-0,%9#-B!>6D2$)=B"H<.3#)7G8>/+9$FDE=A9A$F<"D%H%20%.F>'0+E3,B$=),00292<A%2<135/%6!2B/D&D"2&C.'8<4H!A:#<"B#A46"#3)$'2!7B)C<G27E8C>!51/E0E#E(,-3%A%;.-/-5G)H0F@=,(F1+(+F.#87)E:#;6$2E"%7D@:C55",2;+;=6BHBDA7#DC9D4#2##5D.,)B2'8A;<B"@#B7,A#C7"'!70A2(F70)97<00C4)82*=#B

Powered by [llama.cpp](https://github.com/ggerganov/llama.cpp) and [ggml.ai](https://ggml.ai/).

How is it possible, that follow-up requests to the server are affected at all? Is there some reset of state missing between requests maybe?

anunknowperson commented 1 week ago

KoboldCpp 1.68 8k context, mmap off, mmq off - works. Fully on CPU. Can't test P40 unfortunately now.

Update: I tested it again and it doesn't work actually. If I ask anything more complex than "Hi", "Hello", etc. it starts generating random characters again (This behaviour is similar to what I initially observed).

honuvo commented 1 week ago

Hi, to add to this: My observations are the same. Random tokens after a few messages. I may have some new information. I don't have those problems when using a non-K-quant. Regular Q4_0, Q4_1 and Q5_0 worked for me. Triggered by all ongoing discussions I wanted to be sure how the quants I use are made, so I downloaded the original weights (of a Qwen2-72B-Instruct fine-tune), converted it with "convert-hf-to-gguf.py" with f32-precision to gguf and quantized it to Q4_0, Q4_1, Q5_0 as well as Q4_K_M. I also did a Q8_0 quant that I used to requantize to Q4_K_M. In both cases the K-quant was producing gibberish after a good message or two. My system can't handle larger quants, so more testing is not possible for me (RTX3070 Laptop 8GB VRAM, 64GB RAM). I tested with and without offloading to GPU, with and without flashattention and with and without CuBLAS. (I'm using koboldcpp_cu12_v1.68 btw). Will test with llamacpp when I find the time and the forced f32-precision-commit mentioned above.

anunknowperson commented 1 week ago

I downloaded Q5_K_M and it has exactly same problem. latest llama.cpp built with default parameters though works perfectly, but slow (q4km was like that too).

anunknowperson commented 1 week ago

Looks like Q5km works ok in latest koboldcpp if I use OpenBLAS instead of CuBLAS for prompt processing. It's slow, though. EDIT: CLBlast works too on GPU.

honuvo commented 5 days ago

Looks like Q5km works ok in latest koboldcpp if I use OpenBLAS instead of CuBLAS for prompt processing. It's slow, though. EDIT: CLBlast works too on GPU.

Wow, indeed, using clblas doesn't trigger the problem. It is slower, but for my case still usable. So the problem lies in the cuda-code to handle inference?

henk717 commented 3 days ago

On the Koboldcpp side i can confirm the issue happens on CuBLAS if Flash Attention is not activated, if Flash Attention is activated the issue goes away. Similarly if I use another backend such as the Vulkan backend the issue also goes away.

Update: I think its best to treat this as an entirely new issue with the CUDA backend rather than a regression. I tested all the way back on Koboldcpp 1.56 dated Jan 27 which is the first version I can get this model to run on and it already had the issue.

Just like the versions released since if I switch to Vulkan the problem is bypassed. That version did not have Flash Attention yet so that I can't test.

anunknowperson commented 3 days ago

I don't now if it's directly related to flash attention, because for me, as I said, disabling or enabling flash attention doesn't change anything. Model still generates garbare from the first message. The only way I can get rid of this behaviour is changing prompt processing backend. (I tested both with offloading and not). Cuda 11.

flastir commented 3 days ago

It appears that the Qwen2-72B-Instuct-Q5_K_M model stopped functioning correctly after release b3091.

Result in the release b3091:

User: who are you?
Llama: I am Llama, a friendly and helpful chatbot designed to assist you with any questions or tasks you might have. My goal is to make your day easier by providing accurate information and engaging in meaningful conversations.

Result in the next release b3130:

User: who are you?
Llama: !t 10.
The- .4s híst, and of
A) ()/b0–all.
...

In the latest release that I checked b3291 the model still doesn't work correctly.

The GGUF model was downloaded from here.

I started the server from llama-b3091-bin-win-cuda-cu12.2.0-x64.zip with this command:

server.exe -m Qwen2-72B-Instuct-Q5_K_M-00001-of-00002.gguf -c 2048 -ngl 0 -fa

The GGUF works normally in KoboldCpp v1.69.1 only if "Use FlashAttention" is checked and "Use QuantMatMul (mmq)" is unchecked.

flastir commented 3 days ago

Looks like Q5km works ok in latest koboldcpp if I use OpenBLAS instead of CuBLAS for prompt processing. It's slow, though. EDIT: CLBlast works too on GPU.

Try CuBLAS with "Use FlashAttention" checked and "Use QuantMatMul (mmq)" unchecked. This works in my case. KoboldCpp v1.69.1. Cuda 12.3

henk717 commented 2 days ago

Did an extra test just to make sure, and what I expected is true. If you use the CUDA 11 build of Koboldcpp Flash Attention can't be used the same as it can with CUDA12 since it needs some CUDA12 features and otherwise uses fallbacks. In the CUDA 11 build turning flash attention on does nothing for me, while in the CUDA12 build it bypasses the bug.