Open phymbert opened 7 months ago
@ggerganov do I configure something wrongly ?
32 sequences, each with 4096 tokens requires a KV cache of size 32*4096 = 131072
in order to handle the worst case, so this setup should theoretically run out of KV cache slots. Not sure about the error that you get though, but it could be related
Another thing to look into is the hypothesis that we are evaluating batches partially:
https://github.com/ggerganov/llama.cpp/issues/6617#issuecomment-2051618514
I suspect that this logic does not always work as expected because the batch splitting in llama_decode means that a part of the batch may have been evaluated even if an error is returned.
Do you get the error with equal batch
and ubatch
?
Yes sorry 4096 max tokens to predict is misleading, it's a max context of 1024 tokens per sequence in this use case. I will test to adjust the ubatch size.
Althought, this is the first time I do performance tests on 2 A100 and an 70B F16. Before I was usually using only one device and Q4_0, so I suspect the split is not working as expected, maybe @slaren has an idea.
Do you have the full call stack?
no :/ only the last line. I have built with Release and the server is running in a kubernetes cluster.
The size of a hash table is probably being underestimated. However without the call stack it is hard to figure which one.
Do you know how I can get the stack trace ? The server is running in a docker nvidia 11.7.1 runtime ubuntu 20.04 image
It should give you a full call stack automatically if gdb is installed. Otherwise just run gdb --args ./server ...
and type bt
to get the call stack when it crashes. It will be more accurate with a debug build, but probably even in release will be enough to figure the issue.
Here it is (manually copied):
ggml_hash_find_or_insert
ggml_backend_sched_split_graph
ggml_backend_sched_alloc_graph
ggml_backend_sched_graph_compute_async
llama_kv_cache_update_internal
llama_decode_internal
llama_decode
server_context::update_slots
server_queue::start_loop
main
Also, please note I downgrade from CUDA 12 to 11 recently, but I will try again with cuda 12, I do not think it is related.
@ggerganov Same crash with n_batch
==n_ubatch
=256
Is there a simple command that I can run to reproduce this issue? It seems that the issue is the size of the KV defrag graph.
Can you run a 70B llama model ? we can try to reproduce with with the server bench.py
? I will send you the steps or I can try to reproduce first with mixtral8x7B or lower size model but it will take time for me to download them and split.
Meanwhile should I remove --defrag-thold 0.1
?
It seems that the issue is the size of the KV defrag graph.
This could potentially be fixed by using 2 ggml_get_rows
per layer (one for K, one for V) instead of many views per layer per non-contiguous cell ids to move, though I'm not sure how feasable this is with the transposed V cache.
But ggml_get_rows
returns an F32
tensor, which might be relatively big if used with a quantized KV cache (though if done per-layer, it might not be that bad since it's not done all at once, and intermediate tensors during inference are bigger anyway). I wonder why it was made to dequant and not simply do a row-wise memcpy
?
@ggerganov Yes removing defragmentation seems to solve, but I guess that's not expected.
I think that the issue is that ggml_backend_sched
allocates a hash table based on the actual size of the graph used during measure, but llama_kv_cache_defrag_internal
uses up to LLAMA_MAX_NODES
. Limiting the size of the hash table is important for performance, as clearing a large table has a significant cost.
@slaren Thanks a lot, tell me what's the best model for you to reproduce the issue, I will work on a reproducible way
The smallest the model the easier it will be to test, but 70b is fine.
I am not sure which model size will cause the issue, but these steps can be done for any model, I am trying to reproduce with a llama 70B F16 in an isolated environment.
Assuming you are on a debian like OS.
# May not install the latest
sudo apt install go
# Or
wget https://go.dev/dl/go1.22.2.linux-amd64.tar.gz
sudo rm -rf /usr/local/go && sudo tar -C /usr/local -xzf go1.22.2.linux-amd64.tar.gz
cd examples/server/bench
go install go.k6.io/xk6/cmd/xk6@latest
export PATH=~/go/bin:$PATH
xk6 build master \
--with github.com/phymbert/xk6-sse
Download the test dataset
cd examples/server/bench
wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json
Run the server benchmark (it includes --defrag-thold
0.1,
)
cd examples/server/bench
pip install -r requirements.txt
mkdir models
LLAMA_SERVER_BIN_PATH=../../../cmake-build-debug/bin/server python bench.py \
--runner-label local \
--name local \
--branch `git rev-parse --abbrev-ref HEAD` \
--commit `git rev-parse HEAD` \
--scenario script.js \
--duration 10m \
--hf-repo TheBloke/KafkaLM-70B-German-V0.1-GGUF \
--hf-file kafkalm-70b-german-v0.1.Q2_K.gguf \
--model-path-prefix models \
--parallel 32 \
-ngl 81 \
--batch-size 4096 \
--ubatch-size 256 \
--ctx-size 32768 \
--n-prompts 1000 \
--max-prompt-tokens 512 \
--max-tokens 1024
I will update the steps if I miss something here, but this is the general idea.
EDIT: I am not german, but this is the latest model I found on HF :D
I am sorry but what I meant by "simply command" was something like curl, this is too complicated and it is going to take too much time to reproduce.
I understand, but I think we need concurrent request to reproduce no ?
Maybe @ggerganov has a suggestion for how to trigger a worst-case defragment in a simple way.
This command triggers the error quite easily:
./parallel -m models/llama-70b-v2/ggml-model-q4_k_s.gguf -np 32 -ngl 99 -b 4096 -ub 256 -c 32768 -ns 1024 -n 300 -cb -dt 0.1
Let me know if does not work and can send you patch to trigger always
Edit: this one also triggers it and uses 13B model:
./parallel -m models/llama-13b-v2/ggml-model-q4_0.gguf -np 32 -ngl 99 -b 4096 -ub 256 -c 32768 -ns 1024 -n 300 -cb -dt 0.1
A proper fix will take some time, but this should fix it for now.
--- a/llama.cpp
+++ b/llama.cpp
@@ -10666,7 +10666,7 @@ static void llama_kv_cache_defrag_internal(struct llama_context & lctx) {
// each move requires 6*n_layer tensors (see build_defrag)
// - source view, destination view, copy operation
// - x2 for keys and values
- const uint32_t max_moves = LLAMA_MAX_NODES/(6*n_layer);
+ const uint32_t max_moves = (LLAMA_MAX_NODES - 2*n_layer)/(6*n_layer);
// determine which KV cells to move where
//
Thanks all for having taken the time to investigate and resolve this issue. I will test tomorrow.
32 sequences, each with 4096 tokens requires a KV cache of size
32*4096 = 131072
in order to handle the worst case, so this setup should theoretically run out of KV cache slots. Not sure about the error that you get though, but it could be related
@phymbert This calculation method has solved the problem of yesterday's confusion. It seems that parallel requires accurate calculation of ctx-size
, but I still don't understand the calculation methods of batch-size
and ubatch-size
. @ggerganov can help me?
This calculation method has solved the problem of yesterday's confusion.
Not sure which problem you are referencing.
The optimal values of batch-size
and ubatch-size
depend on the hardware that you have. Experiment with powers of 2 in [128, 4096]
and ubatch-size <= batch-size
The optimal values of batch-size and ubatch-size depend on the hardware that you have. Experiment with powers of 2 in [128, 4096] and ubatch-size <= batch-size
Thanks.It's helpful @ggerganov , there. https://github.com/ggerganov/llama.cpp/issues/6681
I have made different attempts, and VRAM does indeed occupy different sizes
@slaren The temporary fix is behaving well, 32 users on F16 70b with 1024 max context each for 30min. No shifting, no KV cache full. And no crash :) I have pushed eedd42e3767efb49cd497cdef3943397b42ee935 in order to retrieve it securely, but I will delete this temp branch once you feel ready to submit the target patch. FYI: 2 A100 average PP=200tk/s per sequence, TG=4,5tk/s per sequence, KV Cache usage ratio=0,78
I will also test flash attention, but this is quite out of the scope of this issue.
Context
Using latest 17e98d4c96a583d420f12046bc92102381dbd28e llama.cpp server.
Server started with a llama70b-F16 like model:
When sending 32 concurrent requests, the server crashes with:
GGML_ASSERT: /llama.cpp/ggml.c:16521: i != GGML_HASHTABLE_FULL
Backend is CUDA, on 2 A100, compute capability 80.
EDIT: The issue is related with defragmentation, quick fix: disable defragmentation