ggerganov / llama.cpp

LLM inference in C/C++
MIT License
64.58k stars 9.24k forks source link

Bug: Weird output from llama-speculative #8499

Closed bong-furiosa closed 1 day ago

bong-furiosa commented 1 month ago

What happened?

Hello, llama.cpp experts! Thank you for creating such an amazing LLM Inference system. 😁 However, while using this system, I encountered an unusual results when checking the speculative decoding output. I believe the observed issue is a bug and reporting it as a Bug ISSUE on this github project.

First of all, I want to provide a configuration of my system.

Next, I will explain the steps I took to download and run the model until the bug occurred. It was somewhat challenging to use the llama.cpp systems.

# download draft model
huggingface-cli download TinyLlama/TinyLlama-1.1B-Chat-v1.0 --local-dir=./llama-1.1b
./venv/bin/python3 convert_hf_to_gguf.py ./llama-1.1b
# download target model
huggingface-cli download NousResearch/Llama-2-7b-hf --local-dir=./llama-7b
./venv/bin/python3 convert_hf_to_gguf.py ./llama-7b
# run llama-speculative
./build/bin/llama-speculative -m ./llama-7b/ggml-model-f16.gguf -md ./llama-1.1b/ggml-model-f16.gguf -p "Making cake is like" -e -ngl 100 -ngld 100 -t 4 --temp 1.0 -n 128 -c 4096 -s 20 --top-k 0 --top-p 1 --repeat-last-n 0 --repeat-penalty 1.0 --draft 5

And the printed result is as follows:

draft:

llama_print_timings:        load time =    4430.64 ms
llama_print_timings:      sample time =     897.28 ms /   555 runs   (    1.62 ms per token,   618.54 tokens per second)
llama_print_timings: prompt eval time =    9531.68 ms /   228 tokens (   41.81 ms per token,    23.92 tokens per second)
llama_print_timings:        eval time =    1968.11 ms /   444 runs   (    4.43 ms per token,   225.60 tokens per second)
llama_print_timings:       total time =   19874.43 ms /   672 tokens

target:

llama_print_timings:        load time =   26494.43 ms
llama_print_timings:      sample time =    1337.68 ms /   112 runs   (   11.94 ms per token,    83.73 tokens per second)
llama_print_timings: prompt eval time =    1840.43 ms /   673 tokens (    2.73 ms per token,   365.68 tokens per second)
llama_print_timings:        eval time =       0.00 ms /     1 runs   (    0.00 ms per token,      inf tokens per second)
llama_print_timings:       total time =   24380.18 ms /   674 tokens

Here, unlike #3649, I got the inf eval time of the target model.

I am currently comparing the generation phase latency of the draft model and the target model in Speculative Decoding. So far, I have used llama-bench and llama-cli to measure tokens per second for each model, and the results have been different (e.g. the latency ratio measured with llama-bench was significanlty larger than that measured with llama-cli).

Therefore I attempted additional measurements with llama-speculative, but I obtained an unusual value of inf. I would like to request confirmation on whether this measurement result is a bug or if it is expected behavior of llama.cpp. 🙏

Name and Version

./build/bin/llama-cli --version
version: 3392 (bda62d79) built with cc (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0 for x86_64-linux-gnu

What operating system are you seeing the problem on?

Linux

Relevant log output

No response

ggerganov commented 1 month ago

The inf is expected - the reason is that the target model never evaluates batches of single tokens, so the eval time metric remains inactive:

https://github.com/ggerganov/llama.cpp/blob/7acfd4e8d55082c1b597dfc3ffe04fb5d530c6dc/examples/speculative/speculative.cpp#L560-L564

You can run the same command, replacing llama-speculative with llama-cli to get the target model speed in this case:

./build/bin/llama-cli -m ./llama-7b/ggml-model-f16.gguf -md ./llama-1.1b/ggml-model-f16.gguf -p "Making cake is like" -e -ngl 100 -ngld 100 -t 4 --temp 1.0 -n 128 -c 4096 -s 20 --top-k 0 --top-p 1 --repeat-last-n 0 --repeat-penalty 1.0 --draft 5
bong-furiosa commented 1 month ago

@ggerganov, Thank you for quick response! Thanks to your precise response, I was able to identify the cause of the inf value. However, there is still one unresolved question. May I ask some additional questions regarding model speed measurements?

When I measured the speed of each model using llama-bench, the results were as follows:

./build/bin/llama-bench -m ./llama-7b/ggml-model-f16.gguf -m ./llama-1.1b/ggml-model-f16.gguf -m ./llama-160m/ggml-model-f16.gguf -m ./llama-68m/ggml-model-f16.gguf -b 1 -p 0 -n 256
model size params backend ngl n_batch test t/s CostCoefficient(c)
llama 7B F16 12.55 GiB 6.74 B CUDA 99 1 tg256 79.75 ± 0.39 1.00
llama 1B F16 2.05 GiB 1.10 B CUDA 99 1 tg256 264.83 ± 3.32 0.29
llama 160m F16 309.82 MiB 162.42 M CUDA 99 1 tg256 712.15 ± 0.64 0.11
llama 68m F16 129.76 MiB 68.03 M CUDA 99 1 tg256 2942.56 ± 27.03 0.02

Here, I believe that we can approximate the Cost Coefficient, c, (from original speculative decoding paper) by dividing the t/s (tokens per second) of the target model by the t/s of the draft model.

We can see that:

  1. LLaMA-7B and LLaMA-68m have reached the similar c value mentioned in the paper (c < 0.05). 👏
  2. We can confirm that the c value from LLaMA-7B and LLaMA-160m are similar to that mentioned in #3649. 👏

However, when measuring the tokens per second of each model during the generation phase using llama-cli, the results were slightly different. The following table is not automatically printed from llama-cli, but is gathered manually by me. The strange output results are due to either the LLM model's poor ability or my inability to properly adjust the inference parameters.

# change llama model {7b, 1.1b, 160m, 68m}
./build/bin/llama-cli -m ./llama-68m/ggml-model-f16.gguf -e -ngl 100 -t 4 -n 512 -c 2048 -p "What can we do with llama llm model?"

# output result
 What can we do with llama llm model?
How can llam llama llm model?
You might have to have llama llama llm model?
What llama llama llama llm model?
I llama llama llama llama llama llama llama llama llama llama llama llama llama llama llama llama llama llama llama llama llama llama llama llama llama llama llama llama llama llama llama llama llama llama llama llama llama llama llama llama llama llama llama llama llama llama llama llama llama llama llama llama llama llama llama llama llama llama llama llama llama llama llama llama llama llama llama llama llama llama llama llama llama llama llama llama llama llama llama llama llama llama llama llama llama llama llama llama llama llama llama llama llama llama llama llama llama llama llama llama llama llama llama llama llama llama llama llama llama llama llama llama llama llama llama llama llama llama llama llama llama llama llama llama llama llama llama llama llama llama llama llama llama llama llama llama llama llama llama llama llama llama llama llama.
How can llama llama llama llama llama llama llama llama llama llama llama llama llama llama llama llama llama llama llama llama llama llama llama llama llama llama llama llama llama llama llama llama llama llama llama llama llama llama llama llama llama llama llama llama llama llama llama llama llama llama llama llama llama llama llama llama llama llama llama llama llama llama llama llama llama llama llama llama llama llama ll. llama llama llama llama llama llama llama llama llama llama llama llama llama llama llama llama llama llama llama ll
llama_print_timings:        load time =    4725.01 ms
llama_print_timings:      sample time =      23.96 ms /   512 runs   (    0.05 ms per token, 21367.16 tokens per second)
llama_print_timings: prompt eval time =       4.28 ms /    12 tokens (    0.36 ms per token,  2803.08 tokens per second)
llama_print_timings:        eval time =    1875.77 ms /   511 runs   (    3.67 ms per token,   272.42 tokens per second)
llama_print_timings:       total time =    7766.90 ms /   523 tokens
model t/s CostCoefficient(c)
llama 7B F16 78.18 1.00
llama 1B F16 211.63 0.36
llama 160m F16 244.85 0.31
llama 68m F16 272.42 0.28

🤔 We can see that the c value calculated from the t/s measured during llama-cli is higher than that during llama-bench.

I do not have the capability or time to analyze the llama-bench or llama-cli (since those codes are written very precisely). Could you possibly verify

  1. if these measurement results are reproducible in other systems?
  2. or if there is an error in my measurement method(e.g. command line) using llama-cli?

I apologize for asking such a complex question. However, I find the llama.cpp system truly amazing, and seeing it utilized in papers like OSD made me want to examine its robustness. 👍👍

ggerganov commented 1 month ago

When using llama-cli, if you redirect stdout to a file, does the number match?

./build/bin/llama-cli \
  -m ./llama-68m/ggml-model-f16.gguf \
  -e -ngl 100 -t 4 -n 512 -c 2048 \
  -p "What can we do with llama llm model?" > result.txt
bong-furiosa commented 1 month ago

Unfortunately, it still shows similar results. However, if token streaming and printing them are main causes of the increasing latency, I will delete the print section, rebuild, and test again. I will inform you the results after testing!

ggerganov commented 1 month ago

Not sure, with my RTX 2060 the results from llama-bench and llama-cli match:

GGML_CUDA=1 make -j && ./llama-bench -m models/llama-68m/ggml-model-f16.gguf -p 0 -n 128,256,512
ggml_cuda_init: GGML_CUDA_FORCE_MMQ: no ggml_cuda_init: GGML_CUDA_FORCE_CUBLAS: no ggml_cuda_init: found 1 CUDA devices: Device 0: NVIDIA GeForce RTX 2060 SUPER, compute capability 7.5, VMM: yes model size params backend ngl test t/s
llama ?B F16 129.76 MiB 68.03 M CUDA 99 tg128 2207.20 ± 7.61
llama ?B F16 129.76 MiB 68.03 M CUDA 99 tg256 2181.46 ± 2.34
llama ?B F16 129.76 MiB 68.03 M CUDA 99 tg512 2127.68 ± 0.53

build: 1666f92d (3404)

GGML_CUDA=1 make -j && ./llama-cli -m models/llama-68m/ggml-model-f16.gguf -e -ngl 100 -t 4 -n 512 -c 2048 -p "What can we do with llama llm model?"
ggml_cuda_init: GGML_CUDA_FORCE_MMQ:    no
ggml_cuda_init: GGML_CUDA_FORCE_CUBLAS: no
ggml_cuda_init: found 1 CUDA devices:
  Device 0: NVIDIA GeForce RTX 2060 SUPER, compute capability 7.5, VMM: yes
llm_load_tensors: ggml ctx size =    0,02 MiB
llm_load_tensors: offloading 2 repeating layers to GPU
llm_load_tensors: offloading non-repeating layers to GPU
llm_load_tensors: offloaded 3/3 layers to GPU
llm_load_tensors:        CPU buffer size =    46,88 MiB
llm_load_tensors:      CUDA0 buffer size =    82,89 MiB
..............
llama_new_context_with_model: n_ctx      = 2048
llama_new_context_with_model: n_batch    = 2048
llama_new_context_with_model: n_ubatch   = 512
llama_new_context_with_model: flash_attn = 0
llama_new_context_with_model: freq_base  = 10000,0
llama_new_context_with_model: freq_scale = 1
llama_kv_cache_init:      CUDA0 KV buffer size =    12,00 MiB
llama_new_context_with_model: KV self size  =   12,00 MiB, K (f16):    6,00 MiB, V (f16):    6,00 MiB
llama_new_context_with_model:  CUDA_Host  output buffer size =     0,12 MiB
llama_new_context_with_model:      CUDA0 compute buffer size =    64,00 MiB
llama_new_context_with_model:  CUDA_Host compute buffer size =     5,51 MiB
llama_new_context_with_model: graph nodes  = 70
llama_new_context_with_model: graph splits = 2

system_info: n_threads = 4 / 32 | AVX = 1 | AVX_VNNI = 0 | AVX2 = 1 | AVX512 = 0 | AVX512_VBMI = 0 | AVX512_VNNI = 0 | AVX512_BF16 = 0 | FMA = 1 | NEON = 0 | SVE = 0 | ARM_FMA = 0 | F16C = 1 | FP16_VA = 0 | WASM_SIMD = 0 | BLAS = 1 | SSE3 = 1 | SSSE3 = 1 | VSX = 0 | MATMUL_INT8 = 0 | LLAMAFILE = 1 | 
sampling: 
    repeat_last_n = 64, repeat_penalty = 1,000, frequency_penalty = 0,000, presence_penalty = 0,000
    top_k = 40, tfs_z = 1,000, top_p = 0,950, min_p = 0,050, typical_p = 1,000, temp = 0,800
    mirostat = 0, mirostat_lr = 0,100, mirostat_ent = 5,000
sampling order: 
CFG -> Penalties -> top_k -> tfs_z -> typical_p -> top_p -> min_p -> temperature 
generate: n_ctx = 2048, n_batch = 2048, n_predict = 512, n_keep = 1

 What can we do with llama llm model?
How to find the correct llama llm model? We have a large variety of hottest styles that you can find in the following range of the most popular hottest styles.
There are a variety of hottest styles and styles that you can find in the following range of the most popular hottest styles: Hottest Hottest Hottest Hottest Hottest Hottest Hottest Hottest Hottest Hottest Hottest Hottest Hottest Hottest Hottest Hottest Hottest Hottest Hottest Hottest Hottest Hottest Hottest Hottest Hottest Hottest Hottest Hottest Hottest Hottest Hottest Hottest Hottest Hottest Hottest Hottest Hottest Hottest Hottest Hottest Hottest Hottest Hottest Hottest Hottest Hottest Hottest Hottest Hottest Hottest Hottest Hottest Hottest Hottest Hottest Hottest Hottest Hottest Hottest Hottest Hottest Hottest Hottest Hottest Hottest Hottest Hottest Hottest Hottest Hottest Hottest Hottest Hottest Hottest Hottest Hottest Hottest Hottest Hottest Hottest Hottest Hottest Hottest Hottest Hottest Hottest Hottest Hottest Hottest Hottest Hottest Hottest Hottest Hottest Hottest Hottest Hottest Hottest Hottest Hottest Hottest Hottest Hottest Hottest Hottest Hottest Hottest Hottest Hottest Hottest Hottest Hottest Hottest Hottest Hottest Hottest Hottest Hottest Hottest Hottest Hottest Hottest Hottest Hottest Hottest Hottest Hottest Hottest Hottest Hottest Hottest Hottest Hottest Hottest Hottest Hottest Hottest Hottest Hottest Hottest Hottest Hottest Hottest Hottest Hottest Hottest Hottest Hott
llama_print_timings:        load time =     262,00 ms
llama_print_timings:      sample time =      11,34 ms /   512 runs   (    0,02 ms per token, 45161,86 tokens per second)
llama_print_timings: prompt eval time =       0,72 ms /    12 tokens (    0,06 ms per token, 16713,09 tokens per second)
llama_print_timings:        eval time =     239,42 ms /   511 runs   (    0,47 ms per token,  2134,32 tokens per second)
llama_print_timings:       total time =     266,97 ms /   523 tokens
Log end
bong-furiosa commented 1 month ago

This is so weird... I could not get your logs in my system (2TB RAM, 16 core CPU, everything is overly enough). After downloading the latest llama.cpp version and using the commands you provided above, I still observed a difference in speed between llama-cli and llama-bench.

# GGML_CUDA=1 make -j
# ./llama-cli -m ./llama-68m/ggml-model-f16.gguf -e -ngl 100 -t 4 -n 512 -c 2048 -p "What can we do with llama llm model?"
ggml_cuda_init: GGML_CUDA_FORCE_MMQ:    no
ggml_cuda_init: GGML_CUDA_FORCE_CUBLAS: no
ggml_cuda_init: found 1 CUDA devices:
  Device 0: NVIDIA A100-SXM4-80GB, compute capability 8.0, VMM: yes
llm_load_tensors: ggml ctx size =    0.02 MiB
llm_load_tensors: offloading 2 repeating layers to GPU
llm_load_tensors: offloading non-repeating layers to GPU
llm_load_tensors: offloaded 3/3 layers to GPU
llm_load_tensors:        CPU buffer size =    46.88 MiB
llm_load_tensors:      CUDA0 buffer size =    82.89 MiB
..............
llama_new_context_with_model: n_ctx      = 2048
llama_new_context_with_model: n_batch    = 2048
llama_new_context_with_model: n_ubatch   = 512
llama_new_context_with_model: flash_attn = 0
llama_new_context_with_model: freq_base  = 10000.0
llama_new_context_with_model: freq_scale = 1
llama_kv_cache_init:      CUDA0 KV buffer size =    12.00 MiB
llama_new_context_with_model: KV self size  =   12.00 MiB, K (f16):    6.00 MiB, V (f16):    6.00 MiB
llama_new_context_with_model:  CUDA_Host  output buffer size =     0.12 MiB
llama_new_context_with_model:      CUDA0 compute buffer size =    64.00 MiB
llama_new_context_with_model:  CUDA_Host compute buffer size =     5.51 MiB
llama_new_context_with_model: graph nodes  = 70
llama_new_context_with_model: graph splits = 2

system_info: n_threads = 4 / 128 | AVX = 1 | AVX_VNNI = 0 | AVX2 = 1 | AVX512 = 1 | AVX512_VBMI = 1 | AVX512_VNNI = 1 | AVX512_BF16 = 0 | FMA = 1 | NEON = 0 | SVE = 0 | ARM_FMA = 0 | F16C = 1 | FP16_VA = 0 | WASM_SIMD = 0 | BLAS = 1 | SSE3 = 1 | SSSE3 = 1 | VSX = 0 | MATMUL_INT8 = 0 | LLAMAFILE = 1 | 
sampling: 
        repeat_last_n = 64, repeat_penalty = 1.000, frequency_penalty = 0.000, presence_penalty = 0.000
        top_k = 40, tfs_z = 1.000, top_p = 0.950, min_p = 0.050, typical_p = 1.000, temp = 0.800
        mirostat = 0, mirostat_lr = 0.100, mirostat_ent = 5.000
sampling order: 
CFG -> Penalties -> top_k -> tfs_z -> typical_p -> top_p -> min_p -> temperature 
generate: n_ctx = 2048, n_batch = 2048, n_predict = 512, n_keep = 1

 What can we do with llama llm model?
A llama llm model is an llama llm model.
A llama llm model is an llama llm model.
A llama llm model is an llama llm model.
A llama llm model is an llama llm model.
A llama llm model is an llama llm model.
A llama llm model is an llama llm model.
A llama llm model is an llama llm model is an llama llm model.
A llama llm model is an llama llm model.
A llama llm model is an llama llm model.
A llama llm model is an llama llm model.
A llama llm model is an llama llm model.
A llama llm model is a llama llm model.
A llama llm model is an llama llm model is an llama llm model is an llama llm model.
A llm model is an llama llm model is an llama llm model.
A llama llm model is a llama llm model. I’ve been meaning to write a blog post on the subject for the last 5 years but haven’t been able to find any time to do so. I’m currently using the blog post, so I’ve been going to the post again.
My name is Cathy, and I have a job. I was really looking for work and work on a book. I’m not a writer. I have a job for a client. I want to be my life.
I am a writer and a writer. I’ve been a writer for a long time. I’m an editor and have a job for a client. I have a job for a client. I’m looking for a writer that I can do that could not write a writer and have a job for a client. I’m looking for a writer that will send me a blog post.
I’m a writer. I was a writer that I could write a writer that I can write a blog post that I’m not looking for.
I’m currently working on a new site called ‘Craft Craft’. I’m still a freelance writer. I’m a freelance writer. I’m looking for a writer that I
llama_print_timings:        load time =     701.66 ms
llama_print_timings:      sample time =      22.24 ms /   512 runs   (    0.04 ms per token, 23022.62 tokens per second)
llama_print_timings: prompt eval time =       3.82 ms /    12 tokens (    0.32 ms per token,  3138.90 tokens per second)
llama_print_timings:        eval time =    1557.46 ms /   511 runs   (    3.05 ms per token,   328.10 tokens per second)
llama_print_timings:       total time =    5714.46 ms /   523 tokens
Log end

Let me check if this issue is related to the Ampere GPU (since your GPU arch is Turing). After finishing my work(🤣), I will install and run the same commands on my personal desktop (RTX 3060) to check.

I'm curious if other users have had similar experiences. Thank you for keeping an eye on this!

ggerganov commented 1 month ago

Hm, yes that is weird - not sure why is that. Let us know the results with RTX 3060 when you get the chance

bong-furiosa commented 1 month ago

Thank you for your interest and patience! First of all, I was able to obtain similar results to those you achieved with the RTX 2060. :tada:


The results of the `llama-bench` are as follows:
./llama-bench -m ./llama-68m/ggml-model-f16.gguf -p 0 -n 128,256,512
ggml_cuda_init: GGML_CUDA_FORCE_MMQ:    no
ggml_cuda_init: GGML_CUDA_FORCE_CUBLAS: no
ggml_cuda_init: found 1 CUDA devices:
Device 0: NVIDIA GeForce RTX 3060, compute capability 8.6, VMM: yes
model size params backend ngl test t/s
llama ?B F16 129.76 MiB 68.03 M CUDA 99 tg128 2200.59 ± 14.63
llama ?B F16 129.76 MiB 68.03 M CUDA 99 tg256 2172.34 ± 30.06
llama ?B F16 129.76 MiB 68.03 M CUDA 99 tg512 2148.46 ± 23.81

The result of the llama-cli is as follows:

./llama-cli -m ./llama-68m/ggml-model-f16.gguf -e -ngl 100 -t 4 -n 512 -c 2048 -p "What can we do with llama llm model?"
ggml_cuda_init: GGML_CUDA_FORCE_MMQ:    no
ggml_cuda_init: GGML_CUDA_FORCE_CUBLAS: no
ggml_cuda_init: found 1 CUDA devices:
Device 0: NVIDIA GeForce RTX 3060, compute capability 8.6, VMM: yes
llm_load_tensors: ggml ctx size =    0.02 MiB
llm_load_tensors: offloading 2 repeating layers to GPU
llm_load_tensors: offloading non-repeating layers to GPU
llm_load_tensors: offloaded 3/3 layers to GPU
llm_load_tensors:        CPU buffer size =    46.88 MiB
llm_load_tensors:      CUDA0 buffer size =    82.89 MiB
What can we do with llama llm model?
llm model is a device that runs in a virtualized virtual world.
llm model is a device that runs in a virtualized virtual world. It is the device that runs in a virtualized virtual world.
llm model is a device that runs in a virtualized virtual world. It is the device that runs in a virtualized virtual world.
llm model is a device that runs in a virtualized virtual world. It runs in a virtual machine.
llm model is a device that runs in a virtual machine.
llm model is a device that runs in a virtualized virtual environment. It runs in a virtual machine. It runs in a virtual machine.
llm model is a device that runs in a virtual machine. It runs in a virtual machine. It runs in a virtual machine. It runs in a virtual machine. It runs in a virtual machine. It runs in a virtual machine. It runs in a virtual machine. It runs in a virtual machine. It runs in a virtual machine. It runs in a virtual machine. It runs in a virtual machine. It runs in a virtual machine. It runs in a virtual machine. It runs in a virtual machine. It runs in a virtual machine. It runs in a virtual machine. It runs in a virtual machine. It runs in a virtual machine. It runs in a virtual machine. It runs in a virtual machine. It runs in a virtual machine. It runs in a virtual machine. It runs in a virtual machine. It runs in a virtual machine. It runs in a virtual machine. It runs in a virtual machine. It runs in a virtual machine. It runs in a virtual machine. It runs in a virtual machine. It runs in a virtual machine. It runs in a virtual machine. It runs in a virtual machine. It runs in a virtual machine. It runs in a virtual machine. It runs in a virtual machine. It runs in a virtual machine. It runs in a virtual machine. It runs in a virtual machine. It runs in a virtual machine. It runs in a virtual machine. It runs in a virtual machine. It runs in a virtual machine. It runs in a virtual machine. It runs in a virtual machine. It runs in a virtual machine. It runs in a virtual machine. It runs in a virtual machine. It runs in a virtual machine. It runs in a virtual machine. It runs in a virtual machine. It runs in a virtual machine. It runs in a virtual machine.
llama_print_timings:        load time =     202.19 ms
llama_print_timings:      sample time =      14.05 ms /   512 runs   (    0.03 ms per token, 36428.32 tokens per second)
llama_print_timings: prompt eval time =       0.77 ms /    12 tokens (    0.06 ms per token, 15645.37 tokens per second)
llama_print_timings:        eval time =     278.39 ms /   511 runs   (    0.54 ms per token,  1835.55 tokens per second)
llama_print_timings:       total time =     316.21 ms /   523 tokens
Log end

:thinking: Based on the results, I plan to proceed the following plan.

The different tokens per second results from llama-bench and llama-cli may be due to an issue with my server or incorrect environment configuration (e.g. somewhat old docker image).

Therefore, I intend to conclude this ISSUE by reporting the results of the above plan. If you want to know the results and follow-up analysis, I will keep this ISSUE open. If you believe that the comparison of the results is already sufficient, please feel free to close this ISSUE. Thank you!

bong-furiosa commented 1 month ago

We conducted experiments on RTX 4090 (Ada architecture with compute capability 8.9) and A100 server after upgrading docker image to latest one (nvcr.io/nvidia/pytorch:24.06-py3).

The llama-bench and llama-cli test results for the RTX 4090 is as follows:

# ./llama-bench -m ./llama-68m/ggml-model-f16.gguf -p 0 -n 128,256,512
ggml_cuda_init: GGML_CUDA_FORCE_MMQ:    no
ggml_cuda_init: GGML_CUDA_FORCE_CUBLAS: no
ggml_cuda_init: found 1 CUDA devices:
Device 0: NVIDIA GeForce RTX 4090, compute capability 8.9, VMM: yes
model size params backend ngl test t/s
llama ?B F16 129.76 MiB 68.03 M CUDA 99 tg128 4006.17 ± 11.78
llama ?B F16 129.76 MiB 68.03 M CUDA 99 tg256 3995.99 ± 117.75
llama ?B F16 129.76 MiB 68.03 M CUDA 99 tg512 3930.12 ± 101.34
# ./llama-cli -m ./llama-68m/ggml-model-f16.gguf -e -ngl 100 -t 4 -n 512 -c 2048 -p "What can we do with llama llm model?"

llama_print_timings:        load time =     168.01 ms
llama_print_timings:      sample time =      13.30 ms /   512 runs   (    0.03 ms per token, 38481.77 tokens per second)
llama_print_timings: prompt eval time =       0.56 ms /    12 tokens (    0.05 ms per token, 21621.62 tokens per second)
llama_print_timings:        eval time =     150.58 ms /   511 runs   (    0.29 ms per token,  3393.52 tokens per second)
llama_print_timings:       total time =     185.14 ms /   523 tokens

Next, after upgrading docker image file, the llama-bench and llama-cli test results for the A100 is as follows:

# ./llama-bench -m ./llama-68m/ggml-model-f16.gguf -p 0 -n 128,256,512
ggml_cuda_init: GGML_CUDA_FORCE_MMQ:    no
ggml_cuda_init: GGML_CUDA_FORCE_CUBLAS: no
ggml_cuda_init: found 1 CUDA devices:
Device 0: NVIDIA A100-SXM4-80GB, compute capability 8.0, VMM: yes
model size params backend ngl test t/s
llama ?B F16 129.76 MiB 68.03 M CUDA 99 tg128 2587.95 ± 121.09
llama ?B F16 129.76 MiB 68.03 M CUDA 99 tg256 2800.25 ± 32.08
llama ?B F16 129.76 MiB 68.03 M CUDA 99 tg512 2762.87 ± 26.93
# ./llama-cli -m ./llama-68m/ggml-model-f16.gguf -e -ngl 100 -t 4 -n 512 -c 2048 -p "What can we do with llama llm model?"

llama_print_timings:        load time =     459.88 ms
llama_print_timings:      sample time =      26.49 ms /   512 runs   (    0.05 ms per token, 19328.78 tokens per second)
llama_print_timings: prompt eval time =       3.73 ms /    12 tokens (    0.31 ms per token,  3215.43 tokens per second)
llama_print_timings:        eval time =    1739.18 ms /   511 runs   (    3.40 ms per token,   293.82 tokens per second)
llama_print_timings:       total time =    6505.22 ms /   523 tokens

We were able to obtain similar results from llama-bench and llama-cli on three different types of desktop GPU (RTX 2060 : Turing, RTX 3060 : Ampere, RTX 4090 : Ada Lovelace).

However, the results measured on the current A100 server (which has better computing resources such as CPU, RAM etc) were different from them.


While writing this comment, I was able to test on an H100 server (Hopper arch, compute capability 9.0). However, since the server is shared among several coworkers, I couldn't update the docker image to the latest version. Here are the test resuls for the H100 server. The results are frustrating. :cry:

# ./llama-bench -m ./llama-68m/ggml-model-f16.gguf -p 0 -n 128,256,512
ggml_cuda_init: GGML_CUDA_FORCE_MMQ:    no
ggml_cuda_init: GGML_CUDA_FORCE_CUBLAS: no
ggml_cuda_init: found 2 CUDA devices:
Device 0: NVIDIA H100 PCIe, compute capability 9.0, VMM: yes
Device 1: NVIDIA H100 PCIe, compute capability 9.0, VMM: yes
model size params backend ngl test t/s
llama ?B F16 129.76 MiB 68.03 M CUDA 99 tg128 2710.55 ± 16.09
llama ?B F16 129.76 MiB 68.03 M CUDA 99 tg256 2617.29 ± 3.47
llama ?B F16 129.76 MiB 68.03 M CUDA 99 tg512 2526.86 ± 18.50
./llama-cli -m ./llama-68m/ggml-model-f16.gguf -e -ngl 100 -t 4 -n 512 -c 2048 -p "What can we do with llama llm model?"
llama_print_timings:        load time =     830.29 ms
llama_print_timings:      sample time =      27.73 ms /   512 runs   (    0.05 ms per token, 18465.76 tokens per second)
llama_print_timings: prompt eval time =       5.48 ms /    12 tokens (    0.46 ms per token,  2191.78 tokens per second)
llama_print_timings:        eval time =    2426.31 ms /   511 runs   (    4.75 ms per token,   210.61 tokens per second)
llama_print_timings:       total time =    8103.09 ms /   523 tokens

The experimental results of the H100 from llama-bench and llama-cli were similar to those of the A100. From above results, we can consider the following thoughts:

  1. There might be a hidden critical cause in our server system configuration. :thinking:
  2. The current version of llama.cpp could produce different t/s results from llama-bench and llama-cli in a GPU server environment (Of course, this is the least likely hypothesis.).

:thinking: The best next approach is to check the current llama.cpp's llama-bench and llama-cli results from other users using GPU server systems. Could there be any opinions from other users or simple test results from the llama.cpp team on the latest llama.cpp tested on GPU server systems?

ggerganov commented 1 month ago

Huh.. Could you run the following on the H100:

make clean
LLAMA_DISABLE_LOGS=1 GGML_CUDA=1 make -j

./llama-cli -m ./llama-68m/ggml-model-f16.gguf -e -ngl 100 -t 4 -n 512 -c 2048 -p "What can we do with llama llm model?"

and also:

CUDA_VISIBLE_DEVICES=0 ./llama-cli -m ./llama-68m/ggml-model-f16.gguf -e -ngl 100 -t 4 -n 512 -c 2048 -p "What can we do with llama llm model?"
bong-furiosa commented 1 month ago

Inlcuding LLaMA_DISABLE_LOGS=1 in the make command produced interesting results!

The results of running the first command are as follows. llama-cli detects 2 GPUs in the server (2 H100 GPUs are installed).

./llama-cli -m ./llama-68m/ggml-model-f16.gguf -e -ngl 100 -t 4 -n 512 -c 2048 -p "What can we do with llama llm model?"
./llama-68m/ggml-model-f16.gguf -e -ngl 100 -t 4 -n 512 -c 2048 -p "What can we do with llama llm model?"
ggml_cuda_init: GGML_CUDA_FORCE_MMQ:    no
ggml_cuda_init: GGML_CUDA_FORCE_CUBLAS: no
ggml_cuda_init: found 2 CUDA devices:
Device 0: NVIDIA H100 PCIe, compute capability 9.0, VMM: yes
Device 1: NVIDIA H100 PCIe, compute capability 9.0, VMM: yes
...............
llama_print_timings:        load time =     727.73 ms
llama_print_timings:      sample time =      14.97 ms /   512 runs   (    0.03 ms per token, 34201.74 tokens per second)
llama_print_timings: prompt eval time =       0.62 ms /    12 tokens (    0.05 ms per token, 19354.84 tokens per second)
llama_print_timings:        eval time =     200.63 ms /   511 runs   (    0.39 ms per token,  2546.91 tokens per second)
llama_print_timings:       total time =     236.70 ms /   523 tokens

The results of running the second command are as follows. Here, we forced llama-cli to use only GPU 0.

CUDA_VISIBLE_DEVICES=0 ./llama-cli -m ./llama-68m/ggml-model-f16.gguf -e -ngl 100 -t 4 -n 512 -c 2048 -p "What can we do with llama llm model?"
ggml_cuda_init: GGML_CUDA_FORCE_MMQ:    no
ggml_cuda_init: GGML_CUDA_FORCE_CUBLAS: no
ggml_cuda_init: found 1 CUDA devices:
  Device 0: NVIDIA H100 PCIe, compute capability 9.0, VMM: yes
..............
llama_print_timings:        load time =     579.10 ms
llama_print_timings:      sample time =      15.20 ms /   512 runs   (    0.03 ms per token, 33693.08 tokens per second)
llama_print_timings: prompt eval time =       1.35 ms /    12 tokens (    0.11 ms per token,  8862.63 tokens per second)
llama_print_timings:        eval time =     194.37 ms /   511 runs   (    0.38 ms per token,  2628.98 tokens per second)
llama_print_timings:       total time =     239.52 ms /   523 tokens

😲 In both cases, over 2000 tokens per second results are recorded. What a surprise!

Additionally, we did the same process to see if we could obtain the similar results on the A100 server.

./llama-bench -m ./llama-68m/ggml-model-f16.gguf -p 0 -n 128,256,512
ggml_cuda_init: GGML_CUDA_FORCE_MMQ:    no
ggml_cuda_init: GGML_CUDA_FORCE_CUBLAS: no
ggml_cuda_init: found 1 CUDA devices:
  Device 0: NVIDIA A100-SXM4-80GB, compute capability 8.0, VMM: yes
model size params backend ngl test t/s
llama ?B F16 129.76 MiB 68.03 M CUDA 99 tg128 2627.97 ± 39.52
llama ?B F16 129.76 MiB 68.03 M CUDA 99 tg256 2827.87 ± 53.34
llama ?B F16 129.76 MiB 68.03 M CUDA 99 tg512 2824.42 ± 13.34
./llama-cli -m ./llama-68m/ggml-model-f16.gguf -e -ngl 100 -t 4 -n 512 -c 2048 -p "What can we do with llama llm model?"
ggml_cuda_init: GGML_CUDA_FORCE_MMQ:    no
ggml_cuda_init: GGML_CUDA_FORCE_CUBLAS: no
ggml_cuda_init: found 1 CUDA devices:
  Device 0: NVIDIA A100-SXM4-80GB, compute capability 8.0, VMM: yes
..............
llama_print_timings:        load time =    1354.61 ms
llama_print_timings:      sample time =      18.16 ms /   512 runs   (    0.04 ms per token, 28187.62 tokens per second)
llama_print_timings: prompt eval time =       0.90 ms /    12 tokens (    0.07 ms per token, 13363.03 tokens per second)
llama_print_timings:        eval time =     200.27 ms /   511 runs   (    0.39 ms per token,  2551.52 tokens per second)
llama_print_timings:       total time =     241.64 ms /   523 tokens

Now, the remaining question is why we were able to obtain consistent results from llama-bench and llama-cli in the RTX GPU tests without the LLAMA_DISABLE_LOGS=1 option. However, this issue might be resolved by:

  1. modifying the logging method
  2. or analyzing the desktop vs server environment configuration.

In conclusion, it was your thoughtful help that us to solve the problem. Thank you so much! 👍👍👍

mscheong01 commented 1 month ago

Maybe the difference is because your local desktop has more CPU resources available for logging than the GPU servers?

I think this information is worth mentioning in the docs or in this issue as it seems to have a significant impact.

Also, thanks for sharing your benchmarks! 😄 🙇‍♂️

bong-furiosa commented 1 month ago

Thank you @mscheong01 for checking this issue! As you suggested, I have reported this issue in #6398. 👍 I'm not sure how well this issue matches with the discussion. However, I'm looking forward to further discussion on this issue and find the exact cause.

ggerganov commented 1 month ago

The logging does incur some overhead as it is synchronous and some of the stuff that we log (e.g. batch contents) involves some heavy ops like detokenization. For very small models such as the 68M one used in the tests earlier, this can have noticeable impact, though it was still surprising to see such a big difference in your tests. For bigger models (i.e. above 1B) I expect that the logging overhead will have much smaller impact - maybe close to insignificant

In any case, all of this will be resolved when we implement asynchronous logging and add compile-time verbosity levels (#8566)

github-actions[bot] commented 1 day ago

This issue was closed because it has been inactive for 14 days since being marked as stale.