ggerganov / llama.cpp

LLM inference in C/C++
MIT License
64.64k stars 9.26k forks source link

Speculative Decoding is slower than expected on A100 #3649

Closed LiuXiaoxuanPKU closed 5 months ago

LiuXiaoxuanPKU commented 10 months ago

Thanks for the great project! I am benchmarking the performance of llamacpp with speculative decoding.

llama_print_timings: load time = 65.11 ms llama_print_timings: sample time = 524.95 ms / 1 runs ( 524.95 ms per token, 1.90 tokens per second) llama_print_timings: prompt eval time = 8.59 ms / 94 tokens ( 0.09 ms per token, 10946.78 tokens per second) llama_print_timings: eval time = 322.80 ms / 216 runs ( 1.49 ms per token, 669.15 tokens per second) llama_print_timings: total time = 2924.72 ms

target:

llama_print_timings: load time = 1144.77 ms llama_print_timings: sample time = 4.02 ms / 259 runs ( 0.02 ms per token, 64411.84 tokens per second) llama_print_timings: prompt eval time = 1939.02 ms / 351 tokens ( 5.52 ms per token, 181.02 tokens per second) llama_print_timings: eval time = 13.19 ms / 1 runs ( 13.19 ms per token, 75.82 tokens per second) llama_print_timings: total time = 2999.59 ms

I am using greedy decoding and disabling all the heuristics (fix `n_draft`, always propose `n_draft` tokens and avoid early stopping). My execution cmd is:

./build/bin/speculative \ -ngl 1000 \ -ngld 100 \ -m /data/model/llama-7b/ggml-model-f16.gguf \ -md /data/model/lama-160m/ggml-model-f16.gguf \ -p "${prompt}" \ -e --temp "-1" -n 256 -s 1 --top-k 0 --top-p 1 --repeat-last-n 0 --repeat-penalty 1.0 --draft 5


When token acceptance rate is 0.44, speculative decoding is actually slower (notice 50 tokens/s < 75 tokens/s)

encoded 94 tokens in 0.076 seconds, speed: 1231.914 t/s decoded 108 tokens in 2.145 seconds, speed: 50.341 t/s

n_draft = 5 n_predict = 108 n_drafted = 165 n_accept = 74 accept = 44.848%


However, based on the original speculative [paper](https://proceedings.mlr.press/v202/leviathan23a/leviathan23a.pdf), the speedup should be:
<img width="160" alt="Screen Shot 2023-10-16 at 9 10 58 PM" src="https://github.com/ggerganov/llama.cpp/assets/16137495/c3177759-70a0-4549-b3db-14f7e3ea15d2">
where `alpha` is the token acceptance rate, `gamma` is the number of tokens proposed each step, and `c` is the ratio between the execution times of the draft and target models. In the example above, `c` is roughly `76/669=0.11`.
Plugin in the numbers above, the expected speedup should be:
`(1-0.44^6)/[(1-0.44)*(0.11*0.44+1)]=1.69x`.
However, the benchmarking results show that it's actually `50/76=0.66x`.

To debug this, I set the token acceptance rate to 100% by removing the `id==draft_id[i_dft]` [here](https://github.com/ggerganov/llama.cpp/blob/940efa95fec0b8a98c226a889d2ad839dfeeae0d/examples/speculative/speculative.cpp#L158). After doing this, I observe that the speed is ~90tokens/s, which brings `90/76=1.18x` speedup. However, this is much smaller than the calculation with the formula above (I use 0.99 as the token acceptance rate instead of 1):
`(1-0.99^6)/[(1-0.99)*(0.11*0.99+1)]=5.27x`.

I wonder which part of the speculative decoding might cause big overhead, any comments are highly appreciated! Thanks!
ggerganov commented 10 months ago

Thank you for the detailed report - very useful information!

~If you add -nommq CLI arg, do the numbers improve?~ (Edit: nvm, -nommq does not make a difference for F16 models)

I'll try to do the same test today and see if I can find the bottleneck.

ggerganov commented 10 months ago

I might be missing something, but I think there is an error in the number representations of the equation. In the first case, for example, it should be:

(1-0.44^6)/[(1-0.44)*(0.11*5+1)]= 1.14x

because gamma in the denominator is 5 - not 0.44

I did some more testing on a V100 16GB GPU using the same models. Here is my script to determine the theoretical speed-up according to the paper:

import numpy as np

# case 0, alpha = 0.8, gamma = 5
sd_tg = 619.4
st_tg = 52.5
st_pp = 121.7 # pp 5

s_avg = (sd_tg + st_pp)/2

a = 0.80
c = st_tg/s_avg
g = 5

speed = (1 - a**(g + 1))/((1 - a)*(c*g + 1))

print(speed)

# case 0, alpha = 0.875, gamma = 8
sd_tg = 625
st_tg = 52.5
st_pp = 194.6 # pp 8

s_avg = (sd_tg + st_pp)/2

a = 0.875
c = st_tg/s_avg
g = 8

speed = (1 - a**(g + 1))/((1 - a)*(c*g + 1))

print(speed)

In case 0 I run the following:

make -j && ./bin/speculative -ngl 1000 -ngld 100 -m /mnt/llama.cpp/models/open-llama/7B-v2/ggml-model-f16.gguf -md /mnt/llama.cpp/models/llama-160m/ggml-model-f16.gguf -p "${prompt}" -e --temp -1 -n 256 --repeat-last-n 0 --repeat-penalty 1.0 --draft 5 -np 1

encoded   58 tokens in    0.111 seconds, speed:  521.414 t/s
decoded  261 tokens in    2.967 seconds, speed:   87.954 t/s

n_draft   = 5
n_predict = 261
n_drafted = 260
n_accept  = 208
accept    = 80.000%

draft:

llama_print_timings:        load time =     125.62 ms
llama_print_timings:      sample time =       9.58 ms /   260 runs   (    0.04 ms per token, 27128.55 tokens per second)
llama_print_timings: prompt eval time =      10.95 ms /    58 tokens (    0.19 ms per token,  5296.80 tokens per second)
llama_print_timings:        eval time =     480.14 ms /   261 runs   (    1.84 ms per token,   543.59 tokens per second)
llama_print_timings:       total time =    3079.20 ms

target:

llama_print_timings:        load time =    2827.23 ms
llama_print_timings:      sample time =       9.84 ms /   261 runs   (    0.04 ms per token, 26513.61 tokens per second)
llama_print_timings: prompt eval time =    2387.92 ms /   369 tokens (    6.47 ms per token,   154.53 tokens per second)
llama_print_timings:        eval time =      19.31 ms /     1 runs   (   19.31 ms per token,    51.80 tokens per second)
llama_print_timings:       total time =    3219.78 ms

In case 1 I run this:

make -j && ./bin/speculative -ngl 1000 -ngld 100 -m /mnt/llama.cpp/models/open-llama/7B-v2/ggml-model-f16.gguf -md /mnt/llama.cpp/models/llama-160m/ggml-model-f16.gguf -p "${prompt}" -e --temp -1 -n 256 --repeat-last-n 0 --repeat-penalty 1.0 --draft 5 -np 1

encoded   58 tokens in    0.110 seconds, speed:  525.272 t/s
decoded  257 tokens in    2.083 seconds, speed:  123.394 t/s

n_draft   = 8
n_predict = 257
n_drafted = 256
n_accept  = 224
accept    = 87.500%

draft:

llama_print_timings:        load time =     125.64 ms
llama_print_timings:      sample time =       9.42 ms /   256 runs   (    0.04 ms per token, 27190.65 tokens per second)
llama_print_timings: prompt eval time =      10.97 ms /    58 tokens (    0.19 ms per token,  5286.66 tokens per second)
llama_print_timings:        eval time =     446.25 ms /   257 runs   (    1.74 ms per token,   575.91 tokens per second)
llama_print_timings:       total time =    2193.53 ms

target:

llama_print_timings:        load time =    2770.13 ms
llama_print_timings:      sample time =      11.02 ms /   257 runs   (    0.04 ms per token, 23323.35 tokens per second)
llama_print_timings: prompt eval time =    1543.14 ms /   345 tokens (    4.47 ms per token,   223.57 tokens per second)
llama_print_timings:        eval time =      19.56 ms /     1 runs   (   19.56 ms per token,    51.12 tokens per second)
llama_print_timings:       total time =    2334.47 ms

I am using the branch in #3624 with -np 1 it should be equivalent to master. I have applied the following patch to simulate high-acceptance rate (similar to your changes):

--- a/examples/speculative/speculative.cpp
+++ b/examples/speculative/speculative.cpp
@@ -174,7 +174,7 @@ int main(int argc, char ** argv) {
                         continue;
                     }

-                    if (i_dft < (int) drafts[s].tokens.size() && id == drafts[s].tokens[i_dft]) {
+                    if (i_dft < (int) drafts[s].tokens.size() - 1) {
                         LOG("the sampled target token matches the %dth drafted token of sequence %d (%d, '%s') - accepted\n", i_dft, s, id, token_str.c_str());

                         s_keep = s;
@@ -273,11 +273,11 @@ int main(int argc, char ** argv) {
                 }

                 // TODO: make this configurable
-                if (cur_p[0].p < 0.4) {
-                    LOG("stopping drafting for seq %3d, probability too low: %.3f < 2*%.3f\n", s, cur_p[0].p, cur_p[1].p);
-                    drafts[s].drafting = false;
-                    continue;
-                }
+                //if (cur_p[0].p < 0.4) {
+                //    LOG("stopping drafting for seq %3d, probability too low: %.3f < 2*%.3f\n", s, cur_p[0].p, cur_p[1].p);
+                //    drafts[s].drafting = false;
+                //    continue;
+                //}

                 std::vector<int> sa(1, s);

To determine sd_tg, st_tg and st_pp I run the following benchmark commands:

# draft model text-generation bench
./bin/llama-bench -m /mnt/llama.cpp/models/llama-160m/ggml-model-f16.gguf -p 0 -n 128 -ngl 99

  Device 0: Tesla V100-PCIE-16GB, compute capability 7.0
| model                          |       size |     params | backend    | ngl | test       |              t/s |
| ------------------------------ | ---------: | ---------: | ---------- | --: | ---------- | ---------------: |
| llama ?B mostly F16            | 309.82 MiB |   162.42 M | CUDA       |  99 | tg 128     |   619.43 ± 20.03 |
# target model PP and TG bench
./bin/llama-bench -m /mnt/llama.cpp/models/open-llama/7B-v2/ggml-model-f16.gguf -p 1,2,3,4,5,6,7,8,64,128,256,512 -n 128 -ngl 99

  Device 0: Tesla V100-PCIE-16GB, compute capability 7.0
| model                          |       size |     params | backend    | ngl | test       |              t/s |
| ------------------------------ | ---------: | ---------: | ---------- | --: | ---------- | ---------------: |
| llama 7B mostly F16            |  12.55 GiB |     6.74 B | CUDA       |  99 | pp 1       |     30.72 ± 3.88 |
| llama 7B mostly F16            |  12.55 GiB |     6.74 B | CUDA       |  99 | pp 2       |     51.22 ± 0.43 |
| llama 7B mostly F16            |  12.55 GiB |     6.74 B | CUDA       |  99 | pp 3       |     75.92 ± 0.64 |
| llama 7B mostly F16            |  12.55 GiB |     6.74 B | CUDA       |  99 | pp 4       |    101.66 ± 0.29 |
| llama 7B mostly F16            |  12.55 GiB |     6.74 B | CUDA       |  99 | pp 5       |    121.74 ± 0.55 |
| llama 7B mostly F16            |  12.55 GiB |     6.74 B | CUDA       |  99 | pp 6       |    146.37 ± 0.35 |
| llama 7B mostly F16            |  12.55 GiB |     6.74 B | CUDA       |  99 | pp 7       |    163.65 ± 0.46 |
| llama 7B mostly F16            |  12.55 GiB |     6.74 B | CUDA       |  99 | pp 8       |    194.58 ± 0.72 |
| llama 7B mostly F16            |  12.55 GiB |     6.74 B | CUDA       |  99 | pp 64      |   899.08 ± 23.64 |
| llama 7B mostly F16            |  12.55 GiB |     6.74 B | CUDA       |  99 | pp 128     |  1708.83 ± 56.78 |
| llama 7B mostly F16            |  12.55 GiB |     6.74 B | CUDA       |  99 | pp 256     |  2515.20 ± 10.46 |
| llama 7B mostly F16            |  12.55 GiB |     6.74 B | CUDA       |  99 | pp 512     |   2866.48 ± 1.45 |
| llama 7B mostly F16            |  12.55 GiB |     6.74 B | CUDA       |  99 | tg 128     |     52.49 ± 0.13 |

For sd_tg I pick the result from the first bench.

For st_tg I pick the tg 128 result from the second bench. For st_pp I pick the respective pp g value based on the value of g.

My understanding is that s_avg is the average speed in speculative mode where we take into account the speed both for drafting and evaluating the drafted batch on the target model.

The theoretical results this way are as follows:

python3 speed.py
2.1594861448542773
2.7629832057356403

While the observed are:

87.954 / 52.49 = 1.675633453991236
123.394 / 52.49 = 2.350809678033911
LiuXiaoxuanPKU commented 10 months ago

Thanks for the correction, yeah I plugin the wrong numbers, your calculation is correct.

I will also try to benchmark on V100 today/tmr, and will let you know the numbers. Thanks for the detailed experiment!

LiuXiaoxuanPKU commented 10 months ago

Hi

  1. I benchmarked on V100 and A100. My numbers on V100 are promising, concretely: Speed (generation phase) in tokens/s of draft and target models:
GPU Draft (160M) Target (7B)
A100 616 74
V100 630 51
when token acceptance rate ~ 60% Speculative decoding with greedy sampling Speedup
A100 63 0.85
V100 69 1.35

I'm still confused about the performance on A100....

  1. For the formula: image The paper makes the assumption that: for the target model, the time to verify gamma+1 tokens is the same as generating a single token. Therefore, c is the ratio between the time for a single run of the draft model and the time for a single run of the target model. I think your way of calculation (My understanding is that s_avg is the average speed in speculative mode where we take into account the speed both for drafting and evaluating the drafted batch on the target model.) will underestimate the speedup a bit, but I guess it's good for now.

  2. I want to confirm that my understanding of variables is correct:st_tg is speed of target model generation phase, st_pp is the speed of target model prompt phase.

  3. Could you demonstrate how you set the token acceptance rate to 80%? I try to generate some random values between 0-1 and accept the token when the random value is < 0.8, still modify the condition here, but it cannot strictly fix the token acceptance rate to 0.8.

ggerganov commented 10 months ago
  1. Yes, my assumption might be incorrect. At least in llama.cpp, currently the time TN to verify N tokens in a batch is not the same as the time T1 for 1 token. Typically, we have T1 < TN < N*T1, except for very small values of N (e.g. ~2, 3)

  2. Yes. st_pp is the prompt processing speed with a certain batch size. Also known as prefill phase. We assume that when we verify a draft with size gamma, the speed is the same as the prompt processing speed with batch size gamma

  3. If you perform a sequential rejection with acceptance probability per token p, your total acceptance rate will not be equal to p. Here is a script to compute the acceptance rate for a given draft size N and acceptance probability p:

# p - probability to accept a token
#
# probability to accept M tokens from a draft of N:
#
# M   probability
# 0   (1-p)
# 1   (1-p)*p
# 2   (1-p)*p^2
# ...
# N-1 (1-p)*p^(N-1)
# N   p^N
#
# expectation:
#
# E[X] = 0*(1-p) + 1*p*(1-p) + 2*p^2*(1-p) + 3*p^3*(1-p) + ... + (N-1)*p^(N-1)*(1-p) + N*p^N
#

import numpy as np
import sys

N = int(sys.argv[1])
p = float(sys.argv[2])

print("N = ", N)
print("p = ", p)

E = 0

for i in range(N):
    E += i * p**i * (1-p)

E += N * p**N

print("E = ", round(E, 2), " (", round(100*(E/N), 2), "% )")

So for a draft size of 8, you can use p = 0.95 to get ~80% total acceptance rate:

$ ▶ python3 expect.py 8 0.95
N =  8
p =  0.95
E =  6.4  ( 79.94 % )

Here is the diff on master:

--- a/examples/speculative/speculative.cpp
+++ b/examples/speculative/speculative.cpp
@@ -8,6 +8,10 @@
 #include <string>
 #include <vector>

+static float frand() {
+    return (float) rand() / RAND_MAX;
+}
+
 struct seq_draft {
     bool active   = false;
     bool drafting = false;
@@ -37,7 +41,7 @@ int main(int argc, char ** argv) {
     const int n_seq_dft = params.n_parallel;

     // TODO: make this configurable
-    const float p_accept = 0.80f;
+    const float p_accept = -1.0f; // always draft n_draft tokens
     const float p_split  = 0.10f;

 #ifndef LOG_DISABLE_LOGS
@@ -178,7 +182,7 @@ int main(int argc, char ** argv) {
                         continue;
                     }

-                    if (i_dft < (int) drafts[s].tokens.size() && id == drafts[s].tokens[i_dft]) {
+                    if (i_dft < (int) drafts[s].tokens.size() && frand() < 0.95) { // use the python script to find the value that you will give you the desired acceptance rate
                         LOG("the sampled target token matches the %dth drafted token of sequence %d (%d, '%s') - accepted\n", i_dft, s, id, token_str.c_str());

                         s_keep = s;

Alternatively, you can do what I did in my previous comment - simply accept the first 0.8*N tokens unconditionally:

    const int n_seq_dft = params.n_parallel;

     // TODO: make this configurable
-    const float p_accept = 0.80f;
+    const float p_accept = -1.0f; // always draft n_draft tokens
     const float p_split  = 0.10f;

 #ifndef LOG_DISABLE_LOGS
@@ -178,7 +182,7 @@ int main(int argc, char ** argv) {
                         continue;
                     }

-                    if (i_dft < (int) drafts[s].tokens.size() && id == drafts[s].tokens[i_dft]) {
+                    if (i_dft < 0.8*(int) drafts[s].tokens.size()) {
                         LOG("the sampled target token matches the %dth drafted token of sequence %d (%d, '%s') - accepted\n", i_dft, s, id, token_str.c_str());

                         s_keep = s;
ggerganov commented 10 months ago

With #3749 now merged, the batched decoding performance for F16 models has been significantly improved.

A few speculative decoding tests on A100 from today achieve 2-3x speed-up using Codellama 34B Target + Codellama 7B Q4_0 Draft:

https://twitter.com/ggerganov/status/1716727296269193702

Here are some examples:

LLAMA_CUBLAS=1 make -j && ./speculative -m ./models/codellama-34b/ggml-model-f16.gguf -md ./models/codellama-7b/ggml-model-f16.gguf -p "# Dijkstra's shortest path algorithm in Python (4 spaces indentation) + complexity analysis:\n\n" -e -ngl 999 -ngld 999 -t 4 -n 512 -c 4096 -s 21 --draft 16 -np 1 --temp 0.0

 # Dijkstra's shortest path algorithm in Python (4 spaces indentation) + complexity analysis:

# Time complexity: O(V^2)
# Space complexity: O(V)

def dijkstra(graph, start):
    distances = {vertex: float('inf') for vertex in graph}
    previous_vertices = {vertex: None for vertex in graph}
    distances[start] = 0
    vertices = list(graph.keys())

    while len(vertices) > 0:
        current_vertex = min(vertices, key=lambda vertex: distances[vertex])
        vertices.remove(current_vertex)

        if distances[current_vertex] == float('inf'):
            break

        for neighbor in graph[current_vertex]:
            distance = distances[current_vertex] + graph[current_vertex][neighbor]

            if distance < distances[neighbor]:
                distances[neighbor] = distance
                previous_vertices[neighbor] = current_vertex

    return distances, previous_vertices

encoded   25 tokens in    0.129 seconds, speed:  193.979 t/s
decoded  238 tokens in    6.033 seconds, speed:   39.453 t/s

n_draft   = 16
n_predict = 238
n_drafted = 323
n_accept  = 215
accept    = 66.563%

draft:

llama_print_timings:        load time =    2186.21 ms
llama_print_timings:      sample time =      45.08 ms /   326 runs   (    0.14 ms per token,  7231.91 tokens per second)
llama_print_timings: prompt eval time =      19.28 ms /    25 tokens (    0.77 ms per token,  1296.88 tokens per second)
llama_print_timings:        eval time =    4599.92 ms /   346 runs   (   13.29 ms per token,    75.22 tokens per second)
llama_print_timings:       total time =    6161.42 ms

target:

llama_print_timings:        load time =    6532.70 ms
llama_print_timings:      sample time =      28.60 ms /   238 runs   (    0.12 ms per token,  8320.81 tokens per second)
llama_print_timings: prompt eval time =    1301.73 ms /   369 tokens (    3.53 ms per token,   283.47 tokens per second)
llama_print_timings:        eval time =      50.88 ms /     1 runs   (   50.88 ms per token,    19.65 tokens per second)
llama_print_timings:       total time =    8365.12 ms

---

LLAMA_CUBLAS=1 make -j && ./speculative -m ./models/codellama-34b/ggml-model-f16.gguf -md ./models/codellama-7b/ggml-model-q4_0.gguf -p "# Dijkstra's shortest path algorithm in Python (4 spaces indentation) + complexity analysis:\n\n" -e -ngl 999 -ngld 999 -t 4 -n 512 -c 4096 -s 21 --draft 16 -np 1 --temp 0.0

 # Dijkstra's shortest path algorithm in Python (4 spaces indentation) + complexity analysis:

# Time complexity: O(V^2)
# Space complexity: O(V)

def dijkstra(graph, start):
    distances = {vertex: float('inf') for vertex in graph}
    previous_vertices = {vertex: None for vertex in graph}
    distances[start] = 0
    vertices = list(graph.keys())

    while len(vertices) > 0:
        current_vertex = min(vertices, key=lambda vertex: distances[vertex])
        vertices.remove(current_vertex)

        if distances[current_vertex] == float('inf'):
            break

        for neighbor in graph[current_vertex]:
            distance = distances[current_vertex] + graph[current_vertex][neighbor]

            if distance < distances[neighbor]:
                distances[neighbor] = distance
                previous_vertices[neighbor] = current_vertex

    return distances, previous_vertices

encoded   25 tokens in    0.199 seconds, speed:  125.459 t/s
decoded  238 tokens in    4.270 seconds, speed:   55.736 t/s

n_draft   = 16
n_predict = 238
n_drafted = 365
n_accept  = 214
accept    = 58.630%

draft:

llama_print_timings:        load time =    1195.22 ms
llama_print_timings:      sample time =      49.97 ms /   366 runs   (    0.14 ms per token,  7323.95 tokens per second)
llama_print_timings: prompt eval time =      89.48 ms /    25 tokens (    3.58 ms per token,   279.41 tokens per second)
llama_print_timings:        eval time =    2766.19 ms /   389 runs   (    7.11 ms per token,   140.63 tokens per second)
llama_print_timings:       total time =    4469.67 ms

target:

llama_print_timings:        load time =    6555.17 ms
llama_print_timings:      sample time =      28.34 ms /   238 runs   (    0.12 ms per token,  8397.14 tokens per second)
llama_print_timings: prompt eval time =    1361.73 ms /   412 tokens (    3.31 ms per token,   302.56 tokens per second)
llama_print_timings:        eval time =      51.10 ms /     1 runs   (   51.10 ms per token,    19.57 tokens per second)
llama_print_timings:       total time =    5682.48 ms

---

LLAMA_CUBLAS=1 make -j && ./speculative -m ./models/codellama-34b/ggml-model-f16.gguf -md ./models/codellama-7b/ggml-model-q4_0.gguf -p "// Below is an implementation in C++ of an algorithm that finds the convex hull of a set of 2D points:\n\n" -e -ngl 999 -ngld 999 -t 4 -n 512 -c 4096 -s 21 --draft 16 -np 1 --temp 0.0

 // Below is an implementation in C++ of an algorithm that finds the convex hull of a set of 2D points:

// https://en.wikipedia.org/wiki/Graham_scan

#include <iostream>
#include <vector>
#include <algorithm>
#include <cmath>
using namespace std;

struct Point {
    double x, y;
};

double cross(const Point &o, const Point &a, const Point &b) {
    return (a.x - o.x) * (b.y - o.y) - (a.y - o.y) * (b.x - o.x);
}

bool cmp(const Point &a, const Point &b) {
    if (a.x != b.x) return a.x < b.x;
    return a.y < b.y;
}

vector<Point> convex_hull(vector<Point> points) {
    int n = points.size(), k = 0;
    vector<Point> hull(2 * n);
    sort(points.begin(), points.end(), cmp);
    for (int i = 0; i < n; hull[k++] = points[i++]) {
        while (k >= 2 && cross(hull[k - 2], hull[k - 1], points[i]) <= 0) k--;
    }
    for (int i = n - 2, t = k + 1; i >= 0; hull[k++] = points[i--]) {
        while (k >= t && cross(hull[k - 2], hull[k - 1], points[i]) <= 0) k--;
    }
    return vector<Point>(hull.begin(), hull.begin() + k - (k > 1));
}

int main() {
    int n; cin >> n;
    vector<Point> points(n);
    for (int i = 0; i < n; ++i) {
        cin >> points[i].x >> points[i].y;
    }
    auto hull = convex_hull(points);
    cout << "The convex hull has " << hull.size() << " points\n";
}

encoded   29 tokens in    0.200 seconds, speed:  144.745 t/s
decoded  510 tokens in    7.867 seconds, speed:   64.824 t/s

n_draft   = 16
n_predict = 510
n_drafted = 642
n_accept  = 467
accept    = 72.741%

draft:

llama_print_timings:        load time =    1163.29 ms
llama_print_timings:      sample time =      87.45 ms /   644 runs   (    0.14 ms per token,  7363.96 tokens per second)
llama_print_timings: prompt eval time =      89.61 ms /    29 tokens (    3.09 ms per token,   323.64 tokens per second)
llama_print_timings:        eval time =    5088.07 ms /   685 runs   (    7.43 ms per token,   134.63 tokens per second)
llama_print_timings:       total time =    8067.89 ms

target:

llama_print_timings:        load time =    6548.10 ms
llama_print_timings:      sample time =      60.25 ms /   510 runs   (    0.12 ms per token,  8464.17 tokens per second)
llama_print_timings: prompt eval time =    2405.81 ms /   711 tokens (    3.38 ms per token,   295.53 tokens per second)
llama_print_timings:        eval time =     103.98 ms /     2 runs   (   51.99 ms per token,    19.23 tokens per second)
llama_print_timings:       total time =    9248.23 ms

---

LLAMA_CUBLAS=1 make -j && ./speculative -m ./models/codellama-34b/ggml-model-f16.gguf -md ./models/codellama-7b/ggml-model-f16.gguf -p "// Below is an implementation in C++ of an algorithm that finds the convex hull of a set of 2D points:\n\n" -e -ngl 999 -ngld 999 -t 4 -n 512 -c 4096 -s 21 --draft 16 -np 1 --temp 0.0

 // Below is an implementation in C++ of an algorithm that finds the convex hull of a set of 2D points:

// https://en.wikipedia.org/wiki/Graham_scan

#include <iostream>
#include <vector>
#include <algorithm>
#include <cmath>
using namespace std;

struct Point {
    double x, y;
};

double cross(const Point &o, const Point &a, const Point &b) {
    return (a.x - o.x) * (b.y - o.y) - (a.y - o.y) * (b.x - o.x);
}

bool cmp(const Point &a, const Point &b) {
    if (a.x != b.x) return a.x < b.x;
    return a.y < b.y;
}

vector<Point> convex_hull(vector<Point> points) {
    int n = points.size(), k = 0;
    vector<Point> hull(2 * n);
    sort(points.begin(), points.end(), cmp);
    for (int i = 0; i < n; hull[k++] = points[i++]) {
        while (k >= 2 && cross(hull[k - 2], hull[k - 1], points[i]) <= 0) k--;
    }
    for (int i = n - 2, t = k + 1; i >= 0; hull[k++] = points[i--]) {
        while (k >= t && cross(hull[k - 2], hull[k - 1], points[i]) <= 0) k--;
    }
    return vector<Point>(hull.begin(), hull.begin() + k - (k > 0));
}

int main() {
    int n; cin >> n;
    vector<Point> points(n);
    for (int i = 0; i < n; ++i) {
        cin >> points[i].x >> points[i].y;
    }
    auto hull = convex_hull(points);
    cout << "The convex hull has " << hull.size() << " points\n";
}

encoded   29 tokens in    0.132 seconds, speed:  219.056 t/s
decoded  510 tokens in   13.127 seconds, speed:   38.850 t/s

n_draft   = 16
n_predict = 510
n_drafted = 703
n_accept  = 465
accept    = 66.145%

draft:

llama_print_timings:        load time =    1860.53 ms
llama_print_timings:      sample time =      95.55 ms /   704 runs   (    0.14 ms per token,  7367.72 tokens per second)
llama_print_timings: prompt eval time =      19.90 ms /    29 tokens (    0.69 ms per token,  1457.43 tokens per second)
llama_print_timings:        eval time =   10197.04 ms /   748 runs   (   13.63 ms per token,    73.35 tokens per second)
llama_print_timings:       total time =   13260.31 ms

target:

llama_print_timings:        load time =    6642.24 ms
llama_print_timings:      sample time =      60.96 ms /   510 runs   (    0.12 ms per token,  8365.87 tokens per second)
llama_print_timings: prompt eval time =    2588.94 ms /   775 tokens (    3.34 ms per token,   299.35 tokens per second)
llama_print_timings:        eval time =      50.84 ms /     1 runs   (   50.84 ms per token,    19.67 tokens per second)
llama_print_timings:       total time =   15140.22 ms

@LiuXiaoxuanPKU Let us know if you attempt more A100 experiments and make sure to use the latest version of llama.cpp to get the best performance. Hope the numbers match better with the expectation now.

calvintwr commented 10 months ago

Hi @ggerganov I have some more data for you. I tried to speed up llama-2 70b with either 13b or 7b. In both cases, to no avail:

llama-13b-chat as draft model:

./speculative -m ../../../llama-2-70b-chat.Q6_K.gguf --threads 1 --n-gpu-layers 999 -md ../../../llama-2-13b-chat.Q8_0.gguf --n-gpu-layers-draft 999 -n 500 --prompt "<s>[INST]\nTell me about Joe Biden. [/INST] " --draft 8 --ctx-size 4096

 <s>[INST]\nTell me about Joe Biden. [/INST]  Joe Biden is the 46th President of the United States. He was born on November 20, 1942, in Scranton, Pennsylvania. He served as Vice President under Barack Obama from 2009 to 2017 and was elected President in 2020.

Biden earned a bachelor's degree from the University of Delaware and a law degree from Syracuse University. Before entering politics, he worked as a lawyer and served on the Senate staff. In 1970, he was elected to the New Castle County Council, and in 1972, he was elected to the United States Senate, where he served for six terms until 2009.

During his time in the Senate, Biden focused on issues related to criminal justice, foreign policy, and the rights of people with disabilities. He also served as chair of the Senate Foreign Relations Committee and was a strong advocate for the Violence Against Women Act.

In 2008, Biden was chosen by Barack Obama as his running mate in the presidential election. They won the election and served two terms together, during which time Biden focused on issues related to foreign policy and national security.

After leaving office, Biden established the Biden Foundation, a nonprofit organization focused on issues related to education, LGBTQ rights, and the prevention of sexual assault. He also began teaching at the University of Pennsylvania and authored several books.

In 2019, Biden announced his candidacy for the 2020 presidential election. He ran as a moderate Democrat, focusing on issues related to healthcare, education, and the economy. He won the nomination and went on to defeat incumbent President Donald Trump in the general election.

Biden's presidency has been marked by several significant accomplishments, including the passage of the American Rescue Plan, a $1.9 trillion stimulus package aimed at addressing the COVID-19 pandemic and economic downturn. He has also taken executive action to address climate change, expand access to healthcare, and protect the rights of LGBTQ individuals.

Biden has also pursued a number of foreign policy initiatives, including the

encoded   20 tokens in    0.607 seconds, speed:   32.959 t/s
decoded  503 tokens in   33.435 seconds, speed:   15.044 t/s

n_draft   = 8
n_predict = 503
n_drafted = 471
n_accept  = 399
accept    = 84.713%

draft:

llama_print_timings:        load time =    5963.01 ms
llama_print_timings:      sample time =    1645.45 ms /     1 runs   ( 1645.45 ms per token,     0.61 tokens per second)
llama_print_timings: prompt eval time =      67.36 ms /    20 tokens (    3.37 ms per token,   296.92 tokens per second)
llama_print_timings:        eval time =    9664.61 ms /   575 runs   (   16.81 ms per token,    59.50 tokens per second)
llama_print_timings:       total time =   34042.43 ms

target:

llama_print_timings:        load time =   16015.05 ms
llama_print_timings:      sample time =     275.97 ms /   503 runs   (    0.55 ms per token,  1822.64 tokens per second)
llama_print_timings: prompt eval time =   21001.56 ms /   578 tokens (   36.33 ms per token,    27.52 tokens per second)
llama_print_timings:        eval time =    1092.52 ms /    16 runs   (   68.28 ms per token,    14.65 tokens per second)

llama 7b chat as draft model:

./speculative -m ../../../llama-2-70b-chat.Q6_K.gguf --threads 1 --n-gpu-layers 999 -md ../../../llama-2-7b-chat.Q8_0.gguf --n-gpu-layers-draft 999 -n 500 --prompt "<s>[INST]\nTell me about Joe Biden. [/INST] " --draft 8 --ctx-size 4096

 <s>[INST]\nTell me about Joe Biden. [/INST]  Joe Biden is the 46th President of the United States. He was born on November 20, 1942, in Scranton, Pennsylvania. Biden served as Vice President under Barack Obama from 2009 to 2017 and represented Delaware in the United States Senate from 1973 to 2009. He is a member of the Democratic Party.

Biden graduated from the University of Delaware and Syracuse University College of Law. Before entering politics, he worked as a lawyer and served on the Senate staff. In 1972, Biden was elected to the New Castle County Council, and in 1970, he ran for the United States Senate, but he lost to incumbent Senator J. Caleb Boggs.

Biden was first elected to the Senate in 1972, at the age of 29, making him the youngest person to be elected to the Senate at the time. He served in the Senate for six terms, becoming one of the longest-serving Senators in American history. During his time in the Senate, Biden focused on issues related to criminal justice, foreign policy, and the rights of people with disabilities.

In 2008, Biden was chosen by Barack Obama as his running mate in the presidential election. They won the election, and Biden served as Vice President from 2009 to 2017. As Vice President, Biden focused on issues related to foreign policy, national security, and the economy.

In 2015, Biden announced that he would not run for President in the 2016 election, but he remained a prominent figure in the Democratic Party. In 2019, he announced his candidacy for the 2020 presidential election, and he won the nomination at the 2020 Democratic National Convention. Biden went on to defeat incumbent President Donald Trump in the general election, becoming the oldest person to be elected President of the United States at the age of 78.

Biden's presidency has focused on issues such as COVID-19 pandemic response, economic recovery, and addressing climate change. He has also taken steps to reform the immigration system,

encoded   20 tokens in    0.595 seconds, speed:   33.617 t/s
decoded  503 tokens in   30.448 seconds, speed:   16.520 t/s

n_draft   = 8
n_predict = 503
n_drafted = 426
n_accept  = 384
accept    = 90.141%

draft:

llama_print_timings:        load time =    3155.88 ms
llama_print_timings:      sample time =    1575.41 ms /     1 runs   ( 1575.41 ms per token,     0.63 tokens per second)
llama_print_timings: prompt eval time =      36.19 ms /    20 tokens (    1.81 ms per token,   552.72 tokens per second)
llama_print_timings:        eval time =    6042.37 ms /   545 runs   (   11.09 ms per token,    90.20 tokens per second)
llama_print_timings:       total time =   31043.23 ms

target:

llama_print_timings:        load time =   15269.15 ms
llama_print_timings:      sample time =     266.29 ms /   503 runs   (    0.53 ms per token,  1888.89 tokens per second)
llama_print_timings: prompt eval time =   21196.78 ms /   540 tokens (   39.25 ms per token,    25.48 tokens per second)
llama_print_timings:        eval time =    1635.22 ms /    24 runs   (   68.13 ms per token,    14.68 tokens per second)
llama_print_timings:       total time =   34226.27 ms

This is running the latest llama cpp as of now.

Also tried to adjust --draft from 4 all the way to 32, but still not able to get a speed up.

Also, the accept computations looks wrong. For example, we see here that accept rate for 7b model is 90%, whereaby 384 tokens are accepted. This would mean that given that llama-7b generates at 90 tokens/s, having 90% of its tokens accept would meant that the speed should be significantly close to 90tps. But we see that the resultant is only 16.5 tps.

ggerganov commented 10 months ago

Here is how to read the numbers:

n_predict = 503 is the total number of generated tokens. Having n_accept = 384 means that 503-384 = 119 tokens were generated on the target model at speed of 14.68 t/s - i.e. regular single-batch sampling on the big model.

n_drafted = 426 are the total number of tokens generated on the draft model at speed 90.20 t/s All of those drafted tokens were additionally evaluated with the target model at speed of about ~25.48 t/s (this number depends on the batch size so it is an approximation). So the average speed for these n_drafted tokens is: ~19.86 t/s (this takes into account the time to draft them and the time to evaluate them as batches on the target model).

Therefore we have:

So this should give expected total speed for n_predict = 503 tokens of about ~17.02 t/s. We measure 16.52 t/s. The difference might be explained by the variable batch size speed of the target model

You don't see a significant speedup likely because the drafts are evaluated in very small batches. Maybe try to reduce the acceptance threshold:

https://github.com/ggerganov/llama.cpp/blob/207b51900e15cc7f89763a3bb1c565fe11cbb45d/examples/speculative/speculative.cpp#L42-L43

It's currently hardcoded, so you will have to edit the code and recompile. Maybe try something like 0.2 - 0.3. This would make the draft batches larger in size and combined with the high acceptance rate that you have, it might help to increase the batch utilization of the target model and give you better performance.

Btw, thanks for looking into this. Definitely let us know you observations. I'm interested in this technique, and probably the current example can be improved in many ways.

ggerganov commented 10 months ago

Could you also post the result from this command:

LLAMA_CUBLAS=1 make -j && ./batched-bench ../../../llama-2-70b-chat.Q6_K.gguf 4096 1 99 1 512 128 1,2,3,4,5,6,7,8,16,32,64
calvintwr commented 10 months ago

Thanks for explaining the numbers to me!

Could you also post the result from this command:

LLAMA_CUBLAS=1 make -j && ./batched-bench ../../../llama-2-70b-chat.Q6_K.gguf 4096 1 99 1 512 128 1,2,3,4,5,6,7,8,16,32,64

@ggerganov what does make -j do? I don't use make because it gives me `nvcc fatal : Value 'native' is not defined for option 'gpu-architecture' error. This is on GCP.

I use cmake. Can you give me the equivalent? And I will run it.

KerfuffleV2 commented 10 months ago

I use cmake. Can you give me the equivalent? And I will run it.

It's not really important, it just enables parallel builds (so it makes compiling faster, the results are exactly the same).

The cmake build step is something like cmake --build build --config Release -j8 (for both make and cmake -j by itself just automatically chooses how many jobs to use. You can also specify it like in my example: -j8 - 8 jobs)

calvintwr commented 10 months ago

@ggerganov @KerfuffleV2 Here's the result of the batched-bench:

PP TG B N_KV T_PP s S_PP t/s T_TG s S_TG t/s T s S t/s
512 128 1 640 0.654 782.53 8.804 14.54 9.458 67.66
512 128 2 768 0.643 796.25 22.753 11.25 23.396 32.83
512 128 3 896 0.643 796.43 23.024 16.68 23.667 37.86
512 128 4 1024 0.643 796.19 23.220 22.05 23.863 42.91
512 128 5 1152 0.643 796.49 28.660 22.33 29.303 39.31
512 128 6 1280 0.643 796.28 28.843 26.63 29.486 43.41
512 128 7 1408 0.643 796.40 29.105 30.78 29.748 47.33
512 128 8 1536 0.643 796.40 29.271 34.98 29.914 51.35
512 128 16 2560 0.643 795.71 47.656 42.97 48.300 53.00

llama_print_timings: load time = 16006.36 ms llama_print_timings: sample time = 0.00 ms / 1 runs ( 0.00 ms per token, inf tokens per second) llama_print_timings: prompt eval time = 238681.05 ms / 11152 tokens ( 21.40 ms per token, 46.72 tokens per second) llama_print_timings: eval time = 8803.86 ms / 128 runs ( 68.78 ms per token, 14.54 tokens per second) llama_print_timings: total time = 263141.52 ms

ggerganov commented 10 months ago

Hm, yeah - the batched decoding performance for Q6_K seems terrible (the S_TG column). I guess we have still some work to do to improve this. If you try with an F16 target model, the gains will be much higher since the batch decoding speed with F16 scales much better with the batch size (the draft model can be quantum).

calvintwr commented 10 months ago

@ggerganov Got it. I would like to see if I can help. This is an amazing project. If you would be so kind to point me in the right direction in terms of source code.

ggerganov commented 10 months ago

I think we will invest more efforts in improving the quantum batched decoding performance after we finish some improvements to the GPU interface that are currently being developed. The solution will likely require implementing custom kernels and ops, so it will need some deeper understanding of the CUDA implementation and will probably involve some significant changes.

As I mentioned, you can test speculative sampling using an F16 target model with any quantum model and see how this performs. This would at least give confidence that the implemented strategy is working as it should (which is the main problem discussed in this issue) and later when the quantum batched decoding is improved, similar gains would be expected.

BarfingLemurs commented 10 months ago

Hi @ggerganov

I did some tests with the speculative example, and some quantizations appear to be fine when running 72% of the model in vram.

In the case of running this setup in the speculative example with a 70B Q3_K_S I get 1.3x speedup on all chat formats, offloading 57 layers to a 3090, with top-k 1 and all layers of the 1.5T tinyllama base Q4_K_M model. Where my 5.4 t/s goes up to 7.1 t/s.

This is the same speedup factor I am getting on pure cpu speculative sampling with the model. (1.3x, where 1.5 t/s goes to 2 t/s)

It's a general speedup, and shouldn't be limited to coding examples, it works for instruct/chat formats.

I also tried exllamav2's speculative sampling examples and sampling parameters. These give a speedup of mostly 1.5x on the chat responses. When running the Quicksort code example, I get 2-3x. (EDIT: the total may be mixed with prompt processing)

I also played with the 7B fp16 medusa model via their commandline interface, with default settings. This gave a consistent speedup of mostly 2x on the chat responses, compared to the original transformers.

github-actions[bot] commented 5 months ago

This issue was closed because it has been inactive for 14 days since being marked as stale.