ggerganov / llama.cpp

LLM inference in C/C++
MIT License
64.97k stars 9.32k forks source link

Bug: RPC inference is drastically slower even on localhost #8832

Open hafezmg48 opened 1 month ago

hafezmg48 commented 1 month ago

What happened?

I am trying to run inference on RPC example. When running the llama-cli with rpc feature over a single rpc-server on localhost, the inference throughput is only 1.9 tok/sec for llama3.1-8B on CUDA, while the same llama-cli on local cuda build without rpc generates 25 tok/sec.

So about 13x slower even thought the server is in localhost, basically using same GPU locally but through rpc.

Name and Version

followed exact steps in https://github.com/ggerganov/llama.cpp/tree/master/examples/rpc

running cli with command: bin/llama-cli -m ./llama3.1-8B-F16.gguf -p "Hello, my name is" -n 64 --rpc localhost:50052 -ngl 99

running rpc-server: bin/rpc-server -p 50052 create_backend: using CUDA backend ggml_cuda_init: GGML_CUDA_FORCE_MMQ: no ggml_cuda_init: GGML_CUDA_FORCE_CUBLAS: no ggml_cuda_init: found 1 CUDA devices: Device 0: Tesla V100-PCIE-32GB, compute capability 7.0, VMM: yes Starting RPC server on 0.0.0.0:50052, backend memory: 28170 MB Accepted client connection, free_mem=29538713600, total_mem=34079899648 Client connection closed

What operating system are you seeing the problem on?

Linux

Relevant log output

Log start
main: build = 3501 (b7a08fd5)
main: built with cc (Ubuntu 9.4.0-1ubuntu1~20.04.2) 9.4.0 for x86_64-linux-gnu
main: seed  = 1722612132
llama_model_loader: loaded meta data with 27 key-value pairs and 291 tensors from ./llama3.1-8B-F16.gguf (version GGUF V3 (latest))
llama_model_loader: Dumping metadata keys/values. Note: KV overrides do not apply in this output.
llama_model_loader: - kv   0:                       general.architecture str              = llama
llama_model_loader: - kv   1:                               general.type str              = model
llama_model_loader: - kv   2:                               general.name str              = Llama3.1 8b
llama_model_loader: - kv   3:                           general.basename str              = llama3.1
llama_model_loader: - kv   4:                         general.size_label str              = 8B
llama_model_loader: - kv   5:                            general.license str              = llama3
llama_model_loader: - kv   6:                               general.tags arr[str,6]       = ["facebook", "meta", "pytorch", "llam...
llama_model_loader: - kv   7:                          general.languages arr[str,1]       = ["en"]
llama_model_loader: - kv   8:                          llama.block_count u32              = 32
llama_model_loader: - kv   9:                       llama.context_length u32              = 8192
llama_model_loader: - kv  10:                     llama.embedding_length u32              = 4096
llama_model_loader: - kv  11:                  llama.feed_forward_length u32              = 14336
llama_model_loader: - kv  12:                 llama.attention.head_count u32              = 32
llama_model_loader: - kv  13:              llama.attention.head_count_kv u32              = 8
llama_model_loader: - kv  14:                       llama.rope.freq_base f32              = 500000.000000
llama_model_loader: - kv  15:     llama.attention.layer_norm_rms_epsilon f32              = 0.000010
llama_model_loader: - kv  16:                          general.file_type u32              = 1
llama_model_loader: - kv  17:                           llama.vocab_size u32              = 128256
llama_model_loader: - kv  18:                 llama.rope.dimension_count u32              = 128
llama_model_loader: - kv  19:                       tokenizer.ggml.model str              = gpt2
llama_model_loader: - kv  20:                         tokenizer.ggml.pre str              = llama-bpe
llama_model_loader: - kv  21:                      tokenizer.ggml.tokens arr[str,128256]  = ["!", "\"", "#", "$", "%", "&", "'", ...
llama_model_loader: - kv  22:                  tokenizer.ggml.token_type arr[i32,128256]  = [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...
llama_model_loader: - kv  23:                      tokenizer.ggml.merges arr[str,280147]  = ["Ġ Ġ", "Ġ ĠĠĠ", "ĠĠ ĠĠ", "...
llama_model_loader: - kv  24:                tokenizer.ggml.bos_token_id u32              = 128000
llama_model_loader: - kv  25:                tokenizer.ggml.eos_token_id u32              = 128001
llama_model_loader: - kv  26:               general.quantization_version u32              = 2
llama_model_loader: - type  f32:   65 tensors
llama_model_loader: - type  f16:  226 tensors
llm_load_vocab: special tokens cache size = 256
llm_load_vocab: token to piece cache size = 0.8000 MB
llm_load_print_meta: format           = GGUF V3 (latest)
llm_load_print_meta: arch             = llama
llm_load_print_meta: vocab type       = BPE
llm_load_print_meta: n_vocab          = 128256
llm_load_print_meta: n_merges         = 280147
llm_load_print_meta: vocab_only       = 0
llm_load_print_meta: n_ctx_train      = 8192
llm_load_print_meta: n_embd           = 4096
llm_load_print_meta: n_layer          = 32
llm_load_print_meta: n_head           = 32
llm_load_print_meta: n_head_kv        = 8
llm_load_print_meta: n_rot            = 128
llm_load_print_meta: n_swa            = 0
llm_load_print_meta: n_embd_head_k    = 128
llm_load_print_meta: n_embd_head_v    = 128
llm_load_print_meta: n_gqa            = 4
llm_load_print_meta: n_embd_k_gqa     = 1024
llm_load_print_meta: n_embd_v_gqa     = 1024
llm_load_print_meta: f_norm_eps       = 0.0e+00
llm_load_print_meta: f_norm_rms_eps   = 1.0e-05
llm_load_print_meta: f_clamp_kqv      = 0.0e+00
llm_load_print_meta: f_max_alibi_bias = 0.0e+00
llm_load_print_meta: f_logit_scale    = 0.0e+00
llm_load_print_meta: n_ff             = 14336
llm_load_print_meta: n_expert         = 0
llm_load_print_meta: n_expert_used    = 0
llm_load_print_meta: causal attn      = 1
llm_load_print_meta: pooling type     = 0
llm_load_print_meta: rope type        = 0
llm_load_print_meta: rope scaling     = linear
llm_load_print_meta: freq_base_train  = 500000.0
llm_load_print_meta: freq_scale_train = 1
llm_load_print_meta: n_ctx_orig_yarn  = 8192
llm_load_print_meta: rope_finetuned   = unknown
llm_load_print_meta: ssm_d_conv       = 0
llm_load_print_meta: ssm_d_inner      = 0
llm_load_print_meta: ssm_d_state      = 0
llm_load_print_meta: ssm_dt_rank      = 0
llm_load_print_meta: model type       = 8B
llm_load_print_meta: model ftype      = F16
llm_load_print_meta: model params     = 8.03 B
llm_load_print_meta: model size       = 14.96 GiB (16.00 BPW)
llm_load_print_meta: general.name     = Llama3.1 8b
llm_load_print_meta: BOS token        = 128000 '<|begin_of_text|>'
llm_load_print_meta: EOS token        = 128001 '<|end_of_text|>'
llm_load_print_meta: LF token         = 128 'Ä'
llm_load_print_meta: EOT token        = 128009 '<|eot_id|>'
llm_load_print_meta: max token length = 256
llm_load_tensors: ggml ctx size =    0.27 MiB
llm_load_tensors: offloading 32 repeating layers to GPU
llm_load_tensors: offloading non-repeating layers to GPU
llm_load_tensors: offloaded 33/33 layers to GPU
llm_load_tensors: RPC[localhost:50052] buffer size = 13898.98 MiB
llm_load_tensors:        CPU buffer size =  1418.03 MiB
.........................................................................................
llama_new_context_with_model: n_ctx      = 8192
llama_new_context_with_model: n_batch    = 2048
llama_new_context_with_model: n_ubatch   = 512
llama_new_context_with_model: flash_attn = 0
llama_new_context_with_model: freq_base  = 500000.0
llama_new_context_with_model: freq_scale = 1
llama_kv_cache_init: RPC[localhost:50052] KV buffer size =   992.00 MiB
llama_kv_cache_init:        CPU KV buffer size =    32.00 MiB
llama_new_context_with_model: KV self size  = 1024.00 MiB, K (f16):  512.00 MiB, V (f16):  512.00 MiB
llama_new_context_with_model:        CPU  output buffer size =     0.49 MiB
llama_new_context_with_model: RPC[localhost:50052] compute buffer size =   560.00 MiB
llama_new_context_with_model:        CPU compute buffer size =   560.01 MiB
llama_new_context_with_model: graph nodes  = 1030
llama_new_context_with_model: graph splits = 2

system_info: n_threads = 44 / 88 | AVX = 1 | AVX_VNNI = 0 | AVX2 = 1 | AVX512 = 1 | AVX512_VBMI = 0 | AVX512_VNNI = 0 | AVX512_BF16 = 0 | FMA = 1 | NEON = 0 | SVE = 0 | ARM_FMA = 0 | F16C = 1 | FP16_VA = 0 | WASM_SIMD = 0 | BLAS = 0 | SSE3 = 1 | SSSE3 = 1 | VSX = 0 | MATMUL_INT8 = 0 | LLAMAFILE = 1 |
sampling:
        repeat_last_n = 64, repeat_penalty = 1.000, frequency_penalty = 0.000, presence_penalty = 0.000
        top_k = 40, tfs_z = 1.000, top_p = 0.950, min_p = 0.050, typical_p = 1.000, temp = 0.800
        mirostat = 0, mirostat_lr = 0.100, mirostat_ent = 5.000
sampling order:
CFG -> Penalties -> top_k -> tfs_z -> typical_p -> top_p -> min_p -> temperature
generate: n_ctx = 8192, n_batch = 2048, n_predict = 64, n_keep = 1

Hello, my name is Paul. I am a graduate from the University of Michigan with a B.S. in Computer Science and a minor in Mathematics. I love math, and I know that it can be intimidating to some people. That is why I enjoy tutoring because it allows me to help those students who have trouble grasping certain concepts in math
llama_print_timings:        load time =   32468.00 ms
llama_print_timings:      sample time =      10.54 ms /    64 runs   (    0.16 ms per token,  6072.68 tokens per second)
llama_print_timings: prompt eval time =     617.41 ms /     6 tokens (  102.90 ms per token,     9.72 tokens per second)
llama_print_timings:        eval time =   32962.85 ms /    63 runs   (  523.22 ms per token,     1.91 tokens per second)
llama_print_timings:       total time =   33624.89 ms /    69 tokens
Log end
ski422 commented 1 month ago

I'm not sure if it's the same issue (I am not using RPC), but the inference speed has dramatically slowed compared to the older version of llama.cpp.

Despite performing CPU-based inference, it's about 4-5 times slower than the version that ran "main" (2-3 tokens/sec -> 0.4 tokens/sec). For reference, the performance of the old and current versions of llama.cpp was compared on the same server.

Is there a solution to this performance issue?

rgerganov commented 1 month ago

@hafezmg48 I am not able to reproduce such regression with the latest code (commit 6e02327e8b783):

Results without RPC:

➜  build-rpc-cuda git:(master) ✗ bin/llama-bench -m ../models/tinyllama-1.1b-f16.gguf -ngl 99 -fa 1                      
ggml_cuda_init: GGML_CUDA_FORCE_MMQ:    no
ggml_cuda_init: GGML_CUDA_FORCE_CUBLAS: no
ggml_cuda_init: found 1 CUDA devices:
  Device 0: NVIDIA GeForce GTX 1660, compute capability 7.5, VMM: yes
| model                          |       size |     params | backend    | ngl | fa |          test |              t/s |
| ------------------------------ | ---------: | ---------: | ---------- | --: | -: | ------------: | ---------------: |
| llama 1B F16                   |   2.05 GiB |     1.10 B | CUDA       |  99 |  1 |         pp512 |    315.02 ± 0.34 |
| llama 1B F16                   |   2.05 GiB |     1.10 B | CUDA       |  99 |  1 |         tg128 |     74.83 ± 0.05 |

build: 6e02327e (3565)

Results with RPC:

➜  build-rpc-cuda git:(master) ✗ bin/llama-bench -m ../models/tinyllama-1.1b-f16.gguf -ngl 99 -fa 1 --rpc localhost:50052
ggml_cuda_init: GGML_CUDA_FORCE_MMQ:    no
ggml_cuda_init: GGML_CUDA_FORCE_CUBLAS: no
ggml_cuda_init: found 1 CUDA devices:
  Device 0: NVIDIA GeForce GTX 1660, compute capability 7.5, VMM: yes
| model                          |       size |     params | backend    | ngl | fa |          test |              t/s |
| ------------------------------ | ---------: | ---------: | ---------- | --: | -: | ------------: | ---------------: |
| llama 1B F16                   |   2.05 GiB |     1.10 B | CUDA+RPC   |  99 |  1 |         pp512 |    314.06 ± 0.30 |
| llama 1B F16                   |   2.05 GiB |     1.10 B | CUDA+RPC   |  99 |  1 |         tg128 |     68.74 ± 0.01 |

build: 6e02327e (3565)

Can you try llama-bench and post the results with different models?