ggerganov / llama.cpp

LLM inference in C/C++
MIT License
64.78k stars 9.28k forks source link

Bug: KV cache load/save is slow #8915

Open josharian opened 1 month ago

josharian commented 1 month ago

What happened?

I wrote a KV cache cache, and then benchmarked it.

llama_state_seq_get_size, llama_state_seq_get_data, and llama_state_seq_set_data are slow enough that it is significantly (13x) better to just start over from nothing each time.

However, from looking through the code, I think there is opportunity to improve quite a lot. (It is unclear to me whether these improvements will be sufficient to make it worth managing an external cache, but in theory I think it ought to be possible.)

Here are a few observations, starting with just the get APIs...


llama_state_seq_get_size does a full copy from the GPU and throws it away. (My cache management implementation is in Go, so for GC/allocator reasons, I need the size up front.)

size_t llama_state_seq_get_size(struct llama_context *ctx,
                                llama_seq_id seq_id) {
  llama_data_write_dummy data_ctx;
  return llama_state_seq_get_data_internal(ctx, data_ctx, seq_id);
}

In write_kv_cache_data, we have lots of double-copying, from GPU to staging area and then staging area to destination. For example:

        tmp_buf.resize(range_size * k_size_row);
        ggml_backend_tensor_get(kv_self.k_l[il], tmp_buf.data(),
                                range.first * k_size_row,
                                range_size * k_size_row);
        write(tmp_buf.data(), tmp_buf.size());

An extremely crude benchmark suggests that this double-copy is ~5% of the runtime of llama_state_seq_get_data.


We call ggml_backend_tensor_get a lot of times. In the case in which the tensors are contiguous, it would probably be significantly faster to do a single transfer. A back of the envelope calculation about PCIe data transfer rates suggests that we are nowhere near saturating the bus, and there is very little computation going on, which suggests per-transfer latency overhead as a major culprit.


I'm using an RTX 4090 with a server-grade motherboard.

cc @abetlen cc @slaren (per suggestion of @abetlen)

Name and Version

$ ./llama-cli --version version: 3488 (75af08c4) built with cc (Ubuntu 13.2.0-23ubuntu4) 13.2.0 for x86_64-linux-gnu

What operating system are you seeing the problem on?

Linux

Relevant log output

No response

slaren commented 1 month ago

Overall these looks like good ideas to improve the performance. Enabling flash attention will make the V cache contiguous, which will reduce the number of calls to ggml_backend_tensor_get. For CUDA at least, an additional copy could be removed by reading the data into a pinned host buffer, otherwise the data is read first into a driver-managed buffer, and then copied to the user buffer. However, it seems that by far the biggest issue is doing all the reads in the get_size function.

Tagging @compilade since he did some work on this code recently.