ggerganov / llama.cpp

LLM inference in C/C++
MIT License
67.49k stars 9.69k forks source link

with the newest builds i only get gibberish output #1735

Closed maddes8cht closed 1 year ago

maddes8cht commented 1 year ago

After the CUDA refactor PR #1703 by @JohannesGaessler was merged i wanted to try it out this morning and measure the performance difference on my ardware. I use my standard prompts with different models in different sizes.

I use the prebuild versions win-cublas-cu12.1.0-xx64

With the new builds I only get gibberish as a response for all prompts used and all models. It looks like a random mix of words in different languages.

On my current PC I can only use the win-avx-x64 version, here I still get normal output.

I will use the Cuda-pc again in a few hours, then I can provide sample output or more details. Am I the only one with this problem?

RahulVivekNair commented 1 year ago

Same, It gives gibberish output only when layers are offloaded to the gpu via -ngl. Without offload it works as it should.I had to roll back to then pre cuda refactor commit.

dranger003 commented 1 year ago

Same here, this only happens when offloading layers to GPU and running on CPU works fine. Also, I noticed the more GPU layers you have the more gibberish you get.

JohannesGaessler commented 1 year ago

Thank you for reporting this issue. Just to make sure: are you getting garbage outputs with all model sizes, even with 7b?

JohannesGaessler commented 1 year ago

Also to make sure: is there anyone with this issue that is compiling llama.cpp themselves or is everyone using the precompiled Windows binary?

dranger003 commented 1 year ago

I tried different models and model sizes, and they all produce gibberish using GPU layers but work fine using CPU. Also, I am compiling from the latest commit on the master branch, using Windows and cmake.

JohannesGaessler commented 1 year ago

In llama.cpp line 1158 there should be:

        vram_scratch = n_batch * MB;

Someone that is experiencing the issue please try to replace that line with this:

        vram_scratch = 4 * n_batch * MB;

Ideally the problem is just that Windows for whatever reason needs more VRAM scratch than Linux and in that case the fix would be to just use a bigger scratch buffer. Otherwise it may be that I'm accidentally relying on undefined behavior somewhere and the fix will be more difficult.

FlareP1 commented 1 year ago

Yes agree with @dranger003 above, local compile does not fix the issue. I also tried the cublas and clblas both options produce gibberish. I only have one GPU. Do I need any new command line options?

Will try the the change that @JohannesGaessler suggests above.

RahulVivekNair commented 1 year ago

Thank you for reporting this issue. Just to make sure: are you getting garbage outputs with all model sizes, even with 7b?

I've tried it with all the types of quantizations and model sizes. Still produces some weird gibberish output.

dranger003 commented 1 year ago

In llama.cpp line 1158 there should be:

        vram_scratch = n_batch * MB;

Someone that is experiencing the issue please try to replace that line with this:

        vram_scratch = 4 * n_batch * MB;

Ideally the problem is just that Windows for whatever reason needs more VRAM scratch than Linux and in that case the fix would be to just use a bigger scratch buffer. Otherwise it may be that I'm accidentally relying on undefined behavior somewhere and the fix will be more difficult.

Same issue on my end with this change.

diff --git a/llama.cpp b/llama.cpp
index 16d6f6e..e06d503 100644
--- a/llama.cpp
+++ b/llama.cpp
@@ -1155,7 +1155,7 @@ static void llama_model_load_internal(

         (void) vram_scratch;
 #ifdef GGML_USE_CUBLAS
-        vram_scratch = n_batch * MB;
+        vram_scratch = 4 * n_batch * MB;
         ggml_cuda_set_scratch_size(vram_scratch);
         if (n_gpu_layers > 0) {
             fprintf(stderr, "%s: allocating batch_size x 1 MB = %ld MB VRAM for the scratch buffer\n",
FlareP1 commented 1 year ago

In llama.cpp line 1158 there should be:

        vram_scratch = n_batch * MB;

Someone that is experiencing the issue please try to replace that line with this:

        vram_scratch = 4 * n_batch * MB;

Ideally the problem is just that Windows for whatever reason needs more VRAM scratch than Linux and in that case the fix would be to just use a bigger scratch buffer. Otherwise it may be that I'm accidentally relying on undefined behavior somewhere and the fix will be more difficult.

This does not fix the issue for me.

main: build = 635 (5c64a09)
main: seed  = 1686151114
ggml_init_cublas: found 1 CUDA devices:
  Device 0: NVIDIA GeForce RTX 3080
llama.cpp: loading model from ../models/GGML/selfee-13b.ggmlv3.q5_1.bin
llama_model_load_internal: format     = ggjt v3 (latest)
llama_model_load_internal: n_vocab    = 32001
llama_model_load_internal: n_ctx      = 2048
llama_model_load_internal: n_embd     = 5120
llama_model_load_internal: n_mult     = 256
llama_model_load_internal: n_head     = 40
llama_model_load_internal: n_layer    = 40
llama_model_load_internal: n_rot      = 128
llama_model_load_internal: ftype      = 9 (mostly Q5_1)
llama_model_load_internal: n_ff       = 13824
llama_model_load_internal: n_parts    = 1
llama_model_load_internal: model size = 13B
llama_model_load_internal: ggml ctx size =    0.09 MB
llama_model_load_internal: using CUDA for GPU acceleration
llama_model_load_internal: mem required  = 6820.77 MB (+ 1608.00 MB per state)
llama_model_load_internal: allocating batch_size x 1 MB = 2048 MB VRAM for the scratch buffer
llama_model_load_internal: offloading 20 layers to GPU
llama_model_load_internal: total VRAM used: 6587 MB
..................................................
llama_init_from_file: kv self size  = 1600.00 MB

system_info: n_threads = 5 / 16 | AVX = 1 | AVX2 = 1 | AVX512 = 0 | AVX512_VBMI = 0 | AVX512_VNNI = 0 | FMA = 1 | NEON = 0 | ARM_FMA = 0 | F16C = 1 | FP16_VA = 0 | WASM_SIMD = 0 | BLAS = 1 | SSE3 = 1 | VSX = 0 | 
sampling: repeat_last_n = 64, repeat_penalty = 1.100000, presence_penalty = 0.000000, frequency_penalty = 0.000000, top_k = 0, tfs_z = 1.000000, top_p = 0.730000, typical_p = 1.000000, temp = 0.730000, mirostat = 0, mirostat_lr = 0.100000, mirostat_ent = 5.000000
generate: n_ctx = 2048, n_batch = 512, n_predict = 2048, n_keep = 0

 ### Instruction:\
Write a detailed account on the benefits and challenges of using automated assistants based on LLMs.  Suggest how LLMs are likely to be used in the near future. What effect this will have on employment and skills needs in the workplace.  How will businesses need to adapt and evolve to maximise the benefits from this technology.\
### Response:
 benefits &amp Vertigowebsitesearch engines Search engines Search google search Google searchGoogle search engineGoogle search engine Google search engine Google search engineiallyikuwaμClientele clientele Clientele Clientele Clientele Clientele Clientele Clientele Clientele Clientele Clientele Clientele Clientele Clientele Clientele Clientele Clientele Clientele Clientele Clientele Clientele Clientele Clientele Clientele Clientele Clientele Clientele Clientele Clientele Clientele Clientele ClienteleClientele Clientele ClienteleClienteleClienteleClientele Clientele ClienteleClientele clientele Clientele ClienteleClienteleClientele Clientele Clientele ClienteleClienteleClientele ClienteleClientele Clientele ClienteleClientele Clientele ClienteleClientele Clientele ClienteleClientilehiyawaifikMT machine learning Sebastianurity securitysecurity Security Security Securityintegration integration integration integration integration integration integration integration integration integration integration integration integration integration integrationintegration integration integration integration integration integration Linkexchangeabletonikonidéortodoxyfit wittersburgidé修 connaissanceable magnpackageeuropewo meshnetworkedayoutWEvikipediawikiidéangrhythmembergelesupportente Witmaternalismsavedrabblementқreb
JohannesGaessler commented 1 year ago

Alright. I currently don't have CUDA installed on my Windows partition but I'll go ahead and install it to see if I can reproduce the issue.

dranger003 commented 1 year ago

Alright. I currently don't have CUDA installed on my Windows partition but I'll go ahead and install it to see if I can reproduce the issue.

Thanks, this is the command I use to compile on my end.

cmake -DLLAMA_CUBLAS=ON . && cmake --build . --config Release
RahulVivekNair commented 1 year ago

Alright. I currently don't have CUDA installed on my Windows partition but I'll go ahead and install it to see if I can reproduce the issue.

Is it working as intended on Linux?

JohannesGaessler commented 1 year ago

It is working as intended on my machines which all run Linux. The first step for me to make a fix is to be able to reproduce the issue and the difference in operating system is the first thing that seems worthwhile to check. Of course, if someone that is experiencing the issue could confirm whether or not they have the same problem on Linux that would be very useful to me.

FlareP1 commented 1 year ago

It is working as intended on my machines which all run Linux. The first step for me to make a fix is to be able to reproduce the issue and the difference in operating system is the first thing that seems worthwhile to check. Of course, if someone that is experiencing the issue could confirm whether or not they have the same problem on Linux that would be very useful to me.

When I compile this under WSL2 and run with -ngl 0 this works ok. When I run with -ngl 10 then I get CUDA error 2 at /home/ubuntu/llama.cpp/ggml-cuda.cu:1241: out of memory even thought I should have plenty free (10GB) only requesting 10 layers on a 7B model

However I have never run under WSL before so it might be another issue with my setup but accelerated prompt processing is ok.

Which commit do I need to pull to try a re-build before the issue occured?

JohannesGaessler commented 1 year ago

Did you revert the change that increases the size of the VRAM scratch buffer? In any case, since the GPU changes that I did are most likely the problem the last good commit should be 44f906e8537fcec965e312d621c80556d6aa9bec.

dranger003 commented 1 year ago

Did you revert the change that increases the size of the VRAM scratch buffer? In any case, since the GPU changes that I did are most likely the problem the last good commit should be 44f906e8537fcec965e312d621c80556d6aa9bec.

I don't have a Linux install to test at the moment, but on Windows I confirm commit 44f906e8537fcec965e312d621c80556d6aa9bec works fine with all GPU layers offloaded.

dranger003 commented 1 year ago

It is working as intended on my machines which all run Linux. The first step for me to make a fix is to be able to reproduce the issue and the difference in operating system is the first thing that seems worthwhile to check. Of course, if someone that is experiencing the issue could confirm whether or not they have the same problem on Linux that would be very useful to me.

When I compile this under WSL2 and run with -ngl 0 this works ok. When I run with -ngl 10 then I get CUDA error 2 at /home/ubuntu/llama.cpp/ggml-cuda.cu:1241: out of memory even thought I should have plenty free (10GB) only requesting 10 layers on a 7B model

However I have never run under WSL before so it might be another issue with my setup but accelerated prompt processing is ok.

Which commit do I need to pull to try a re-build before the issue occured?

Could this be related?

WSL: CUDA error 2 at ggml-cuda.cu:359: out of memory (Fix found) https://github.com/ggerganov/llama.cpp/issues/1230

FlareP1 commented 1 year ago

I have reverted the changes and checked out the 44f906e8537fcec965e312d621c80556d6aa9bec commit. On my version of WSL2 this still does not work and gives the same out of memory error so I guess I probably have a WSL / cuda setup issue.

On Windows I can compile and the code works fine from this commit.

dranger003 commented 1 year ago

It is working as intended on my machines which all run Linux. The first step for me to make a fix is to be able to reproduce the issue and the difference in operating system is the first thing that seems worthwhile to check. Of course, if someone that is experiencing the issue could confirm whether or not they have the same problem on Linux that would be very useful to me.

Working fine on WSL2 (Ubuntu) using CUDA on commit 5c64a09. So on my end this is a Windows only bug it seems.

FlareP1 commented 1 year ago

It is working as intended on my machines which all run Linux. The first step for me to make a fix is to be able to reproduce the issue and the difference in operating system is the first thing that seems worthwhile to check. Of course, if someone that is experiencing the issue could confirm whether or not they have the same problem on Linux that would be very useful to me.

When I compile this under WSL2 and run with -ngl 0 this works ok. When I run with -ngl 10 then I get CUDA error 2 at /home/ubuntu/llama.cpp/ggml-cuda.cu:1241: out of memory even thought I should have plenty free (10GB) only requesting 10 layers on a 7B model However I have never run under WSL before so it might be another issue with my setup but accelerated prompt processing is ok. Which commit do I need to pull to try a re-build before the issue occured?

Could this be related?

WSL: CUDA error 2 at ggml-cuda.cu:359: out of memory (Fix found) #1230

Awesome tip of thanks.

Due to current cuda bug you need to set no pinned for enviroment variables. Command for it: export GGML_CUDA_NO_PINNED=1

Now the old commit works under WSL, will try the latest again

UPDATE: Yes works fine on the latest commit under WSL2, as long as you disabled pinned memory.

RahulVivekNair commented 1 year ago

Can confirm, ran under WSL and the output is as expected. Something wrong only on the windows side with the gibberish output.

JohannesGaessler commented 1 year ago

I have bad news: on my main desktop I am not experiencing the bug when using Windows. I'll try setting up llama.cpp on the other machine that I have.

FlareP1 commented 1 year ago

I have bad news: on my main desktop I am not experiencing the bug when using Windows. I'll try setting up llama.cpp on the other machine that I have.

If it helps I am using the following config.

Microsoft Windows [Version 10.0.19044.2965]

>nvidia-smi
Wed Jun  7 19:48:35 2023
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 531.61                 Driver Version: 531.61       CUDA Version: 12.1     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                      TCC/WDDM | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf            Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|=========================================+======================+======================|
|   0  NVIDIA GeForce RTX 3080       WDDM | 00000000:01:00.0  On |                  N/A |
| 40%   27C    P8               16W / 320W|    936MiB / 10240MiB |      1%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
dranger003 commented 1 year ago

And here's mine.

Microsoft Windows [Version 10.0.22621.1778]
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.98                 Driver Version: 535.98       CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                     TCC/WDDM  | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|=========================================+======================+======================|
|   0  NVIDIA GeForce RTX 3090      WDDM  | 00000000:2D:00.0  On |                  N/A |
|  0%   45C    P8              32W / 420W |  12282MiB / 24576MiB |     44%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+

And also this one.

Microsoft Windows [Version 10.0.22621.1702]
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.98                 Driver Version: 535.98       CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                     TCC/WDDM  | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|=========================================+======================+======================|
|   0  NVIDIA GeForce RTX 2070 ...  WDDM  | 00000000:01:00.0  On |                  N/A |
| N/A   52C    P8               4W /  80W |    265MiB /  8192MiB |     10%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+

+---------------------------------------------------------------------------------------+
| Processes:                                                                            |
|  GPU   GI   CI        PID   Type   Process name                            GPU Memory |
|        ID   ID                                                             Usage      |
|=======================================================================================|
+---------------------------------------------------------------------------------------+
maddes8cht commented 1 year ago

here it's

Microsoft Windows [Version 10.0.19045.3031]
(c) Microsoft Corporation. Alle Rechte vorbehalten.

C:\Users\Mathias>nvidia-smi
Wed Jun  7 21:20:43 2023
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 531.61                 Driver Version: 531.61       CUDA Version: 12.1     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                      TCC/WDDM | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf            Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|=========================================+======================+======================|
|   0  NVIDIA GeForce RTX 3060       WDDM | 00000000:0A:00.0  On |                  N/A |
|  0%   41C    P8               16W / 170W|    704MiB / 12288MiB |      0%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+

maybe it is a windows 10 thing? @JohannesGaessler are you using Windows 10 or windows 11? Anyone having problems using Windows 11? Just a guess..

dranger003 commented 1 year ago

maybe it is a windows 10 thing? @JohannesGaessler are you using Windows 10 or windows 11? Anyone having problems using Windows 11? Just a guess..

I'm on Windows 11 and I have the issue, build 10.0.22621.1778 is Windows 11 btw.

JohannesGaessler commented 1 year ago

I'm now able to reproduce the issue. On my system it only occurs when I use the --config Release option. If I make a debug build by omitting the option the program produces correct results.

FlareP1 commented 1 year ago

I have bad news: on my main desktop I am not experiencing the bug when using Windows. I'll try setting up llama.cpp on the other machine that I have.

@JohannesGaessler does the pre-bulit windows release work for you or only your local compile? Wondering if it is a compiler configuration issue on windows that is also effecting the pre-built binaries?

Is there anyone else that can run the latest cuda builds on windows without this error?

EDIT: Oh I see you found part of the issue, so is it an uninitialised variable or somthing that is diff with a release build

Folko-Ven commented 1 year ago

Same problem here:

Windows 10 Version 10.0.19045.2846

+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.98                 Driver Version: 535.98       CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                     TCC/WDDM  | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|=========================================+======================+======================|
|   0  NVIDIA GeForce RTX 3070 Ti   WDDM  | 00000000:01:00.0  On |                  N/A |
|  0%   35C    P8              13W / 290W |    635MiB /  8192MiB |      4%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
JohannesGaessler commented 1 year ago

EDIT: Oh I see you found part of the issue, so is it an uninitialised variable or somthing that is diff with a release build

No, what I think is happening is that both the Linux build and the Windows build have a defect, it's just that this defect is only manifesting on Windows as a bug, and only if the compiler does certain optimizations. I've narrowed down the commit that introduced the bug and I've noticed that with only a single GPU layer the output is different but not gibberish. So what I think is happening is that somewhere something is writing to the wrong memory and on Linux this randomly does not make a difference.

mirek190 commented 1 year ago

Yep the same problem. Newest build not working at all and one previous ( this ) is giving nonsense with -ngl. Without -ngl is fine.

With -nlg

PS F:\LLAMA\llama.cpp> build\bin\main --model models/new1/VicUnlocked-Alpaca-65B.ggmlv3.q5_1.bin --mlock --color --threads 8 --keep -1 --batch_size 512 --n_predict -1 --top_k 10000 --top_p 0.9 --temp 0.96 --repeat_penalty 1.1 --ctx_size 2048 --interactive --instruct --reverse-prompt "### Human:" --reverse-prompt "### User:" --reverse-prompt "### Assistant:" -ngl 36 main: build = 632 (35a8491) main: seed = 1686175203 ggml_init_cublas: found 1 CUDA devices: Device 0: NVIDIA GeForce RTX 3090 llama.cpp: loading model from models/new1/VicUnlocked-Alpaca-65B.ggmlv3.q5_1.bin llama_model_load_internal: format = ggjt v3 (latest) llama_model_load_internal: n_vocab = 32000 llama_model_load_internal: n_ctx = 2048 llama_model_load_internal: n_embd = 8192 llama_model_load_internal: n_mult = 256 llama_model_load_internal: n_head = 64 llama_model_load_internal: n_layer = 80 llama_model_load_internal: n_rot = 128 llama_model_load_internal: ftype = 9 (mostly Q5_1) llama_model_load_internal: n_ff = 22016 llama_model_load_internal: n_parts = 1 llama_model_load_internal: model size = 65B llama_model_load_internal: ggml ctx size = 0.18 MB llama_model_load_internal: using CUDA for GPU acceleration llama_model_load_internal: mem required = 29437.96 MB (+ 5120.00 MB per state) llama_model_load_internal: allocating batch_size x 1 MB = 512 MB VRAM for the scratch buffer llama_model_load_internal: offloading 36 layers to GPU llama_model_load_internal: total VRAM used: 21359 MB .................................................................................................... llama_init_from_file: kv self size = 5120.00 MB

system_info: n_threads = 8 / 16 | AVX = 1 | AVX2 = 1 | AVX512 = 0 | AVX512_VBMI = 0 | AVX512_VNNI = 0 | FMA = 1 | NEON = 0 | ARM_FMA = 0 | F16C = 1 | FP16_VA = 0 | WASM_SIMD = 0 | BLAS = 1 | SSE3 = 1 | VSX = 0 | main: interactive mode on. Reverse prompt: '### Human:' Reverse prompt: '### User:' Reverse prompt: '### Assistant:' Reverse prompt: '### Instruction:

' sampling: repeat_last_n = 64, repeat_penalty = 1.100000, presence_penalty = 0.000000, frequency_penalty = 0.000000, top_k = 10000, tfs_z = 1.000000, top_p = 0.900000, typical_p = 1.000000, temp = 0.960000, mirostat = 0, mirostat_lr = 0.100000, mirostat_ent = 5.000000 generate: n_ctx = 2048, n_batch = 512, n_predict = -1, n_keep = 2

== Running in interactive mode. ==

hello hello Stackoverflowustom remilt mir bezeichneteritutionengMDrez dedic Gul contraryglass

without -ngl

PS F:\LLAMA\llama.cpp> build\bin\main --model models/new1/VicUnlocked-Alpaca-65B.ggmlv3.q5_1.bin --mlock --color --threads 8 --keep -1 --batch_size 512 --n_predict -1 --top_k 10000 --top_p 0.9 --temp 0.96 --repeat_penalty 1.1 --ctx_size 2048 --interactive --instruct --reverse-prompt "### Human:" --reverse-prompt "### User:" --reverse-prompt "### Assistant:" main: build = 632 (35a8491) main: seed = 1686175409 ggml_init_cublas: found 1 CUDA devices: Device 0: NVIDIA GeForce RTX 3090 llama.cpp: loading model from models/new1/VicUnlocked-Alpaca-65B.ggmlv3.q5_1.bin llama_model_load_internal: format = ggjt v3 (latest) llama_model_load_internal: n_vocab = 32000 llama_model_load_internal: n_ctx = 2048 llama_model_load_internal: n_embd = 8192 llama_model_load_internal: n_mult = 256 llama_model_load_internal: n_head = 64 llama_model_load_internal: n_layer = 80 llama_model_load_internal: n_rot = 128 llama_model_load_internal: ftype = 9 (mostly Q5_1) llama_model_load_internal: n_ff = 22016 llama_model_load_internal: n_parts = 1 llama_model_load_internal: model size = 65B llama_model_load_internal: ggml ctx size = 0.18 MB llama_model_load_internal: using CUDA for GPU acceleration llama_model_load_internal: mem required = 50284.21 MB (+ 5120.00 MB per state) llama_model_load_internal: offloading 0 layers to GPU llama_model_load_internal: total VRAM used: 512 MB .................................................................................................... llama_init_from_file: kv self size = 5120.00 MB

system_info: n_threads = 8 / 16 | AVX = 1 | AVX2 = 1 | AVX512 = 0 | AVX512_VBMI = 0 | AVX512_VNNI = 0 | FMA = 1 | NEON = 0 | ARM_FMA = 0 | F16C = 1 | FP16_VA = 0 | WASM_SIMD = 0 | BLAS = 1 | SSE3 = 1 | VSX = 0 | main: interactive mode on. Reverse prompt: '### Human:' Reverse prompt: '### User:' Reverse prompt: '### Assistant:' Reverse prompt: '### Instruction:

' sampling: repeat_last_n = 64, repeat_penalty = 1.100000, presence_penalty = 0.000000, frequency_penalty = 0.000000, top_k = 10000, tfs_z = 1.000000, top_p = 0.900000, typical_p = 1.000000, temp = 0.960000, mirostat = 0, mirostat_lr = 0.100000, mirostat_ent = 5.000000 generate: n_ctx = 2048, n_batch = 512, n_predict = -1, n_keep = 2

== Running in interactive mode. ==

hello Hello Human!!! Nice to meet you. What can I do for you today.

vsalibrary commented 1 year ago

Same problem here, but when running the metal build on a macbook air. Have tried multiple models and settings, and it only happens when I try to use the -ngl option.

JohannesGaessler commented 1 year ago

I will need to investigate some more but the problem may unironically be caused by a bug in the Windows nvcc compiler regarding instruction reordering or optimizing out local variables. Adding an assert for one of the variables I use to determine the data indices changes whether or not an assert for another variable gets triggered which really shouldn't happen.

FlareP1 commented 1 year ago

I will need to investigate some more but the problem may unironically be caused by a bug in the Windows nvcc compiler regarding instruction reordering or optimizing out local variables. Adding an assert for one of the variables I use to determine the data indices changes whether or not an assert for another variable gets triggered which really shouldn't happen.

Good luck with your investigation. Could be a compiler issue but possibly it is a memory trampling issue and the different optimisation just effects where in memory the errorneous write is occuring. This way different compilers and memory allocation is effecting the symptoms, sometimes the erroneous writing is not corrupting the output other times it is?

JohannesGaessler commented 1 year ago

I pushed a fix: https://github.com/ggerganov/llama.cpp/pull/1753 . Please try whether it works.

helkaluin commented 1 year ago

Another data point, since you mentioned this is only happening on Windows: I couldn't trigger this multi-language gibberish on Linux if I compile with just -DLLAMA_CUBLAS=1, but if I do -DLLAMA_CUBLAS=1 -DLLAMA_BLAS=1 then it starts doing so.

But just plain -DLLAMA_BLAS=1 is fine.

JohannesGaessler commented 1 year ago

It's possible that there is more than one bug though. Usually something like just one of the indices being slightly wrong results in the output being gibberish. Compiling with both LLAMA_CUBLAS and LLAMA_BLAS on my Linux system results in a compile time error using cmake and no gibberish when using make.

JohannesGaessler commented 1 year ago

Same problem here, but when running the metal build on a macbook air. Have tried multiple models and settings, and it only happens when I try to use the -ngl option.

Do you get correct results when using 44f906e8537fcec965e312d621c80556d6aa9bec?

FlareP1 commented 1 year ago

I pushed a fix: #1753 . Please try whether it works.

I gave this a go, (im not an expert in git so hope I did this right? @JohannesGaessler .

git fetch origin pull/1753/head:master
git checkout master

cd build
cmake .. -DLLAMA_CUBLAS=ON
cmake --build . --config Release

It appears to work for me so it is definatley an improvement but it is concerning looking at the code change becuase you did not change anything material that should change behaviour.

main: build = 640 (a50d003)
main: seed  = 1686218444
ggml_init_cublas: found 1 CUDA devices:
  Device 0: NVIDIA GeForce RTX 3080
llama.cpp: loading model from ../models/GGML/Manticore-13B.ggmlv2.q5_1.bin
llama_model_load_internal: format     = ggjt v2 (pre #1508)
llama_model_load_internal: n_vocab    = 32000
llama_model_load_internal: n_ctx      = 2048
llama_model_load_internal: n_embd     = 5120
llama_model_load_internal: n_mult     = 256
llama_model_load_internal: n_head     = 40
llama_model_load_internal: n_layer    = 40
llama_model_load_internal: n_rot      = 128
llama_model_load_internal: ftype      = 9 (mostly Q5_1)
llama_model_load_internal: n_ff       = 13824
llama_model_load_internal: n_parts    = 1
llama_model_load_internal: model size = 13B
llama_model_load_internal: ggml ctx size =    0.09 MB
llama_model_load_internal: using CUDA for GPU acceleration
llama_model_load_internal: mem required  = 4097.80 MB (+ 1608.00 MB per state)
llama_model_load_internal: allocating batch_size x 1 MB = 512 MB VRAM for the scratch buffer
llama_model_load_internal: offloading 32 layers to GPU
llama_model_load_internal: total VRAM used: 7774 MB
....................................................................................................
llama_init_from_file: kv self size  = 1600.00 MB

system_info: n_threads = 5 / 16 | AVX = 1 | AVX2 = 1 | AVX512 = 0 | AVX512_VBMI = 0 | AVX512_VNNI = 0 | FMA = 1 | NEON = 0 | ARM_FMA = 0 | F16C = 1 | FP16_VA = 0 | WASM_SIMD = 0 | BLAS = 1 | SSE3 = 1 | VSX = 0 |
main: interactive mode on.
Reverse prompt: 'Me:'
Reverse prompt: 'You:'
Reverse prompt: '###'
Reverse prompt: '### Instruction:

'
sampling: repeat_last_n = 64, repeat_penalty = 1.100000, presence_penalty = 0.000000, frequency_penalty = 0.000000, top_k = 40, tfs_z = 1.000000, top_p = 0.950000, typical_p = 1.000000, temp = 0.800000, mirostat = 2, mirostat_lr = 0.100000, mirostat_ent = 5.000000
generate: n_ctx = 2048, n_batch = 512, n_predict = -1, n_keep = 2

== Running in interactive mode. ==
 - Press Ctrl+C to interject at any time.
 - To return control to LLaMa, end your input with '\'.
 - To return control without starting a new line, end your input with '/'.

> tell me about llamas?
Llamas are a domesticated South American camelid, similar in appearance to alpacas but with shorter ears and a flatter face. They have been bred for their fiber, which is very soft and warm. Llamas are also used as pack animals and can carry heavy loads for long distances without complaint. Their friendly nature makes them popular among tourists and they're often seen in parades or other events. In the past, llamas were an important part of Andean culture and were even considered a sacred animal.
>
dranger003 commented 1 year ago

I pushed a fix: #1753 . Please try whether it works.

Thanks, this appears to work on my end on Win11 and will do some more testing.

IkariDevGIT commented 1 year ago

please fix this

mirek190 commented 1 year ago

Everyone is waiting for it 😅

IkariDevGIT commented 1 year ago

@mirek190 yep, its pretty annoying, just wanted to test out a new model today 😢

maddes8cht commented 1 year ago

@JohannesGaessler thanks for your great work - yes, everyone is waiting for it, so thanks for you for making it happen.

Now that I tested the buggy version compared to the last working cuda version I noticed the following (in the working version 5220a99): I give prompts to a 33b model with always the same seed and accordingly get the same output every time - as long as ngl is left at the same value.

For different values of ngl, however, I also get different responses. Is this correct? Shouldn't I get the same answer with the same seed on every PC and no matter if with or without gpu and no matter with how many layers?

JohannesGaessler commented 1 year ago

It's expected that changing the number of GPU layers changes the output. The way in which the calculation is done is slightly different because on the CPU f32 values are first quantized to q8_0 before they are multiplied with the quantized weights. On the GPU however the quantized weights are converted to f32 and then f32 x f32 multiplication is done. And once a single token changes from this small difference the entire remaining output changes as well. And as the warning says: when using cuBLAS results are not guaranteed to be reproducible anyways.

prsyahmi commented 1 year ago

I pushed a fix: https://github.com/ggerganov/llama.cpp/pull/1753 . Please try whether it works.

I've tried the fix and it is working

tcristo commented 1 year ago

I'm running this on Termux (linux) that I build in Release mode with CLBlas and this was occurring when ngl was greater than 0. Your change appears to have fixed the gibberish output issue.

JohannesGaessler commented 1 year ago

Sorry, are you saying the change that I did to ggml-cuda.cu in this PR https://github.com/ggerganov/llama.cpp/pull/1753 has fixed the issue that you were having with CLBLast? That makes absolutely no sense to me since ggml-cuda.cu doesn't even get loaded when using CLBlast.

tcristo commented 1 year ago

Yep, when I downloaded your version, compiled it, and ran it with ngl>0, I got a valid response. Before that I was getting gibberish. Same model and parameters but I did use a different prompt than before.