ggerganov / llama.cpp

LLM inference in C/C++
MIT License
67.79k stars 9.73k forks source link

Bug: llama-cli generates incoherent output with full gpu offload #9535

Closed 8XXD8 closed 1 month ago

8XXD8 commented 1 month ago

What happened?

Offloading 31 layers out of the 33 with an 8b model produces correct results, with 32 layers, the response is incoherent. 33 or more offloaded layers cause the instruction to be ignored, with seed 1, with any other seed, no response is printed. This affects conversational and normal modes as well. llama-server functions without problem.

Name and Version

version: 3782 (8a308354) built with clang version 20.0.0git (https://github.com/ROCm/llvm-project.git 487d0fd20dcbb6fbf926333d7b0b355788efb009) for x86_64-unknown-linux-gnu

What operating system are you seeing the problem on?

No response

Relevant log output

**31 layer offloaded:**
HIP_VISIBLE_DEVICES=1 ./llama-cli -m  ~/text*/models/Meta-Llama-3-8B-Instruct-Q8_0.gguf -p "Tell a joke " -s 1 -n 64 -ngl 31
build: 3782 (8a308354) with clang version 20.0.0git (https://github.com/ROCm/llvm-project.git 487d0fd20dcbb6fbf926333d7b0b355788efb009) for x86_64-unknown-linux-gnu
main: llama backend init
main: load the model and apply lora adapter, if any
llama_model_loader: loaded meta data with 26 key-value pairs and 291 tensors from /home/user/text-generation-webui/models/Meta-Llama-3-8B-Instruct-Q8_0.gguf (version GGUF V3 (latest))
llama_model_loader: Dumping metadata keys/values. Note: KV overrides do not apply in this output.
llama_model_loader: - kv   0:                       general.architecture str              = llama
llama_model_loader: - kv   1:                               general.name str              = Meta-Llama-3-8B-Instruct
llama_model_loader: - kv   2:                          llama.block_count u32              = 32
llama_model_loader: - kv   3:                       llama.context_length u32              = 8192
llama_model_loader: - kv   4:                     llama.embedding_length u32              = 4096
llama_model_loader: - kv   5:                  llama.feed_forward_length u32              = 14336
llama_model_loader: - kv   6:                 llama.attention.head_count u32              = 32
llama_model_loader: - kv   7:              llama.attention.head_count_kv u32              = 8
llama_model_loader: - kv   8:                       llama.rope.freq_base f32              = 500000.000000
llama_model_loader: - kv   9:     llama.attention.layer_norm_rms_epsilon f32              = 0.000010
llama_model_loader: - kv  10:                          general.file_type u32              = 7
llama_model_loader: - kv  11:                           llama.vocab_size u32              = 128256
llama_model_loader: - kv  12:                 llama.rope.dimension_count u32              = 128
llama_model_loader: - kv  13:                       tokenizer.ggml.model str              = gpt2
llama_model_loader: - kv  14:                         tokenizer.ggml.pre str              = llama-bpe
llama_model_loader: - kv  15:                      tokenizer.ggml.tokens arr[str,128256]  = ["!", "\"", "#", "$", "%", "&", "'", ...
llama_model_loader: - kv  16:                  tokenizer.ggml.token_type arr[i32,128256]  = [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...
llama_model_loader: - kv  17:                      tokenizer.ggml.merges arr[str,280147]  = ["Ġ Ġ", "Ġ ĠĠĠ", "ĠĠ ĠĠ", "...
llama_model_loader: - kv  18:                tokenizer.ggml.bos_token_id u32              = 128000
llama_model_loader: - kv  19:                tokenizer.ggml.eos_token_id u32              = 128001
llama_model_loader: - kv  20:                    tokenizer.chat_template str              = {% set loop_messages = messages %}{% ...
llama_model_loader: - kv  21:               general.quantization_version u32              = 2
llama_model_loader: - kv  22:                      quantize.imatrix.file str              = /models/Meta-Llama-3-8B-Instruct-GGUF...
llama_model_loader: - kv  23:                   quantize.imatrix.dataset str              = /training_data/groups_merged.txt
llama_model_loader: - kv  24:             quantize.imatrix.entries_count i32              = 224
llama_model_loader: - kv  25:              quantize.imatrix.chunks_count i32              = 88
llama_model_loader: - type  f32:   65 tensors
llama_model_loader: - type q8_0:  226 tensors
llm_load_vocab: special tokens cache size = 256
llm_load_vocab: token to piece cache size = 0.8000 MB
llm_load_print_meta: format           = GGUF V3 (latest)
llm_load_print_meta: arch             = llama
llm_load_print_meta: vocab type       = BPE
llm_load_print_meta: n_vocab          = 128256
llm_load_print_meta: n_merges         = 280147
llm_load_print_meta: vocab_only       = 0
llm_load_print_meta: n_ctx_train      = 8192
llm_load_print_meta: n_embd           = 4096
llm_load_print_meta: n_layer          = 32
llm_load_print_meta: n_head           = 32
llm_load_print_meta: n_head_kv        = 8
llm_load_print_meta: n_rot            = 128
llm_load_print_meta: n_swa            = 0
llm_load_print_meta: n_embd_head_k    = 128
llm_load_print_meta: n_embd_head_v    = 128
llm_load_print_meta: n_gqa            = 4
llm_load_print_meta: n_embd_k_gqa     = 1024
llm_load_print_meta: n_embd_v_gqa     = 1024
llm_load_print_meta: f_norm_eps       = 0.0e+00
llm_load_print_meta: f_norm_rms_eps   = 1.0e-05
llm_load_print_meta: f_clamp_kqv      = 0.0e+00
llm_load_print_meta: f_max_alibi_bias = 0.0e+00
llm_load_print_meta: f_logit_scale    = 0.0e+00
llm_load_print_meta: n_ff             = 14336
llm_load_print_meta: n_expert         = 0
llm_load_print_meta: n_expert_used    = 0
llm_load_print_meta: causal attn      = 1
llm_load_print_meta: pooling type     = 0
llm_load_print_meta: rope type        = 0
llm_load_print_meta: rope scaling     = linear
llm_load_print_meta: freq_base_train  = 500000.0
llm_load_print_meta: freq_scale_train = 1
llm_load_print_meta: n_ctx_orig_yarn  = 8192
llm_load_print_meta: rope_finetuned   = unknown
llm_load_print_meta: ssm_d_conv       = 0
llm_load_print_meta: ssm_d_inner      = 0
llm_load_print_meta: ssm_d_state      = 0
llm_load_print_meta: ssm_dt_rank      = 0
llm_load_print_meta: ssm_dt_b_c_rms   = 0
llm_load_print_meta: model type       = 8B
llm_load_print_meta: model ftype      = Q8_0
llm_load_print_meta: model params     = 8.03 B
llm_load_print_meta: model size       = 7.95 GiB (8.50 BPW)
llm_load_print_meta: general.name     = Meta-Llama-3-8B-Instruct
llm_load_print_meta: BOS token        = 128000 '<|begin_of_text|>'
llm_load_print_meta: EOS token        = 128001 '<|end_of_text|>'
llm_load_print_meta: LF token         = 128 'Ä'
llm_load_print_meta: EOT token        = 128009 '<|eot_id|>'
llm_load_print_meta: max token length = 256
ggml_cuda_init: GGML_CUDA_FORCE_MMQ:    no
ggml_cuda_init: GGML_CUDA_FORCE_CUBLAS: no
ggml_cuda_init: found 1 ROCm devices:
  Device 0: AMD Radeon Pro VII, compute capability 9.0, VMM: no
llm_load_tensors: ggml ctx size =    0.27 MiB
llm_load_tensors: offloading 31 repeating layers to GPU
llm_load_tensors: offloaded 31/33 layers to GPU
llm_load_tensors:      ROCm0 buffer size =  6851.97 MiB
llm_load_tensors:        CPU buffer size =  8137.64 MiB
.........................................................................................
llama_new_context_with_model: n_ctx      = 8192
llama_new_context_with_model: n_batch    = 2048
llama_new_context_with_model: n_ubatch   = 512
llama_new_context_with_model: flash_attn = 0
llama_new_context_with_model: freq_base  = 500000.0
llama_new_context_with_model: freq_scale = 1
llama_kv_cache_init:      ROCm0 KV buffer size =   992.00 MiB
llama_kv_cache_init:  ROCm_Host KV buffer size =    32.00 MiB
llama_new_context_with_model: KV self size  = 1024.00 MiB, K (f16):  512.00 MiB, V (f16):  512.00 MiB
llama_new_context_with_model:  ROCm_Host  output buffer size =     0.49 MiB
llama_new_context_with_model:      ROCm0 compute buffer size =   798.81 MiB
llama_new_context_with_model:  ROCm_Host compute buffer size =    24.01 MiB
llama_new_context_with_model: graph nodes  = 1030
llama_new_context_with_model: graph splits = 15
llama_init_from_gpt_params: warming up the model with an empty run - please wait ... (--no-warmup to disable)
main: llama threadpool init, n_threads = 64

system_info: n_threads = 64 (n_threads_batch = 64) / 128 | AVX = 1 | AVX_VNNI = 0 | AVX2 = 1 | AVX512 = 0 | AVX512_VBMI = 0 | AVX512_VNNI = 0 | AVX512_BF16 = 0 | FMA = 1 | NEON = 0 | SVE = 0 | ARM_FMA = 0 | F16C = 1 | FP16_VA = 0 | RISCV_VECT = 0 | WASM_SIMD = 0 | BLAS = 1 | SSE3 = 1 | SSSE3 = 1 | VSX = 0 | MATMUL_INT8 = 0 | LLAMAFILE = 1 |

sampler seed: 1
sampler params:
        repeat_last_n = 64, repeat_penalty = 1.000, frequency_penalty = 0.000, presence_penalty = 0.000
        top_k = 40, tfs_z = 1.000, top_p = 0.950, min_p = 0.050, typical_p = 1.000, temp = 0.800
        mirostat = 0, mirostat_lr = 0.100, mirostat_ent = 5.000
sampler chain: logits -> logit-bias -> penalties -> top-k -> tail-free -> typical -> top-p -> min-p -> temp-ext -> softmax -> dist
generate: n_ctx = 8192, n_batch = 2048, n_predict = 64, n_keep = 1

Tell a joke  Tell a joke
Bert's got a bad joke for you!
Why couldn't the bicycle stand up by itself?
(wait for it...)
Because it was two-tired! Hahahaha! Get it? Two-tired? Like a bike has two tires, but it's also tired... Ah, nevermind!

llama_perf_sampler_print:    sampling time =       6.67 ms /    69 runs   (    0.10 ms per token, 10343.28 tokens per second)
llama_perf_context_print:        load time =    2367.69 ms
llama_perf_context_print: prompt eval time =      52.25 ms /     5 tokens (   10.45 ms per token,    95.69 tokens per second)
llama_perf_context_print:        eval time =    1815.69 ms /    63 runs   (   28.82 ms per token,    34.70 tokens per second)
llama_perf_context_print:       total time =    1884.21 ms /    68 tokens

**32 layer offloaded:**
sampler seed: 1
sampler params:
        repeat_last_n = 64, repeat_penalty = 1.000, frequency_penalty = 0.000, presence_penalty = 0.000
        top_k = 40, tfs_z = 1.000, top_p = 0.950, min_p = 0.050, typical_p = 1.000, temp = 0.800
        mirostat = 0, mirostat_lr = 0.100, mirostat_ent = 5.000
sampler chain: logits -> logit-bias -> penalties -> top-k -> tail-free -> typical -> top-p -> min-p -> temp-ext -> softmax -> dist
generate: n_ctx = 8192, n_batch = 2048, n_predict = 64, n_keep = 1

Tell a joke atedRoute : + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +

llama_perf_sampler_print:    sampling time =       6.57 ms /    69 runs   (    0.10 ms per token, 10510.28 tokens per second)
llama_perf_context_print:        load time =    2447.36 ms
llama_perf_context_print: prompt eval time =      72.25 ms /     5 tokens (   14.45 ms per token,    69.21 tokens per second)
llama_perf_context_print:        eval time =    1962.80 ms /    63 runs   (   31.16 ms per token,    32.10 tokens per second)
llama_perf_context_print:       total time =    2051.52 ms /    68 token

**33 layer offloaded:**
sampler seed: 1
sampler params:
        repeat_last_n = 64, repeat_penalty = 1.000, frequency_penalty = 0.000, presence_penalty = 0.000
        top_k = 40, tfs_z = 1.000, top_p = 0.950, min_p = 0.050, typical_p = 1.000, temp = 0.800
        mirostat = 0, mirostat_lr = 0.100, mirostat_ent = 5.000
sampler chain: logits -> logit-bias -> penalties -> top-k -> tail-free -> typical -> top-p -> min-p -> temp-ext -> softmax -> dist
generate: n_ctx = 8192, n_batch = 2048, n_predict = 64, n_keep = 1

Tell a joke assistant

I'm here to help you with any questions or topics you'd like to discuss! [end of text]

llama_perf_sampler_print:    sampling time =       1.18 ms /    27 runs   (    0.04 ms per token, 22939.68 tokens per second)
llama_perf_context_print:        load time =    2607.18 ms
llama_perf_context_print: prompt eval time =      64.18 ms /     5 tokens (   12.84 ms per token,    77.91 tokens per second)
llama_perf_context_print:        eval time =     595.69 ms /    21 runs   (   28.37 ms per token,    35.25 tokens per second)
llama_perf_context_print:       total time =     663.40 ms /    26 tokens

**Random seed with -ngl 99:**
sampler seed: 1190865593
sampler params:
        repeat_last_n = 64, repeat_penalty = 1.000, frequency_penalty = 0.000, presence_penalty = 0.000
        top_k = 40, tfs_z = 1.000, top_p = 0.950, min_p = 0.050, typical_p = 1.000, temp = 0.800
        mirostat = 0, mirostat_lr = 0.100, mirostat_ent = 5.000
sampler chain: logits -> logit-bias -> penalties -> top-k -> tail-free -> typical -> top-p -> min-p -> temp-ext -> softmax -> dist
generate: n_ctx = 8192, n_batch = 2048, n_predict = 64, n_keep = 1

Tell a joke

llama_perf_sampler_print:    sampling time =       3.69 ms /    69 runs   (    0.05 ms per token, 18719.48 tokens per second)
llama_perf_context_print:        load time =    2606.19 ms
llama_perf_context_print: prompt eval time =      62.95 ms /     5 tokens (   12.59 ms per token,    79.42 tokens per second)
llama_perf_context_print:        eval time =    1848.42 ms /    63 runs   (   29.34 ms per token,    34.08 tokens per second)
llama_perf_context_print:       total time =    1920.18 ms /    68 tokens
ggerganov commented 1 month ago

I'm able to reproduce - not sure what is the cause. Can you help trace which commit introduces the regression?

ggerganov commented 1 month ago

Looks like 0226613853133c081b55bb892a41bb5eacc0bc94 introduces the regression. I believe @max-krasnyansky is working on resolving it.

slaren commented 1 month ago

This should fix the issue:

diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c
index bccb6237..2e8c806c 100644
--- a/ggml/src/ggml.c
+++ b/ggml/src/ggml.c
@@ -20239,6 +20239,7 @@ enum ggml_status ggml_graph_compute(struct ggml_cgraph * cgraph, struct ggml_cpl
             ggml_graph_compute_thread(&threadpool->workers[omp_get_thread_num()]);
         }
     } else {
+        threadpool->n_threads_cur = 1;
         ggml_graph_compute_thread(&threadpool->workers[0]);
     }
 #else