Open ggerganov opened 7 months ago
I tried to add these both operations (in a very naive form, just the boilerplate and copy/paste of the CPU implementation, replacing memcpy by for loops, and using 1 block with WARP_SIZE==32 threads) here:
https://github.com/jploski/llama.cpp/commit/677ad0af0a9a38a02c933e14cb290457b42bd23a
While it works in the sense that the implementation generates plausible-looking output, it differs from the CPU version. The initial prompt seems to be processed incorrectly, with the first generated token already looking off, and the generation not influenced by the content of the prompt. Furthermore, changing batch size (the -b parameter) affects the generation (it shouldn't). So maybe something else needs to be improved besides the new ops implementation.
Sample correct output from CPU version:
/mnt/seagate/dalai/llama.cpp.bak/bin/main -m /mnt/f2fs/mamba/tinystories.gguf -t 1 -p 'Sara and Ben are playing in the snow.' --temp 0.0 -c 1024 --keep -1 --repeat_penalty 1.3 -n 1024
Sara and Ben are playing in the snow. They make a big snowman with a hat, a scarf and a carrot nose. They put on their warm clothes and go inside to get some cold chocolate.
"Look at my snowman!" Sara says. "He is so cool! He can talk like us."
Ben looks at his snowman and smiles. He likes the hat, the scarf and the carrot nose too. He wants to make a new friend.
Sara nods and takes her hat from the closet. She puts it on Ben's head and says, "Hello, Mr. Snowman! Do you want to play with me?"
Ben looks at Sara and smiles. He likes his snowman. He says, "Yes, Mr. Snowman! I like your hat, scarf and carrot nose too."
They hug each other and play together in the snow. They make a big snowman, a blue snowman, a red snowman and a yellow snowman. They are happy to have fun with their new friend. [end of text]
llama_print_timings: load time = 299.08 ms
llama_print_timings: sample time = 167.65 ms / 206 runs ( 0.81 ms per token, 1228.74 tokens per second)
llama_print_timings: prompt eval time = 123.71 ms / 11 tokens ( 11.25 ms per token, 88.92 tokens per second)
llama_print_timings: eval time = 5023.01 ms / 205 runs ( 24.50 ms per token, 40.81 tokens per second)
llama_print_timings: total time = 5394.42 ms / 216 tokens
Sample incorrect output from the GPU version:
Sara and Ben are playing in the snow. not too sure to eat it.swed wasmometimes been
"I am so happy when a room is gone," was were, into them witholidated. They had done something for their family.
They were out at the park andsled down together best best of amidt- covari.
The sun was in his eye not long before they were been to one. He wasmnoted to each other and this is where he wasered been been been been been been been been been done.
He had been been beby, so to all who have been been there been been been been been been taken out of them best part onlys in the long- march with us.
The end. [end of text]
llama_print_timings: load time = 1359.73 ms
llama_print_timings: sample time = 117.06 ms / 147 runs ( 0.80 ms per token, 1255.71 tokens per second)
llama_print_timings: prompt eval time = 52.04 ms / 11 tokens ( 4.73 ms per token, 211.40 tokens per second)
llama_print_timings: eval time = 911.71 ms / 146 runs ( 6.24 ms per token, 160.14 tokens per second)
llama_print_timings: total time = 1135.02 ms / 157 tokens
The GPU output looks especially mangled in the example above, but with a sinlge-token prompt consisting of a space it almost looks "normal" for this model, which makes me suspect that maybe it's something unrelated to the new CUDA kernels:
avorable, a little girl was walking in the park. She saw something strange and she wanted to take a look. She walked closer and saw that it was a big, round, green frog! The frog hopped around and said "Hello!"
The girl was so surprised but also excited. She asked the frog if he could help her. The frog smiled and said "Of course I can help you!" He took out some food from his bag and gave it to the girl.
"Thank you very much," she said with a big smile on her face. The frog then hopped away, but he was still there when she got home.
The girl was so happy that she had helped the frog. She thanked him again and went back to playing in the park. [end of text]
(The same also happens when I reduce the CUDA kernels to use just a single thread rather than 32.)
Another observation: with one offloaded layer (-ngl 1) I get the same output as for -ngl 0 (= CPU). But as the number of offloaded layers is increased, the output changes each time.
You can add tests for the ops in test-backend-ops
to compare the GPU and CPU implementations.
You can add tests for the ops in
test-backend-ops
to compare the GPU and CPU implementations.
Thanks. I added a test for each of the new ops here: https://github.com/ggerganov/llama.cpp/commit/35f2f868b14db3a0a67b1992c92895d1f4c2d847
The CPU vs. GPU comparison ("test-backend-ops test -o SSM_CONV", "test-backend-ops test -o SSM_SCAN") passes. Maybe because of non-representative input. I set "sq", whatever it means, to zero because initializing it to any other value triggered an assertion on the CPU backend. (I should mention that I have little clue about how the Mamba algorithm works in theory or practice; I just attempted a blind port FWIW.)
@jploski Thank you for doing this!
I set \"sq\", whatever it means, to zero because initializing it to any other value triggered an assertion on the CPU backend
Zero is correct, sq
is the seq_id
of the tokens corresponding to each part of the input. It was added for simultaneous sequence processing.
But note that this is being changed in #7531. sq
no longer exists there, because equal-length simultaneous sequences allow directly using the fourth tensor dimension to separate the states for each sequence, which simplifies SSM_SCAN
and SSM_CONV
a bit.
The CPU vs. GPU comparison ("test-backend-ops test -o SSM_CONV", "test-backend-ops test -o SSM_SCAN") passes. Maybe because of non-representative input.
Hmm. If both result in the same output with random input (for all tensors apart from sq
), maybe the problem isn't with the operators, but how the data gets in and out of them in practice?
I wonder if using ctx_layer
instead of ctx_split
for the tensors that get passed to ggml_ssm_conv
and ggml_ssm_scan
would help?
diff --git a/llama.cpp b/llama.cpp
index 841be1de..0c710f8d 100644
--- a/llama.cpp
+++ b/llama.cpp
@@ -6099,7 +6099,7 @@ static bool llm_load_tensors(
layer.ssm_in = ml.create_tensor(ctx_split, tn(LLM_TENSOR_SSM_IN, "weight", i), {n_embd, 2*d_inner});
- layer.ssm_conv1d = ml.create_tensor(ctx_split, tn(LLM_TENSOR_SSM_CONV1D, "weight", i), {d_conv, d_inner});
+ layer.ssm_conv1d = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_SSM_CONV1D, "weight", i), {d_conv, d_inner});
layer.ssm_conv1d_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_SSM_CONV1D, "bias", i), {d_inner});
layer.ssm_x = ml.create_tensor(ctx_split, tn(LLM_TENSOR_SSM_X, "weight", i), {d_inner, dt_rank + 2*d_state});
@@ -6108,7 +6108,7 @@ static bool llm_load_tensors(
layer.ssm_dt_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_SSM_DT, "bias", i), {d_inner});
// no "weight" suffix for these
- layer.ssm_a = ml.create_tensor(ctx_split, tn(LLM_TENSOR_SSM_A, i), {d_state, d_inner});
+ layer.ssm_a = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_SSM_A, i), {d_state, d_inner});
layer.ssm_d = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_SSM_D, i), {d_inner});
// out_proj
I have no idea what exactly this would do, but I guess it's worth a try?
I have no idea what exactly this would do, but I guess it's worth a try?
Replacing ctx_split by ctx_layer did not change anything.
Maybe it's good to wait with further tests until #7531 is merged (there's slight hope that the observed issue might go away...)
I have no idea what exactly this would do, but I guess it's worth a try?
Replacing ctx_split by ctx_layer did not change anything.
Maybe it's good to wait with further tests until #7531 is merged (there's slight hope that the observed issue might go away...)
So I was impatient, and applied my changes to #7531 here:
https://github.com/ggerganov/llama.cpp/commit/697fab6b7490ca62e702df60ba2228ca4e752dd1
The result is as before. The CPU output is the same as before (meaning that #7531 did not break anything.) The CPU vs. GPU backend tests for ssm_conv and ssm_scan with random data pass, but the actual generation derails (even faster than before, namely already with -ngl 1). The differences feel like there's some unintended "latent space exploration" going on, but with more GPU layers they get worse.
CPU:
Sara and Ben are playing in the snow. They make a big snowman with a hat, a scarf and a carrot nose. They put on their warm clothes and go inside to get some cold chocolate. "Look at my snowman!" Sara says. "He is so cool! He can talk like us." Ben looks at his snowman and smiles. He likes the hat, the scarf and the carrot nose too. He wants to make a new friend.
GPU (-ngl 1):
Sara and Ben are playing in the snow. They make a big snowman with a hat, a scarf and a carrot nose. They put on their warm clothes and go inside to have some hot chocolate. "Look at my snowman!" Sara says. "He is so cool! He can do anything." Ben looks around his room and sees that the snowman has no hat or scarf. He feels sad for him.
ctx_split
should only be used for matrices. Unless using CUDA with -sm row
, it is the same as ctx_layer
. If the results look reasonable and the tests pass, they might not be anything wrong, the results from different backends are always slightly different due to small floating point differences. perplexity is good for verifying that.
Btw, on master
when building with CUDA and running with -ngl 0
the ppl is bogus:
LLAMA_CUDA=1 make -j && ./perplexity -m models/mamba-130m/ggml-model-f16.gguf -f build-cublas/wikitext-2-raw/wiki.test.raw -ngl 0
ggml_cuda_init: GGML_CUDA_FORCE_MMQ: no
ggml_cuda_init: CUDA_USE_TENSOR_CORES: yes
ggml_cuda_init: found 1 CUDA devices:
Device 0: NVIDIA GeForce RTX 2060 SUPER, compute capability 7.5, VMM: yes
llm_load_tensors: ggml ctx size = 0.12 MiB
llm_load_tensors: offloading 0 repeating layers to GPU
llm_load_tensors: offloaded 0/25 layers to GPU
llm_load_tensors: CPU buffer size = 256.96 MiB
.................................................
llama_new_context_with_model: n_ctx = 2048
llama_new_context_with_model: n_batch = 2048
llama_new_context_with_model: n_ubatch = 512
llama_new_context_with_model: flash_attn = 0
llama_new_context_with_model: freq_base = 10000.0
llama_new_context_with_model: freq_scale = 1
llama_kv_cache_init: CUDA_Host KV buffer size = 10.69 MiB
llama_new_context_with_model: KV self size = 10.69 MiB, K (f32): 1.69 MiB, V (f32): 9.00 MiB
llama_new_context_with_model: CUDA_Host output buffer size = 0.77 MiB
llama_new_context_with_model: CUDA0 compute buffer size = 184.98 MiB
llama_new_context_with_model: CUDA_Host compute buffer size = 10.07 MiB
llama_new_context_with_model: graph nodes = 896
llama_new_context_with_model: graph splits = 292
system_info: n_threads = 16 / 32 | AVX = 1 | AVX_VNNI = 0 | AVX2 = 1 | AVX512 = 0 | AVX512_VBMI = 0 | AVX512_VNNI = 0 | AVX512_BF16 = 0 | FMA = 1 | NEON = 0 | SVE = 0 | ARM_FMA = 0 | F16C = 1 | FP16_VA = 0 | WASM_SIMD = 0 | BLAS = 1 | SSE3 = 1 | SSSE3 = 1 | VSX = 0 | MATMUL_INT8 = 0 | LLAMAFILE = 1 |
perplexity: tokenizing the input ..
perplexity: tokenization took 807.662 ms
perplexity: calculating perplexity over 650 chunks, n_ctx=512, batch_size=2048, n_seq=4
perplexity: 1.21 seconds per pass - ETA 3.27 minutes
[1]54658547752546633926049792.0000,[2]179461831369242333449027584.0000,[3]31554836485675804180611072.0000,[4]6706279437817913437847552.0000
If I build without CUDA, it is OK:
make -j && ./perplexity -m models/mamba-130m/ggml-model-f16.gguf -f build-cublas/wikitext-2-raw/wiki.test.raw
llm_load_tensors: ggml ctx size = 0.12 MiB
llm_load_tensors: CPU buffer size = 256.96 MiB
.................................................
llama_new_context_with_model: n_ctx = 2048
llama_new_context_with_model: n_batch = 2048
llama_new_context_with_model: n_ubatch = 512
llama_new_context_with_model: flash_attn = 0
llama_new_context_with_model: freq_base = 10000.0
llama_new_context_with_model: freq_scale = 1
llama_kv_cache_init: CPU KV buffer size = 10.69 MiB
llama_new_context_with_model: KV self size = 10.69 MiB, K (f32): 1.69 MiB, V (f32): 9.00 MiB
llama_new_context_with_model: CPU output buffer size = 0.77 MiB
llama_new_context_with_model: CPU compute buffer size = 99.71 MiB
llama_new_context_with_model: graph nodes = 896
llama_new_context_with_model: graph splits = 1
system_info: n_threads = 16 / 32 | AVX = 1 | AVX_VNNI = 0 | AVX2 = 1 | AVX512 = 0 | AVX512_VBMI = 0 | AVX512_VNNI = 0 | AVX512_BF16 = 0 | FMA = 1 | NEON = 0 | SVE = 0 | ARM_FMA = 0 | F16C = 1 | FP16_VA = 0 | WASM_SIMD = 0 | BLAS = 0 | SSE3 = 1 | SSSE3 = 1 | VSX = 0 | MATMUL_INT8 = 0 | LLAMAFILE = 1 |
perplexity: tokenizing the input ..
perplexity: tokenization took 805.807 ms
perplexity: calculating perplexity over 650 chunks, n_ctx=512, batch_size=2048, n_seq=4
perplexity: 1.43 seconds per pass - ETA 3.85 minutes
[1]17.9926,[2]26.3160,[3]26.9249,[4]27.3857,[5]25.7235,[6]24.7161,[7]24.1991,[8]24.1345
ctx_split
should only be used for matrices. Unless using CUDA with-sm row
, it is the same asctx_layer
. If the results look reasonable and the tests pass, they might not be anything wrong, the results from different backends are always slightly different due to small floating point differences. perplexity is good for verifying that.
Unfortunately, since the model output deteriorates to the point of not generating eos token, I am pretty sure it's not a slight difference in this case (even without inspecting the perplexity metric; as @ggerganov pointed out, there seems to be something wrong with the perplexity calculation for CUDA version as well).
The CPU vs. GPU backend tests for ssm_conv and ssm_scan with random data pass
@jploski Try making the tests process more tokens and more sequences at a time.
Looking at the tests, it doesn't really compare all outputs, since in #7531 the state inputs are also written back. I think I should change the output of the SSM operators back to concatenated tensors, just to make the output comparisons easier, and to make the graph dependencies more representative of the actual data dependencies.
Btw, on
master
when building with CUDA and running with-ngl 0
the ppl is bogus:
@ggerganov My hypothesis in this case is that it's probably related to where the KV cache is located.
llama_kv_cache_init: CUDA_Host KV buffer size = 10.69 MiB
llama_new_context_with_model: KV self size = 10.69 MiB, K (f32): 1.69 MiB, V (f32): 9.00 MiB
vs
llama_kv_cache_init: CPU KV buffer size = 10.69 MiB
llama_new_context_with_model: KV self size = 10.69 MiB, K (f32): 1.69 MiB, V (f32): 9.00 MiB
But maybe CUDA_Host
is the CPU? I'm not used to CUDA device names.
(EDIT: that name is defined in ggml-cuda.cu
.)
CUDA_Host
is just a pinned CPU buffer.
The CPU vs. GPU backend tests for ssm_conv and ssm_scan with random data pass
@jploski Try making the tests process more tokens and more sequences at a time.
Thanks, I applied your patch, and the test for ssm_conv fails now, which sounds like good progress!
/mnt/seagate/dalai/llama.cpp> bin/test-backend-ops test -o SSM_CONV
ggml_cuda_init: GGML_CUDA_FORCE_MMQ: no
ggml_cuda_init: CUDA_USE_TENSOR_CORES: yes
ggml_cuda_init: found 1 CUDA devices:
Device 0: NVIDIA GeForce RTX 3090, compute capability 8.6, VMM: yes
Testing 2 backends
Backend 1/2 (CPU)
Skipping CPU backend
Backend 2/2 (CUDA0)
Backend name: CUDA0
SSM_CONV(type=f32,d_conv=4,d_inner=1536,n_seq_tokens=7,n_seqs=2): [SSM_CONV] NMSE = 0.207277691 > 0.000000100 FAIL
1207/1208 tests passed
Backend CUDA0: FAIL
This should fix the CUDA ppl. Not sure if both conts are actually needed.
diff --git a/llama.cpp b/llama.cpp
index 841be1de..b311467a 100644
--- a/llama.cpp
+++ b/llama.cpp
@@ -10499,6 +10499,9 @@ struct llm_build_context {
struct ggml_tensor * x = ggml_view_2d(ctx0, xz, d_inner, xz->ne[1], xz->nb[1], 0);
struct ggml_tensor * z = ggml_view_2d(ctx0, xz, d_inner, xz->ne[1], xz->nb[1], ggml_element_size(xz)*d_inner);
+ x = ggml_cont(ctx0, x);
+ z = ggml_cont(ctx0, z);
+
// conv
{
// Custom operator which is needed only to ease simultaneous sequence processing.
This should fix the CUDA ppl. Not sure if both conts are actually needed.
I tried applying the patch to my PR #7531 branch around here: https://github.com/jploski/llama.cpp/blob/mamba_cuda_pr7531/llama.cpp#L8694 - and as a result perplexity calculation on CPU appeared ok. But with this fix the GPU-based generation fails as follows (so did not commit it):
Sara and Ben are playing in the snow. They maketerminate called after throwing an instance of 'std::out_of_range'
what(): vector::_M_range_check: __n (which is 2) >= this->size() (which is 2)
Aborted
I got the failing ssm_conv test working now - even if I increase n_tokens from 7 to 1024 and n_seqs to 8. The fail was due to a misplaced parenthesis (https://github.com/jploski/llama.cpp/commit/982a22e17f5b8066ea5ad801624682a1f0a7333c), which was a mistake I made today.
However, even with the test working, the output qualitatively degrades with an increasing number of GPU layers (as was the case with the original non-PR-7531 version, which did not suffer from that parenthesis bug).
So I'd say the overall state of the GPU port is now as bad as yesterday, but not worse - and we have more representative tests for ssm_conv and ssm_scan in place thanks to @compilade.
I managed to do the equivalent of GGML_OP_SSM_CONV
with ggml_concat
, ggml_im2col
, reshapes and a ggml_mul_mat
!!! And the memory usage is reasonable!
But performance is measurably worse for prompt processing (around 85%
of what it was with ggml_ssm_conv
for pp512
on my CPU with a 100k parameter Mamba model). However, text generation speed is very similar as before.
I think ggml_mul_mat
might not be optimized for small row sizes (4 elements per row in this case). Or maybe the overhead is due to im2col
. Not sure.
I'm not sure if I should remove GGML_OP_SSM_CONV
yet.
@jploski note that I also changed ggml_ssm_scan
to make it not modify its inputs. This should make tests more representative, because otherwise they did not check the final state. It's now including the final state in its output, similarly to how it is in master
, but a bit simpler, due to the equal-sequence-length batching.
I managed to do the equivalent of
GGML_OP_SSM_CONV
withggml_concat
,ggml_im2col
, reshapes and aggml_mul_mat
!!! And the memory usage is reasonable!
I updated my branch to catch up, but the non-ssm_conv implementation now fails on the GPU because apparently MUL_MAT is not completely implemented (or maybe you could work around with some transpose, didn't examine it further):
ggml_cuda_compute_forward: cannot compute node_31: src0->ne[3] = 1, src1->ne[3] = 2 - fallback to CPU
ggml_backend_cuda_graph_compute: op not supported node_31 (MUL_MAT)
GGML_ASSERT: /mnt/seagate/dalai/llama.cpp/ggml-cuda.cu:2675: ok
I imagine that a version with a more coarse-granular op might offer more potential for optimization on the GPU, keyword "fused kernel" (not claiming that I know how to do it for CUDA specifically, but generally knowing "what comes next" in a computation pipeline aids optimization).
I updated my branch to catch up, but the non-ssm_conv implementation now fails on the GPU because apparently MUL_MAT is not completely implemented (or maybe you could work around with some transpose, didn't examine it further)
@jploski Thank you for trying this.
I did not expect it would fail like this. It seems this is also a problem on SYCL and Vulkan, which also expect ne[3]
to be the same on both tensors when doing MUL_MAT
.
A transpose can't work around this. The first tensor in that MUL_MAT
really has to be broadcast on the second tensor over the fourth dimension.
I guess it will likely be simpler (and faster) with GGML_OP_SSM_CONV
then.
I'm not sure if I should or not fuse the concat into it. The implementation is much simpler when it's separate (the shift doesn't need to move the state, it's only +1 on a pointer, no need to shift a temporary buffer), but it's also slighly slower (97%
of tg128
and pp512
) on CPU, although it might be due to something else (EDIT: when making CONCAT
work with a transposed src1
, there is practically almost no performance difference (on CPU) vs the previous SSM_CONV
which fused the concat).
I did not expect it would fail like this. It seems this is also a problem on SYCL and Vulkan, which also expect ne[3] to be the same on both tensors when doing MUL_MAT.
The op can be extended to support broadcast in all backends. If it's not too much hassle, try to keep the im2col + mul_mat
version around (even behind ifdef
if necessary) and will try to make it work on the GPU at some point
I think it's also possible to replace that ggml_mul_mat
with a ggml_sum_rows
and ggml_mul
, but again, it's less performant than fusing it all in ggml_ssm_conv
.
(around 95% tg128
and 81% pp512
on my CPU, when comparing MUL
+SUM_ROWS
with SSM_CONV
, while MUL_MAT
gets around 94% tg128
and 79% pp512
, so strangely, MUL
+SUM_ROWS
is slightly faster than MUL_MAT
in this case, with a very small model like https://huggingface.co/delphi-suite/v0-mamba-100k, and also with bigger models, I think)
I compared the tensor contents of CPU and GPU versions to find the source of first discrepancy. While small rounding errors which might account for confusing hot chocolate with cold chocolate are evident from the get go, the first GPU tensor with a really alarming difference was the one resulting from ggml_silu(ctx, z). I wrapped it with ggml_cont to fix:
https://github.com/ggerganov/llama.cpp/commit/7509b9eed26ec8424cfbf2da87f56ab056fc76aa
With that (and using ssm_conv for reasons explained previously) the tensor looks almost identical. Although the facts about Sara's and Ben's encounter with the abominable snowman still do not match between CPU and GPU, I no longer get strange tokens, and it remains on topic regardless of the CPU/GPU layer split.
I pushed a new branch https://github.com/jploski/llama.cpp/tree/falcon_mamba_cuda, which is based on the recent master of llama.cpp with reapplied patches from my original attempt in June (https://github.com/jploski/llama.cpp/tree/mamba_cuda_pr7531)
With one small additional fix (https://github.com/ggerganov/llama.cpp/commit/fae826fb56b6f40e73fb4721e11adc1cf2795431) this implementation is now working. I tested it with https://huggingface.co/tiiuae/falcon-mamba-7b-instruct. It produces coherent output in f16, q8_0 and q5_k_m quantizations.
Maybe someone with CUDA experience could have a look at the ggml/src/ggml-cuda/ssm_scan.cu and ggml/src/ggml-cuda/ssm_conv.cu regarding grid configuration and memory access patterns in those kernels (remember I just copy-pasted the CPU version without regard for what is good for CUDA's parallel execution; so it could probably be optimized for performance).
I optimized the 10x performance (A100) based on @jploski ’s work.the pr is https://github.com/ggerganov/llama.cpp/pull/9186.
Maybe it's just a small promotion, we have implemented initial SYCL support and CUDA support for RWKV, everyone is welcome to use and improve! https://github.com/ggerganov/llama.cpp/pull/10133
Recently, initial Mamba support (CPU-only) has been introduced in #5328 by @compilade
In order to support running these models efficiently on the GPU, we seem to be lacking kernel implementations for the following 2 ops:
GGML_OP_SSM_CONV
GGML_OP_SSM_SCAN
Creating this issue to keep track of this and give more visibility of this feature. Help with implementing the missing kernels for CUDA and Metal (and other backends potentially) is welcome. We can also discuss if anything else is required to better support this architecture in
llama.cpp