ggerganov / llama.cpp

LLM inference in C/C++
MIT License
67.21k stars 9.64k forks source link

llama : enable FA by default and disable it per-layer #10005

Open ggerganov opened 1 week ago

ggerganov commented 1 week ago

See the discussion starting here: https://github.com/ggerganov/llama.cpp/issues/9991#issuecomment-2428407002 and the proposed solution here: https://github.com/ggerganov/llama.cpp/issues/9991#issuecomment-2428868490.

Additionally, switch to F32 precision for the K*Q matrix multiplication by default.

Marking this as good first issue as an opportunity for new contributors, but also it is kind of high priority, so we should probably implement this in a day or two if there is no progress. @slaren or @JohannesGaessler in case you already started to work on it, fill free to assign to the issue and finish it.

slaren commented 1 week ago

I will work on this as part of the changes to the model loader than I am working on. This is more complicated than it looks like, determining which layer will run the KV run on, then testing if the operation is supported, then updating all the session writing code to support per-layer v-trans would be very hard to do for new contributors without experience on the code base.

Additionally, switch to F32 precision for the K*Q matrix multiplication by default.

This can already be done separately.

Dampfinchen commented 1 week ago

With partial offloading, FA slows down generation speed significantly, though. So I wonder if it's a good idea to make FA the default, considering how the killer feature of llama.cpp is the fast CPU offloading.

I can provide benchmarks, if needed.

JohannesGaessler commented 1 week ago

Despite user reports I have never been able to reproduce and investigate the issue where FlashAttention is detrimental with partial offloading.

Dampfinchen commented 1 week ago

Despite user reports I have never been able to reproduce and investigate the issue where FlashAttention is detrimental with partial offloading.

I was investigating this further and I could not reproduce it using llama-bench. Token generation and prompt processing were both faster using llama-bench which goes completely against my real world experience. The token generation was 5.5 token/s with 20 layers offloaded for Mistral Nemo q4_k_s. Something is wrong here.

However, I noted something curious. Even increasing the prompt size from the default 512 to 8096 using the command -p, the token generation speed would not change using llama-bench. This doesn't make sense, as in real world experience the token generation is significantly faster when using a smaller prompt compared to a bigger one. So I've tried llama.cpp server.

With llama-server, this were my results:

 ./llama-server -m "D:\KI\LLMs\Mistral-Nemo-Instruct-2407-Q4_K_S.gguf" -t 6 -c 8096 -fa -ngl 20 -b 512 --host 127.0.0.1 --port 8080

prompt eval time =   11302.73 ms /  4716 tokens (    2.40 ms per token,   417.24 tokens per second)
       eval time =   41852.62 ms /   128 tokens (  326.97 ms per token,     3.06 tokens per second)
      total time =   53155.35 ms /  4844 tokens

 ./llama-server -m "D:\KI\LLMs\Mistral-Nemo-Instruct-2407-Q4_K_S.gguf" -t 6 -c 8096 -ngl 20 -b 512 --host 127.0.0.1 --port 8080

prompt eval time =   12661.91 ms /  4764 tokens (    2.66 ms per token,   376.25 tokens per second)
       eval time =   30581.22 ms /   128 tokens (  238.92 ms per token,     4.19 tokens per second)
      total time =   43243.14 ms /  4892 tokens
srv  update_slots: all slots are idle

Like I have expected, there is a drastic slowdown in terms of token generation. If you were not able to reproduce this, perhaps you were using llama-bench as well? Could you try sending a prompt with the size of 8096 tokens to llama-server and see if FA slows down token generation speed for you as well?

If by then you are still not able to reproduce the issue then perhaps the slowdown may be related to GPU architecture since I'm using Turing.

Theoretically I should be able to reproduce the issue using llama-cli as well, I may test this later today.

slaren commented 1 week ago

There is too much noise in the server/cli performance measurements to be used to draw any conclusions from it, that's one of the reasons llama-bench exists in the first place.

Dampfinchen commented 1 week ago

There is too much noise in the server/cli performance measurements to be used to draw any conclusions from it, that's one of the reasons llama-bench exists in the first place.

Shouldn't real world performance count more in this case? Something seems to be off with the llama-bench results. I'm seeing and experiencing the slower token generation in all of my use cases, regardless what noise there may be and thus I always turn FA off when using partial offloading.

Wouldn't you agree that 5.5 token/s on a 8096 context prompt using a 12B model (Nemo in this case) seems unrealistic on an RTX 2060 6 GB with just 20/41 layers offloaded?

If a benchmark is that far off of real world results then in my opinion, there's not much worth to it. Given the results, I suspect llama-bench may be hard coded to only use a context of 512 for token generation, regardless of the prompt size. It may very well be that the slower token generation with FA and partial offloading is an issue that only appears when using higher context sizes, in that case the results of llama-bench would not suffice to investigate this issue properly.

slaren commented 1 week ago

If you are expecting the value of -p to affect tg results in llama-bench, you are misunderstanding the way it works. If you want to test pp followed by tg, you need to use -pg.

Dampfinchen commented 1 week ago

If you are expecting the value of -p to affect tg results in llama-bench, you are misunderstanding the way it works. If you want to test pp followed by tg, you need to use -pg.

Thank you. I was able to reproduce it on llama-bench as well now. And indeed, it's a clear slowdown.

./llama-bench -m D:\KI\LLMs\Mistral-Nemo-Instruct-2407-Q4_K_S.gguf -ngl 17 -pg 8096,128 -fa 1

ggml_cuda_init: GGML_CUDA_FORCE_MMQ:    no
ggml_cuda_init: GGML_CUDA_FORCE_CUBLAS: no
ggml_cuda_init: found 1 CUDA devices:
  Device 0: NVIDIA GeForce RTX 2060, compute capability 7.5, VMM: yes
| model                          |       size |     params | backend    | ngl | fa |          test |                  t/s |
| ------------------------------ | ---------: | ---------: | ---------- | --: | -: | ------------: | -------------------: |
| llama 13B Q4_K - Small         |   6.62 GiB |    12.25 B | CUDA       |  17 |  1 |         pp512 |        474.73 ± 2.64 |
| llama 13B Q4_K - Small         |   6.62 GiB |    12.25 B | CUDA       |  17 |  1 |         tg128 |          5.25 ± 0.14 |
| llama 13B Q4_K - Small         |   6.62 GiB |    12.25 B | CUDA       |  17 |  1 |  pp8096+tg128 |         98.21 ± 1.64 |

build: 2f8bd2b9 (3976)

 ./llama-bench -m D:\KI\LLMs\Mistral-Nemo-Instruct-2407-Q4_K_S.gguf -ngl 17 -pg 8096,128 -fa 0

ggml_cuda_init: GGML_CUDA_FORCE_MMQ:    no
ggml_cuda_init: GGML_CUDA_FORCE_CUBLAS: no
ggml_cuda_init: found 1 CUDA devices:
  Device 0: NVIDIA GeForce RTX 2060, compute capability 7.5, VMM: yes
| model                          |       size |     params | backend    | ngl |          test |                  t/s |
| ------------------------------ | ---------: | ---------: | ---------- | --: | ------------: | -------------------: |
| llama 13B Q4_K - Small         |   6.62 GiB |    12.25 B | CUDA       |  17 |         pp512 |        451.15 ± 7.14 |
| llama 13B Q4_K - Small         |   6.62 GiB |    12.25 B | CUDA       |  17 |         tg128 |          5.10 ± 0.07 |
| llama 13B Q4_K - Small         |   6.62 GiB |    12.25 B | CUDA       |  17 |  pp8096+tg128 |        137.83 ± 1.04 |

build: 2f8bd2b9 (3976)

For some reason, the benchmark combines pp and tg into one instead of showing them seperately. And since prompt processing with FA is faster period, that will decrease the difference between tg FA on and off. Still, 137.83 token/s with FA off vs 98 token/s with it on shows the slowdown clearly.

slaren commented 1 week ago

I can reproduce this as well, the CPU flast attn does not scale well with context size. CUDA does not have the same problem.

It would be possible to disable fattn in the CPU layers, and keep it enabled for the CUDA layers, but I would prefer if we didn't have to add new parameters to configure this. We could add an environment variable to enable the CPU fattn in case memory usage is more important, and leave it disabled by default.

These results are obtained with a modified version of llama-bench to show the t/s for the tg part only:

model size params backend threads fa test t/s
llama 7B Q4_0 3.56 GiB 6.74 B CPU 16 0 pp128+tg32 20.03 ± 0.00
llama 7B Q4_0 3.56 GiB 6.74 B CPU 16 0 pp256+tg32 19.67 ± 0.00
llama 7B Q4_0 3.56 GiB 6.74 B CPU 16 0 pp512+tg32 18.92 ± 0.00
llama 7B Q4_0 3.56 GiB 6.74 B CPU 16 0 pp1024+tg32 17.94 ± 0.00
llama 7B Q4_0 3.56 GiB 6.74 B CPU 16 0 pp2048+tg32 15.68 ± 0.00
llama 7B Q4_0 3.56 GiB 6.74 B CPU 16 0 pp4096+tg32 12.68 ± 0.00
llama 7B Q4_0 3.56 GiB 6.74 B CPU 16 0 pp8192+tg32 9.10 ± 0.00
llama 7B Q4_0 3.56 GiB 6.74 B CPU 16 1 pp128+tg32 20.16 ± 0.00
llama 7B Q4_0 3.56 GiB 6.74 B CPU 16 1 pp256+tg32 19.44 ± 0.00
llama 7B Q4_0 3.56 GiB 6.74 B CPU 16 1 pp512+tg32 18.25 ± 0.00
llama 7B Q4_0 3.56 GiB 6.74 B CPU 16 1 pp1024+tg32 16.27 ± 0.00
llama 7B Q4_0 3.56 GiB 6.74 B CPU 16 1 pp2048+tg32 13.23 ± 0.00
llama 7B Q4_0 3.56 GiB 6.74 B CPU 16 1 pp4096+tg32 9.11 ± 0.00
llama 7B Q4_0 3.56 GiB 6.74 B CPU 16 1 pp8192+tg32 5.82 ± 0.00

build: ff252ea4 (3978)

Device 0: NVIDIA GeForce RTX 3090 Ti, compute capability 8.6, VMM: yes model size params backend ngl fa test t/s
llama 7B Q4_0 3.56 GiB 6.74 B CUDA 99 0 pp128+tg32 137.38 ± 2.14
llama 7B Q4_0 3.56 GiB 6.74 B CUDA 99 0 pp256+tg32 137.80 ± 2.03
llama 7B Q4_0 3.56 GiB 6.74 B CUDA 99 0 pp512+tg32 127.80 ± 1.68
llama 7B Q4_0 3.56 GiB 6.74 B CUDA 99 0 pp1024+tg32 126.12 ± 0.53
llama 7B Q4_0 3.56 GiB 6.74 B CUDA 99 0 pp2048+tg32 114.80 ± 1.08
llama 7B Q4_0 3.56 GiB 6.74 B CUDA 99 0 pp4096+tg32 98.33 ± 0.93
llama 7B Q4_0 3.56 GiB 6.74 B CUDA 99 0 pp8192+tg32 74.14 ± 0.43
llama 7B Q4_0 3.56 GiB 6.74 B CUDA 99 1 pp128+tg32 149.18 ± 2.81
llama 7B Q4_0 3.56 GiB 6.74 B CUDA 99 1 pp256+tg32 144.97 ± 0.78
llama 7B Q4_0 3.56 GiB 6.74 B CUDA 99 1 pp512+tg32 144.01 ± 2.93
llama 7B Q4_0 3.56 GiB 6.74 B CUDA 99 1 pp1024+tg32 135.53 ± 1.45
llama 7B Q4_0 3.56 GiB 6.74 B CUDA 99 1 pp2048+tg32 127.47 ± 1.29
llama 7B Q4_0 3.56 GiB 6.74 B CUDA 99 1 pp4096+tg32 106.04 ± 2.50
llama 7B Q4_0 3.56 GiB 6.74 B CUDA 99 1 pp8192+tg32 82.70 ± 2.55

build: ff252ea4 (3978)

JohannesGaessler commented 1 week ago

I can reproduce the CPU FA performance issue using LLaMA 3 8b q4_0, both with and without -pg. I'm very certain I had tested this multiple time previously with negative results; there may at some point have been a performance regression.

It would be possible to disable fattn in the CPU layers, and keep it enabled for the CUDA layers, but I would prefer if we didn't have to add new parameters to configure this.

If possible I think we should try to better optimize the CPU FA implementation for a batch size of 1. If FA is consistently faster there is no issue with making it the default where available.

Dampfinchen commented 6 days ago

If possible I think we should try to better optimize the CPU FA implementation for a batch size of 1. If FA is consistently faster there is no issue with making it the default where available.

From an enduser perspective, I think that would be awesome. Then I could also use quantized kv cache with partially offloaded models which would result in a nice speedup.

I do wonder though, isn't what Slaren suggested basically mentioned by ggerganov in the OP? I'm pretty sure he was aware of the slowdown considering he was suggesting to disable FA per layer and using FP32 for matrix multiplication which is cpu friendly. I was thinking about this before I've started this whole discussion about the slowdown.

Perhaps GG knew all along!

I can reproduce the CPU FA performance issue using LLaMA 3 8b q4_0, both with and without -pg. I'm very certain I had tested this multiple time previously with negative results; there may at some point have been a performance regression.

That is strange because I do not remember a time where FA did not slow down partially offloaded models.

ggerganov commented 5 days ago

If possible I think we should try to better optimize the CPU FA implementation for a batch size of 1. If FA is consistently faster there is no issue with making it the default where available.

Yes, that should be the ideal solution. For larger batch size, CPU FA already seems to perform better, so we need to fix BS=1.

JohannesGaessler commented 5 days ago

For large batch sizes the entire layer is evaluated on the GPU so the CPU code is not being invoked at all.

slaren commented 5 days ago

I wonder if it would be worth adding an option for offloading only the KV. When working on #7315 I tested this briefly, and the performance with large batches was very close to full GPU offload, so there are some advantages to keeping the weights on the CPU and offloading the KV instead.

ggerganov commented 5 days ago

For large batch sizes the entire layer is evaluated on the GPU so the CPU code is not being invoked at all.

Oh right, I missed that.

I wonder if it would be worth adding an option for offloading only the KV.

Do you think that full KV offload + partial weight offload would perform better than what we currently have with partial offload at BS=1? If so, then we should definitely add it.

slaren commented 5 days ago

The main benefit would be for processing large batches, not BS=1. Essentially if the batch size is large enough, then the cost of copying the weights to VRAM can be hidden almost completely by doing it asynchronously while other operations are running. However, unless the KV is also kept on VRAM, that still causes a stall on every layer, which destroys the overall performance. But if the weights are kept on system RAM, and the KV is fully offloaded to VRAM, then it is possible to obtain similar performance with large batches as if the weights were offloaded to VRAM instead. For BS=1 I guess it could help with very large contexts, since eventually as the context size grows, the attention dominates the time spent during inference, but I don't think it would be very useful overall for this case.

JohannesGaessler commented 5 days ago

For BS=1 I guess it could help with very large contexts, since eventually as the context size grows, the attention dominates the time spent during inference, but I don't think it would be very useful overall for this case.

No, I think for batch size 1 weights will always be a better choice to offload first. The most important factor is minimizing the amount of data that you need to load from RAM. It doesn't really matter whether that data is weights or KV cache. But the amount of memory allocated for the KV cache is always larger than the memory that you need to load if the context isn't full. So it makes more sense to prioritize the weights where always 100% of the allocated memory needs to be loaded.

slaren commented 5 days ago

It doesn't really matter whether that data is weights or KV cache.

In theory this is true, the KV could also be copied asynchronously like weights, but in practice implementing this is very complicated due to the way the KV cache is handled in ggml.