ggerganov / llama.cpp

LLM inference in C/C++
MIT License
64.86k stars 9.29k forks source link

Question: Why prompt is being run trough the network before generating new tokens? #719

Closed python273 closed 1 year ago

python273 commented 1 year ago

As I understand, the NN doesn't have a state, so you should be able to put whatever tokens into context and start generating new tokens immediately. But right now, it first runs the NN on the prompt it seems? So with a long prompt, it takes some time until it starts to generate new tokens.

I though I was missing something, but huggingface transformers starts to generate the tokens immediately.

eliseygusev commented 1 year ago

LLM essentially continue the text they were previously given. Thus any LLM would first run on the whole prompt before predicting what the next token would be. In that sense huggingface and llama.cpp should behave similarly. The difference between HF and llama.cpp behaviour probably lies somewhere else. For example, huggingface class is usually initialized well in advance so you don't see it loading weights when you start generate. Conversely, running .main in llama.cpp first loads the weights, which can be time consuming.

python273 commented 1 year ago

I'm not talking about the weights. Even original weights are loading into memory in a few seconds on my machine.

If you put a long prompt into prompt.txt and run:

sudo ./main -m ./models/7B/ggml-model-f16.bin -n -1 --mlock --ctx_size 2048 --color -f prompt.txt

You can see initial prompt being printed slowly before it starts to generate new tokens.

cmp-nct commented 1 year ago

I'm not talking about the weights. Even original weights are loading into memory in a few seconds on my machine.

If you put a long prompt into prompt.txt and run:

sudo ./main -m ./models/7B/ggml-model-f16.bin -n -1 --mlock --ctx_size 2048 --color -f prompt.txt

You can see initial prompt being printed slowly before it starts to generate new tokens.

increase the batch size and it will goop a larger part in one go

python273 commented 1 year ago

It's faster with larger batch size, but I still don't understand why it needs to do anything with the prompt. transformers pretty much starts to generate new tokens immediately

MillionthOdin16 commented 1 year ago

It's faster with larger batch size, but I still don't understand why it needs to do anything with the prompt. transformers pretty much starts to generate new tokens immediately

You're right. There's something odd about it that's not quite working right. I've run across discussions about this in the past where people had the same reasoning as you. Unfortunately, I can't seem to find them right now. Basically it seemed like Georgio and others were aware of the issue and were trying to figure out how to resolve it. It's been a bit, but I'm sure they're planning to look more into it sometime.

It's not a simple adjust an argument fix.

cmp-nct commented 1 year ago

It's faster with larger batch size, but I still don't understand why it needs to do anything with the prompt. transformers pretty much starts to generate new tokens immediately

Your prompt is split into groups of b tokens, those are processed in parallel using n threads until the whole prompt was processed. If you have the spare memory you can use a larger -b, not sure if you actually win performance that way (I don't think so). Maybe someone familiar with the python transformer implementation can explain what they do different.

MillionthOdin16 commented 1 year ago

I wonder if I describe the issue to chat GPT and give it some context with code if it could help narrow it down.😂 But definitely we should focus on a more experienced person :)

On Sun, Apr 2, 2023, 22:29 John @.***> wrote:

It's faster with larger batch size, but I still don't understand why it needs to do anything with the prompt. transformers pretty much starts to generate new tokens immediately Your prompt is split into groups of b tokens, those are processed in parallel using n threads until the whole prompt was processed. If you have the spare memory you can use a larger -b, not sure if you actually win performance that way (I don't think so).

Maybe someone familiar with the python transformer implementation can explain what they do different.

— Reply to this email directly, view it on GitHub https://github.com/ggerganov/llama.cpp/issues/719#issuecomment-1493550054, or unsubscribe https://github.com/notifications/unsubscribe-auth/AYMC3AAYBINYB6KT4PTX4ALW7IYYTANCNFSM6AAAAAAWQPXMB4 . You are receiving this because you commented.Message ID: <ggerganov/llama. @.***>

ghost commented 1 year ago

It looks to me like eval is being called on the "prompt" embeddings just to load them into the attention K/V memory (it doesn't look like it's kept anywhere else and the KV calculations don't actually depend on the layer activations.) That could probably be factored out to its own much faster function that mutates the model as well instead of calculating and then throwing away the predictions for tokens that you already have. EDIT: Oops, missed this line: https://github.com/ggerganov/llama.cpp/blob/master/llama.cpp#L803 so that's wrong. Huggingface xformers does seem to manage doing the same thing nearly instantly (at least with gpt2) I wonder what it does differently.

python273 commented 1 year ago

Yep, seems like KV cache becomes the state, so it must run the network on the prompt. 🤔

Here's the code with transformers if anyone curious: https://gist.github.com/python273/d2b16b267104179d0718fc266f74c132

LostRuins commented 1 year ago

@python273 @notswileynet did you ever figure out why transformers is still so much (10x+) faster compared to llama.cpp for prompt ingestion even on CPU?

ggerganov commented 1 year ago

@LostRuins Is it really that fast? If this is true, then there must be something wrong in our approach. Can you provide some time comparison for a sample prompt, making sure you are using llama.cpp with BLAS enabled?

python273 commented 1 year ago

Here's my attempt at benchmarking. With transformers, I couldn't get quantized model to run on CPU, so not that fair comparison.

https://gist.github.com/python273/ca23361caf1cde9dc06bbc9acd44b22d

tldr:

7B q4 on AMD Ryzen 9 5950X 16-Core Processor: llama_print_timings: sample time = 27.27 ms / 64 runs ( 0.43 ms per run) llama_print_timings: prompt eval time = 27051.89 ms / 1151 tokens ( 23.50 ms per token) llama_print_timings: eval time = 7044.22 ms / 63 runs ( 111.81 ms per run)

7B q4 BLAS (seems to be slower) llama_print_timings: sample time = 35.45 ms / 64 runs ( 0.55 ms per run) llama_print_timings: prompt eval time = 40494.75 ms / 1151 tokens ( 35.18 ms per token) llama_print_timings: eval time = 7111.35 ms / 63 runs ( 112.88 ms per run)

7B 8bit transformers + bitsandbytes on NVIDIA GeForce RTX 2080 Ti: (this is prompt + 64 tokens gen)

generation time: 5748.10601 ms tokens consumed: 1217 gen time per token: 4.723176672144619 ms

7B 4bit transformers + GPTQ-for-LLaMa: generation time: 21058.956373 ms tokens consumed: 1217 gen time per token: 17.303990446179128 ms


if generating only 4 tokens in python:

8bit generation time: 918.090631 ms tokens consumed: 1157 gen time per token: 0.7935096205704408 ms total time: 6.507722863 s

4bit generation time: 12009.928084 ms tokens consumed: 1157 gen time per token: 10.38023170613656 ms total time: 13.905669252 s

LostRuins commented 1 year ago

Okay maybe 10x is an exaggeration especially considering BLAS, apologize for the hyperbole, but it is still significantly faster.

With the above prompt (1151 tokens) and generating only one (1) extra token, I am getting:

No BLAS = 151s BLAS = 65s HF Pytorch = 39s

(this is fully on CPU, on GPU the pytorch one is much much faster)

python273 commented 1 year ago

Can you post your python code to run on cpu? Also full output compiling and running llama.cpp might be useful.

LostRuins commented 1 year ago

Can you post your python code to run on cpu? Also full output compiling and running llama.cpp might be useful.

I am running it through KoboldAI: https://github.com/0cc4m/KoboldAI

LostRuins commented 1 year ago

Some messing around: Inside ggml_compute_forward_mul_mat_use_blas even after commenting out the cblas_sgemm and dequantize_row_q, I am still getting an overhead of about 40ms per token on 7B (so a 1024 context prompt takes ~40 seconds with the abovementioned functions replaced with NOPs) - and this overhead is scaling along with context length too. Any idea what it could be?

It's kinda significant, because with both functions enabled (BLAS) it takes about 70ms per token. So the overhead is costing more than both the mat mul and the dequantization combined.

slaren commented 1 year ago

@LostRuins I tried to replicate that on my computer and the overhead that I got was between 10 and 20ms. Most of it seems to be in small matrix multiplication. You can try the steps described at https://github.com/ggerganov/llama.cpp/wiki/GGML-Tips-&-Tricks to get a breakdown per operation.

MillionthOdin16 commented 1 year ago

It doesn't seem that this was actually resolved to me. I've seen this comment about how quickly transformers starts to generate tokens vs llama.cpp and I've never really seen an answer. I think the core of the issue is described in the first few messages in the thread, and might have been overshadowed by LostRuin's 10x exaggeration. But if this is solved, it will have a significant impact for many people.

LostRuins commented 1 year ago

Yes I admit it was nadir of me to say 10x faster especially without comparing against BLAS. I should have just said, significantly faster. I think that further gains are possible but may still be somewhat slower than a GPU alternative.

ggerganov commented 1 year ago

Okay maybe 10x is an exaggeration especially considering BLAS, apologize for the hyperbole, but it is still significantly faster.

With the above prompt (1151 tokens) and generating only one (1) extra token, I am getting:

No BLAS = 151s BLAS = 65s HF Pytorch = 39s

(this is fully on CPU, on GPU the pytorch one is much much faster)

These numbers look reasonable to me. I believe PyTorch is using MKL implementation for matrix multiplication.

Based on discussion with @guillaumekln (see https://github.com/ggerganov/whisper.cpp/discussions/589#discussioncomment-5289714) the MKL implementation is considerably faster than OpenBLAS. I think if we "plug" MKL into ggml, we might observe parity with PyTorch

LostRuins commented 1 year ago

would be awesome if that is possible and works just as well - at least for the intel users

MillionthOdin16 commented 1 year ago

Interesting, looks like #811 also includes a mention of oneDNN which is great.

Edit: ahh, but not int4

ghost commented 1 year ago

Since MKL is cblas compatible it's a easy drop-in replacement for OpenBLAS. I added this to the Makefile using the MKL link advisor using the default libmkl 2020 on Ubuntu 22. I used OpenMP threading and did not see a performance difference with TBB threading.

ifdef LLAMA_MKL
    CFLAGS  += -DGGML_USE_OPENBLAS -DMKL_ILP64 -m64 -I"/usr/include/mkl"
    LDFLAGS += -L/lib/x86_64-linux-gnu -Wl,--no-as-needed -lmkl_intel_ilp64 -lmkl_gnu_thread -lmkl_core -lgomp -lpthread -lm -ldl
endif

We can confirm that MKL is linked properly in the binary if compiled with LLAMA_MKL=on

<REDACTED>:llama.cpp$ ldd main 
    linux-vdso.so.1 (0x00007ffd83d54000)
    libmkl_intel_ilp64.so => /lib/x86_64-linux-gnu/libmkl_intel_ilp64.so (0x00007f1c45600000)
    libmkl_gnu_thread.so => /lib/x86_64-linux-gnu/libmkl_gnu_thread.so (0x00007f1c43800000)
    libmkl_core.so => /lib/x86_64-linux-gnu/libmkl_core.so (0x00007f1c3f000000)
    libgomp.so.1 => /lib/x86_64-linux-gnu/libgomp.so.1 (0x00007f1c462e3000)
    ...

I ran my usual test script which has a ~320 token prompt on Llama 13B with a batch size of 1024. Averaged out the timings at the prompt eval stage with MKL were very similar to OpenBLAS (around 150ms/token on an 16GB i5-6500).

Example MKL result:

llama_print_timings:        load time = 48678.67 ms
llama_print_timings:      sample time =   518.22 ms /   100 runs   (    5.18 ms per run)
llama_print_timings: prompt eval time = 48162.42 ms /   324 tokens (  148.65 ms per token)
llama_print_timings:        eval time = 45609.20 ms /    99 runs   (  460.70 ms per run)
llama_print_timings:       total time = 94807.68 ms

Example OpenBLAS result:

llama_print_timings:        load time = 47797.46 ms
llama_print_timings:      sample time =   505.43 ms /   100 runs   (    5.05 ms per run)
llama_print_timings: prompt eval time = 47273.33 ms /   324 tokens (  145.91 ms per token)
llama_print_timings:        eval time = 45507.38 ms /    99 runs   (  459.67 ms per run)
llama_print_timings:       total time = 93811.85 ms

For fun here's a run with no BLAS lib (generally I see a 2x improvement in prompt eval speed with OpenBLAS).

llama_print_timings:        load time = 98129.39 ms
llama_print_timings:      sample time =   474.32 ms /   100 runs   (    4.74 ms per run)
llama_print_timings: prompt eval time = 97608.26 ms /   324 tokens (  301.26 ms per token)
llama_print_timings:        eval time = 46747.59 ms /    99 runs   (  472.20 ms per run)
llama_print_timings:       total time = 145352.80 ms

As I'm using an older architecture and older (2020) version of MKL I'm curious if people are seeing actual performance improvements with a newer setup.

ggerganov commented 1 year ago

Ok, so based on the results from @LostRuins and the MKL test by @eiery , ggml is almost 2 times slower compared to PyTorch on x86. I was hoping that MKL will close the gap, but unfortunately it is not the case.

@eiery Maybe wrap this LLAMA_MKL in a PR, so other people can easily give it a try as well?

0cc4m commented 1 year ago

Rerunning tests...

ggerganov commented 1 year ago

@0cc4m What is the batch and prompt size in your experiments? (i.e. the -n and -p parameters)

0cc4m commented 1 year ago

@ggerganov These are results from a few days back and I just noticed some inconsistencies. I'll retest and update them, alongside adding the used batch and prompt size.

0cc4m commented 1 year ago

@ggerganov I apologize, my last results were wrong. Here is what I found:

EPYC 7302, 8x 16GB DDR4-3200, RTX 3060 for ClBlast and Pytorch GPU results

model:       Llama 7B (q4_1)
batch size:  512
prompt size: 2047 tokens

ggml stock:                      Processing:145.5s (71ms/T)
ggml+openblas:                   Processing:126.3s (62ms/T)
ggml+clblast:                    Processing:113.1s (55ms/T)
Pytorch CPU:                     80.9s
Pytorch GPU 4bit GPTQ-for-llama: 2.71s

I think there must be some architectural advantage to how Pytorch/transformers handles the context.

LostRuins commented 1 year ago

@eiery I want to try MKL to compare, but I can't seem to find the location to actually obtain the MKL library for windows. Any idea?

ghost commented 1 year ago

@ggerganov The linking process for MKL is complex (hence why there's a link advisor) and users need different commands depending on OS, MKL version, and so on. My current LLAMA_MKL option only works on Ubuntu with the MKL version from the package manager and probably won't work with say Mac or Windows. If people are interested in testing I would recommend they add the lines to the Makefile themselves and replace my lines with ones that match their setup.

@LostRuins You can get the Windows version here. Note that I haven't tested it in Windows.

ghost commented 1 year ago

It would also be an interesting experiment for someone who has it all set up to try compiling llama.cpp with the full Intel One system (ICC, MKL, etc.) to see what gains we can achieve. I used GCC for my tests and am not sure if using the full suite will provide additional improvements.