ggerganov / llama.cpp

LLM inference in C/C++
MIT License
66.7k stars 9.58k forks source link

Bug: Segmentation fault when running speculative decoding #9949

Open rationalism opened 3 days ago

rationalism commented 3 days ago

What happened?

Running speculative decoding with the new Llama-3.1-405B-Instruct, with Llama-3.1-8B-Instruct as a draft model (with the large model on CPU and the small one on GPU), results in a segfault and core dump. (I don't think it's simply an out-of-memory error; 405B runs OK by itself with llama-server, albeit slowly.)

Command used: ./build/bin/llama-speculative -m ~/.cache/huggingface/hub/models--ThomasBaruzier--Meta-Llama-3.1-405B-Instruct-GGUF/snapshots/8545acf6b66386cbe0c37a7a099d634531c62a1c/Meta-Llama-3.1-405B-Instruct-IQ3_XXS/Meta-Llama-3.1-405B-Instruct-IQ3_XXS-00001-of-00004.gguf -fa -ngl 0 -ctk q4_0 -ctv q4_0 -co -md ~/.cache/huggingface/hub/models--bartowski--Meta-Llama-3.1-8B-Instruct-GGUF/snapshots/9a8dec50f04fa8fad1dc1e7bc20a84a512e2bb01/Meta-Llama-3.1-8B-Instruct-Q6_K_L.gguf -ngld 33

Name and Version

(llama) alyssa@alyssa-desktop:~/lm_fun/llama.cpp$ ./build/bin/llama-cli --version ggml_cuda_init: GGML_CUDA_FORCE_MMQ: no ggml_cuda_init: GGML_CUDA_FORCE_CUBLAS: no ggml_cuda_init: found 1 CUDA devices: Device 0: NVIDIA GeForce RTX 4070 Ti SUPER, compute capability 8.9, VMM: yes version: 3943 (cda0e4b6) built with cc (Ubuntu 13.2.0-23ubuntu4) 13.2.0 for x86_64-linux-gnu

What operating system are you seeing the problem on?

Linux

Relevant log output

(llama) alyssa@alyssa-desktop:~/lm_fun/llama.cpp$ ./build/bin/llama-speculative -m ~/.cache/huggingface/hub/models--ThomasBaruzier--Meta-Llama-3.1-405B-Instruct-GGUF/snapshots/8545acf6b66386cbe0c37a7a099d634531c62a1c/Meta-Llama-3.1-405B-Instruct-IQ3_XXS/Meta-Llama-3.1-405B-Instruct-IQ3_XXS-00001-of-00004.gguf -fa -ngl 0 -ctk q4_0 -ctv q4_0 -co -md ~/.cache/huggingface/hub/models--bartowski--Meta-Llama-3.1-8B-Instruct-GGUF/snapshots/9a8dec50f04fa8fad1dc1e7bc20a84a512e2bb01/Meta-Llama-3.1-8B-Instruct-Q6_K_L.gguf -ngld 33
ggml_cuda_init: GGML_CUDA_FORCE_MMQ:    no
ggml_cuda_init: GGML_CUDA_FORCE_CUBLAS: no
ggml_cuda_init: found 1 CUDA devices:
  Device 0: NVIDIA GeForce RTX 4070 Ti SUPER, compute capability 8.9, VMM: yes
build: 3943 (cda0e4b6) with cc (Ubuntu 13.2.0-23ubuntu4) 13.2.0 for x86_64-linux-gnu
llama_load_model_from_file: using device CUDA0 (NVIDIA GeForce RTX 4070 Ti SUPER) - 15381 MiB free
llama_model_loader: additional 3 GGUFs metadata loaded.
llama_model_loader: loaded meta data with 36 key-value pairs and 1138 tensors from /home/alyssa/.cache/huggingface/hub/models--ThomasBaruzier--Meta-Llama-3.1-405B-Instruct-GGUF/snapshots/8545acf6b66386cbe0c37a7a099d634531c62a1c/Meta-Llama-3.1-405B-Instruct-IQ3_XXS/Meta-Llama-3.1-405B-Instruct-IQ3_XXS-00001-of-00004.gguf (version GGUF V3 (latest))
llama_model_loader: Dumping metadata keys/values. Note: KV overrides do not apply in this output.
llama_model_loader: - kv   0:                       general.architecture str              = llama
llama_model_loader: - kv   1:                               general.type str              = model
llama_model_loader: - kv   2:                               general.name str              = .
llama_model_loader: - kv   3:                           general.finetune str              = .
llama_model_loader: - kv   4:                           general.basename str              = Meta-Llama-3.1
llama_model_loader: - kv   5:                         general.size_label str              = 405B
llama_model_loader: - kv   6:                            general.license str              = llama3.1
llama_model_loader: - kv   7:                               general.tags arr[str,6]       = ["facebook", "meta", "pytorch", "llam...
llama_model_loader: - kv   8:                          general.languages arr[str,8]       = ["en", "de", "fr", "it", "pt", "hi", ...
llama_model_loader: - kv   9:                          llama.block_count u32              = 126
llama_model_loader: - kv  10:                       llama.context_length u32              = 131072
llama_model_loader: - kv  11:                     llama.embedding_length u32              = 16384
llama_model_loader: - kv  12:                  llama.feed_forward_length u32              = 53248
llama_model_loader: - kv  13:                 llama.attention.head_count u32              = 128
llama_model_loader: - kv  14:              llama.attention.head_count_kv u32              = 8
llama_model_loader: - kv  15:                       llama.rope.freq_base f32              = 500000.000000
llama_model_loader: - kv  16:     llama.attention.layer_norm_rms_epsilon f32              = 0.000010
llama_model_loader: - kv  17:                          general.file_type u32              = 23
llama_model_loader: - kv  18:                           llama.vocab_size u32              = 128256
llama_model_loader: - kv  19:                 llama.rope.dimension_count u32              = 128
llama_model_loader: - kv  20:                       tokenizer.ggml.model str              = gpt2
llama_model_loader: - kv  21:                         tokenizer.ggml.pre str              = llama-bpe
llama_model_loader: - kv  22:                      tokenizer.ggml.tokens arr[str,128256]  = ["!", "\"", "#", "$", "%", "&", "'", ...
llama_model_loader: - kv  23:                  tokenizer.ggml.token_type arr[i32,128256]  = [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...
llama_model_loader: - kv  24:                      tokenizer.ggml.merges arr[str,280147]  = ["Ġ Ġ", "Ġ ĠĠĠ", "ĠĠ ĠĠ", "...
llama_model_loader: - kv  25:                tokenizer.ggml.bos_token_id u32              = 128000
llama_model_loader: - kv  26:                tokenizer.ggml.eos_token_id u32              = 128009
llama_model_loader: - kv  27:                    tokenizer.chat_template str              = {{- bos_token }}\n{%- if custom_tools ...
llama_model_loader: - kv  28:               general.quantization_version u32              = 2
llama_model_loader: - kv  29:                      quantize.imatrix.file str              = gguf/Meta-Llama-3.1-405B-Instruct/ima...
llama_model_loader: - kv  30:                   quantize.imatrix.dataset str              = misc/calibration_datav3.txt
llama_model_loader: - kv  31:             quantize.imatrix.entries_count i32              = 882
llama_model_loader: - kv  32:              quantize.imatrix.chunks_count i32              = 125
llama_model_loader: - kv  33:                                   split.no u16              = 0
llama_model_loader: - kv  34:                                split.count u16              = 4
llama_model_loader: - kv  35:                        split.tensors.count i32              = 1138
llama_model_loader: - type  f32:  254 tensors
llama_model_loader: - type q4_K:  126 tensors
llama_model_loader: - type q5_K:    1 tensors
llama_model_loader: - type iq3_xxs:  378 tensors
llama_model_loader: - type iq3_s:  127 tensors
llama_model_loader: - type iq2_s:  252 tensors
llm_load_vocab: special tokens cache size = 256
llm_load_vocab: token to piece cache size = 0.7999 MB
llm_load_print_meta: format           = GGUF V3 (latest)
llm_load_print_meta: arch             = llama
llm_load_print_meta: vocab type       = BPE
llm_load_print_meta: n_vocab          = 128256
llm_load_print_meta: n_merges         = 280147
llm_load_print_meta: vocab_only       = 0
llm_load_print_meta: n_ctx_train      = 131072
llm_load_print_meta: n_embd           = 16384
llm_load_print_meta: n_layer          = 126
llm_load_print_meta: n_head           = 128
llm_load_print_meta: n_head_kv        = 8
llm_load_print_meta: n_rot            = 128
llm_load_print_meta: n_swa            = 0
llm_load_print_meta: n_embd_head_k    = 128
llm_load_print_meta: n_embd_head_v    = 128
llm_load_print_meta: n_gqa            = 16
llm_load_print_meta: n_embd_k_gqa     = 1024
llm_load_print_meta: n_embd_v_gqa     = 1024
llm_load_print_meta: f_norm_eps       = 0.0e+00
llm_load_print_meta: f_norm_rms_eps   = 1.0e-05
llm_load_print_meta: f_clamp_kqv      = 0.0e+00
llm_load_print_meta: f_max_alibi_bias = 0.0e+00
llm_load_print_meta: f_logit_scale    = 0.0e+00
llm_load_print_meta: n_ff             = 53248
llm_load_print_meta: n_expert         = 0
llm_load_print_meta: n_expert_used    = 0
llm_load_print_meta: causal attn      = 1
llm_load_print_meta: pooling type     = 0
llm_load_print_meta: rope type        = 0
llm_load_print_meta: rope scaling     = linear
llm_load_print_meta: freq_base_train  = 500000.0
llm_load_print_meta: freq_scale_train = 1
llm_load_print_meta: n_ctx_orig_yarn  = 131072
llm_load_print_meta: rope_finetuned   = unknown
llm_load_print_meta: ssm_d_conv       = 0
llm_load_print_meta: ssm_d_inner      = 0
llm_load_print_meta: ssm_d_state      = 0
llm_load_print_meta: ssm_dt_rank      = 0
llm_load_print_meta: ssm_dt_b_c_rms   = 0
llm_load_print_meta: model type       = ?B
llm_load_print_meta: model ftype      = IQ3_XXS - 3.0625 bpw
llm_load_print_meta: model params     = 405.85 B
llm_load_print_meta: model size       = 145.14 GiB (3.07 BPW) 
llm_load_print_meta: general.name     = .
llm_load_print_meta: BOS token        = 128000 '<|begin_of_text|>'
llm_load_print_meta: EOS token        = 128009 '<|eot_id|>'
llm_load_print_meta: EOT token        = 128009 '<|eot_id|>'
llm_load_print_meta: EOM token        = 128008 '<|eom_id|>'
llm_load_print_meta: LF token         = 128 'Ä'
llm_load_print_meta: EOG token        = 128008 '<|eom_id|>'
llm_load_print_meta: EOG token        = 128009 '<|eot_id|>'
llm_load_print_meta: max token length = 256
llm_load_tensors: ggml ctx size =    0.53 MiB
llm_load_tensors: offloading 0 repeating layers to GPU
llm_load_tensors: offloaded 0/127 layers to GPU
llm_load_tensors:        CPU buffer size = 45213.72 MiB
llm_load_tensors:        CPU buffer size = 45425.75 MiB
llm_load_tensors:        CPU buffer size = 45190.75 MiB
llm_load_tensors:        CPU buffer size = 12789.19 MiB
....................................................................................................
llama_new_context_with_model: n_ctx      = 131072
llama_new_context_with_model: n_batch    = 2048
llama_new_context_with_model: n_ubatch   = 512
llama_new_context_with_model: flash_attn = 1
llama_new_context_with_model: freq_base  = 500000.0
llama_new_context_with_model: freq_scale = 1
llama_kv_cache_init:  CUDA_Host KV buffer size = 18144.00 MiB
llama_new_context_with_model: KV self size  = 18144.00 MiB, K (q4_0): 9072.00 MiB, V (q4_0): 9072.00 MiB
llama_new_context_with_model:  CUDA_Host  output buffer size =     0.49 MiB
llama_new_context_with_model:      CUDA0 compute buffer size =  1660.25 MiB
llama_new_context_with_model:  CUDA_Host compute buffer size =   288.01 MiB
llama_new_context_with_model: graph nodes  = 3535
llama_new_context_with_model: graph splits = 1642
common_init_from_params: warming up the model with an empty run - please wait ... (--no-warmup to disable)
llama_load_model_from_file: using device CUDA0 (NVIDIA GeForce RTX 4070 Ti SUPER) - 13671 MiB free
llama_model_loader: loaded meta data with 33 key-value pairs and 292 tensors from /home/alyssa/.cache/huggingface/hub/models--bartowski--Meta-Llama-3.1-8B-Instruct-GGUF/snapshots/9a8dec50f04fa8fad1dc1e7bc20a84a512e2bb01/Meta-Llama-3.1-8B-Instruct-Q6_K_L.gguf (version GGUF V3 (latest))
llama_model_loader: Dumping metadata keys/values. Note: KV overrides do not apply in this output.
llama_model_loader: - kv   0:                       general.architecture str              = llama
llama_model_loader: - kv   1:                               general.type str              = model
llama_model_loader: - kv   2:                               general.name str              = Meta Llama 3.1 8B Instruct
llama_model_loader: - kv   3:                           general.finetune str              = Instruct
llama_model_loader: - kv   4:                           general.basename str              = Meta-Llama-3.1
llama_model_loader: - kv   5:                         general.size_label str              = 8B
llama_model_loader: - kv   6:                            general.license str              = llama3.1
llama_model_loader: - kv   7:                               general.tags arr[str,6]       = ["facebook", "meta", "pytorch", "llam...
llama_model_loader: - kv   8:                          general.languages arr[str,8]       = ["en", "de", "fr", "it", "pt", "hi", ...
llama_model_loader: - kv   9:                          llama.block_count u32              = 32
llama_model_loader: - kv  10:                       llama.context_length u32              = 131072
llama_model_loader: - kv  11:                     llama.embedding_length u32              = 4096
llama_model_loader: - kv  12:                  llama.feed_forward_length u32              = 14336
llama_model_loader: - kv  13:                 llama.attention.head_count u32              = 32
llama_model_loader: - kv  14:              llama.attention.head_count_kv u32              = 8
llama_model_loader: - kv  15:                       llama.rope.freq_base f32              = 500000.000000
llama_model_loader: - kv  16:     llama.attention.layer_norm_rms_epsilon f32              = 0.000010
llama_model_loader: - kv  17:                          general.file_type u32              = 18
llama_model_loader: - kv  18:                           llama.vocab_size u32              = 128256
llama_model_loader: - kv  19:                 llama.rope.dimension_count u32              = 128
llama_model_loader: - kv  20:                       tokenizer.ggml.model str              = gpt2
llama_model_loader: - kv  21:                         tokenizer.ggml.pre str              = llama-bpe
llama_model_loader: - kv  22:                      tokenizer.ggml.tokens arr[str,128256]  = ["!", "\"", "#", "$", "%", "&", "'", ...
llama_model_loader: - kv  23:                  tokenizer.ggml.token_type arr[i32,128256]  = [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...
llama_model_loader: - kv  24:                      tokenizer.ggml.merges arr[str,280147]  = ["Ġ Ġ", "Ġ ĠĠĠ", "ĠĠ ĠĠ", "...
llama_model_loader: - kv  25:                tokenizer.ggml.bos_token_id u32              = 128000
llama_model_loader: - kv  26:                tokenizer.ggml.eos_token_id u32              = 128009
llama_model_loader: - kv  27:                    tokenizer.chat_template str              = {{- bos_token }}\n{%- if custom_tools ...
llama_model_loader: - kv  28:               general.quantization_version u32              = 2
llama_model_loader: - kv  29:                      quantize.imatrix.file str              = /models_out/Meta-Llama-3.1-8B-Instruc...
llama_model_loader: - kv  30:                   quantize.imatrix.dataset str              = /training_dir/calibration_datav3.txt
llama_model_loader: - kv  31:             quantize.imatrix.entries_count i32              = 224
llama_model_loader: - kv  32:              quantize.imatrix.chunks_count i32              = 125
llama_model_loader: - type  f32:   66 tensors
llama_model_loader: - type q8_0:    2 tensors
llama_model_loader: - type q6_K:  224 tensors
llm_load_vocab: special tokens cache size = 256
llm_load_vocab: token to piece cache size = 0.7999 MB
llm_load_print_meta: format           = GGUF V3 (latest)
llm_load_print_meta: arch             = llama
llm_load_print_meta: vocab type       = BPE
llm_load_print_meta: n_vocab          = 128256
llm_load_print_meta: n_merges         = 280147
llm_load_print_meta: vocab_only       = 0
llm_load_print_meta: n_ctx_train      = 131072
llm_load_print_meta: n_embd           = 4096
llm_load_print_meta: n_layer          = 32
llm_load_print_meta: n_head           = 32
llm_load_print_meta: n_head_kv        = 8
llm_load_print_meta: n_rot            = 128
llm_load_print_meta: n_swa            = 0
llm_load_print_meta: n_embd_head_k    = 128
llm_load_print_meta: n_embd_head_v    = 128
llm_load_print_meta: n_gqa            = 4
llm_load_print_meta: n_embd_k_gqa     = 1024
llm_load_print_meta: n_embd_v_gqa     = 1024
llm_load_print_meta: f_norm_eps       = 0.0e+00
llm_load_print_meta: f_norm_rms_eps   = 1.0e-05
llm_load_print_meta: f_clamp_kqv      = 0.0e+00
llm_load_print_meta: f_max_alibi_bias = 0.0e+00
llm_load_print_meta: f_logit_scale    = 0.0e+00
llm_load_print_meta: n_ff             = 14336
llm_load_print_meta: n_expert         = 0
llm_load_print_meta: n_expert_used    = 0
llm_load_print_meta: causal attn      = 1
llm_load_print_meta: pooling type     = 0
llm_load_print_meta: rope type        = 0
llm_load_print_meta: rope scaling     = linear
llm_load_print_meta: freq_base_train  = 500000.0
llm_load_print_meta: freq_scale_train = 1
llm_load_print_meta: n_ctx_orig_yarn  = 131072
llm_load_print_meta: rope_finetuned   = unknown
llm_load_print_meta: ssm_d_conv       = 0
llm_load_print_meta: ssm_d_inner      = 0
llm_load_print_meta: ssm_d_state      = 0
llm_load_print_meta: ssm_dt_rank      = 0
llm_load_print_meta: ssm_dt_b_c_rms   = 0
llm_load_print_meta: model type       = 8B
llm_load_print_meta: model ftype      = Q6_K
llm_load_print_meta: model params     = 8.03 B
llm_load_print_meta: model size       = 6.37 GiB (6.82 BPW) 
llm_load_print_meta: general.name     = Meta Llama 3.1 8B Instruct
llm_load_print_meta: BOS token        = 128000 '<|begin_of_text|>'
llm_load_print_meta: EOS token        = 128009 '<|eot_id|>'
llm_load_print_meta: EOT token        = 128009 '<|eot_id|>'
llm_load_print_meta: EOM token        = 128008 '<|eom_id|>'
llm_load_print_meta: LF token         = 128 'Ä'
llm_load_print_meta: EOG token        = 128008 '<|eom_id|>'
llm_load_print_meta: EOG token        = 128009 '<|eot_id|>'
llm_load_print_meta: max token length = 256
llm_load_tensors: ggml ctx size =    0.27 MiB
llm_load_tensors: offloading 32 repeating layers to GPU
llm_load_tensors: offloading non-repeating layers to GPU
llm_load_tensors: offloaded 33/33 layers to GPU
llm_load_tensors:        CPU buffer size =   532.31 MiB
llm_load_tensors:      CUDA0 buffer size =  5993.34 MiB
......................................................................................
llama_new_context_with_model: n_ctx      = 131072
llama_new_context_with_model: n_batch    = 2048
llama_new_context_with_model: n_ubatch   = 512
llama_new_context_with_model: flash_attn = 1
llama_new_context_with_model: freq_base  = 500000.0
llama_new_context_with_model: freq_scale = 1
llama_kv_cache_init:      CUDA0 KV buffer size =  4608.00 MiB
llama_new_context_with_model: KV self size  = 4608.00 MiB, K (q4_0): 2304.00 MiB, V (q4_0): 2304.00 MiB
llama_new_context_with_model:  CUDA_Host  output buffer size =     0.49 MiB
llama_new_context_with_model:      CUDA0 compute buffer size =   416.00 MiB
llama_new_context_with_model:  CUDA_Host compute buffer size =   264.01 MiB
llama_new_context_with_model: graph nodes  = 903
llama_new_context_with_model: graph splits = 2
common_init_from_params: warming up the model with an empty run - please wait ... (--no-warmup to disable)

<|begin_of_text|>

Segmentation fault (core dumped)
slaren commented 3 days ago

Can you try building with -DLLAMA_SANITIZE_ADDRESS=ON -DCMAKE_BUILD_TYPE=Debug to get a stack trace when it crashes? If CUDA doesn't work with address sanitizer enabled, try setting the environment variable ASAN_OPTIONS="protect_shadow_gap=0:replace_intrin=0:detect_leaks=0".

rationalism commented 3 days ago

@slaren Here's the stack trace

......................................................................................
llama_new_context_with_model: n_ctx      = 131072
llama_new_context_with_model: n_batch    = 2048
llama_new_context_with_model: n_ubatch   = 512
llama_new_context_with_model: flash_attn = 1
llama_new_context_with_model: freq_base  = 500000.0
llama_new_context_with_model: freq_scale = 1
llama_kv_cache_init:        CPU KV buffer size =  4608.00 MiB
llama_new_context_with_model: KV self size  = 4608.00 MiB, K (q4_0): 2304.00 MiB, V (q4_0): 2304.00 MiB
llama_new_context_with_model:        CPU  output buffer size =     0.49 MiB
llama_new_context_with_model:        CPU compute buffer size =   416.01 MiB
llama_new_context_with_model: graph nodes  = 903
llama_new_context_with_model: graph splits = 1
common_init_from_params: warming up the model with an empty run - please wait ... (--no-warmup to disable)

<|begin_of_text|>AddressSanitizer:DEADLYSIGNAL
=================================================================
==2453936==ERROR: AddressSanitizer: SEGV on unknown address 0xffffffffffffffff (pc 0x750b196fb331 bp 0x7fff3823e810 sp 0x7fff3823e7a0 T0)
==2453936==The signal is caused by a WRITE memory access.
    #0 0x750b196fb331 in llama_batch_allocr::llama_batch_allocr(llama_context*, llama_batch) (/home/alyssa/lm_fun/llama.cpp/build/src/libllama.so+0x2fb331) (BuildId: 38cc5dad531e28cca44c79c07bd8eb278eabf789)
    #1 0x750b196aa379 in llama_decode /home/alyssa/lm_fun/llama.cpp/src/llama.cpp:21195
    #2 0x5ae983c895b5 in main /home/alyssa/lm_fun/llama.cpp/examples/speculative/speculative.cpp:158
    #3 0x750b0242a1c9 in __libc_start_call_main ../sysdeps/nptl/libc_start_call_main.h:58
    #4 0x750b0242a28a in __libc_start_main_impl ../csu/libc-start.c:360
    #5 0x5ae983c88a74 in _start (/home/alyssa/lm_fun/llama.cpp/build/bin/llama-speculative+0x5aa74) (BuildId: 0e2b5ab693d2b3330b0476b4cb6390cee24af2a7)

AddressSanitizer can not provide additional info.
SUMMARY: AddressSanitizer: SEGV (/home/alyssa/lm_fun/llama.cpp/build/src/libllama.so+0x2fb331) (BuildId: 38cc5dad531e28cca44c79c07bd8eb278eabf789) in llama_batch_allocr::llama_batch_allocr(llama_context*, llama_batch)
==2453936==ABORTING
slaren commented 3 days ago

It seems to be caused by the changes in #9745. @ngxson can you take a look?

rationalism commented 3 days ago

@slaren Thanks! Reverting to an earlier checkpoint fixes this error, but now I see a different, apparently unrelated error

Logs:

llm_load_print_meta: general.name     = Meta Llama 3.1 8B Instruct
llm_load_print_meta: BOS token        = 128000 '<|begin_of_text|>'
llm_load_print_meta: EOS token        = 128009 '<|eot_id|>'
llm_load_print_meta: EOT token        = 128009 '<|eot_id|>'
llm_load_print_meta: EOM token        = 128008 '<|eom_id|>'
llm_load_print_meta: LF token         = 128 'Ä'
llm_load_print_meta: EOG token        = 128008 '<|eom_id|>'
llm_load_print_meta: EOG token        = 128009 '<|eot_id|>'
llm_load_print_meta: max token length = 256
llm_load_tensors: ggml ctx size =    0.27 MiB
llm_load_tensors: offloading 32 repeating layers to GPU
llm_load_tensors: offloading non-repeating layers to GPU
llm_load_tensors: offloaded 33/33 layers to GPU
llm_load_tensors:        CPU buffer size =   532.31 MiB
llm_load_tensors:      CUDA0 buffer size =  5993.34 MiB
......................................................................................
llama_new_context_with_model: n_ctx      = 131072
llama_new_context_with_model: n_batch    = 2048
llama_new_context_with_model: n_ubatch   = 512
llama_new_context_with_model: flash_attn = 1
llama_new_context_with_model: freq_base  = 500000.0
llama_new_context_with_model: freq_scale = 1
llama_kv_cache_init:      CUDA0 KV buffer size =  4608.00 MiB
llama_new_context_with_model: KV self size  = 4608.00 MiB, K (q4_0): 2304.00 MiB, V (q4_0): 2304.00 MiB
llama_new_context_with_model:  CUDA_Host  output buffer size =     0.49 MiB
llama_new_context_with_model:      CUDA0 compute buffer size =   416.00 MiB
llama_new_context_with_model:  CUDA_Host compute buffer size =   264.01 MiB
llama_new_context_with_model: graph nodes  = 903
llama_new_context_with_model: graph splits = 2
common_init_from_params: warming up the model with an empty run - please wait ... (--no-warmup to disable)

<|begin_of_text|>llama_decode_internal: n_tokens == 0
llama_decode: failed to decode, ret = -1
<|start_header_id|>/home/alyssa/lm_fun/llama.cpp/common/common.cpp:1484: GGML_ASSERT(batch.seq_id[batch.n_tokens] && "llama_batch size exceeded") failed
Could not attach to process.  If your uid matches the uid of the target
process, check the setting of /proc/sys/kernel/yama/ptrace_scope, or try
again as the root user.  For more details, see /etc/sysctl.d/10-ptrace.conf
ptrace: Operation not permitted.
No stack.
The program is not being run.

Aborted (core dumped)

Command:

git clone https://github.com/ggerganov/llama.cpp; cd llama.cpp; git checkout b3942; cmake -B build -DGGML_CUDA=ON; cmake --build build --config Release -j 24; ./build/bin/llama-speculative -m ~/.cache/huggingface/hub/models--ThomasBaruzier--Meta-Llama-3.1-405B-Instruct-GGUF/snapshots/8545acf6b66386cbe0c37a7a099d634531c62a1c/Meta-Llama-3.1-405B-Instruct-IQ3_XXS/Meta-Llama-3.1-405B-Instruct-IQ3_XXS-00001-of-00004.gguf -fa -ngl 0 -ctk q4_0 -ctv q4_0 -co -md ~/.cache/huggingface/hub/models--bartowski--Meta-Llama-3.1-8B-Instruct-GGUF/snapshots/9a8dec50f04fa8fad1dc1e7bc20a84a512e2bb01/Meta-Llama-3.1-8B-Instruct-Q6_K_L.gguf -ngld 33
slaren commented 3 days ago

Try specifying a prompt with -p.

rationalism commented 3 days ago

@slaren Tried adding -p "Hello", still got the same error. Here's the stack trace when I run as root:

llm_load_print_meta: max token length = 256
llm_load_tensors: ggml ctx size =    0.27 MiB
llm_load_tensors: offloading 32 repeating layers to GPU
llm_load_tensors: offloading non-repeating layers to GPU
llm_load_tensors: offloaded 33/33 layers to GPU
llm_load_tensors:        CPU buffer size =   532.31 MiB
llm_load_tensors:      CUDA0 buffer size =  5993.34 MiB
......................................................................................
llama_new_context_with_model: n_ctx      = 131072
llama_new_context_with_model: n_batch    = 2048
llama_new_context_with_model: n_ubatch   = 512
llama_new_context_with_model: flash_attn = 1
llama_new_context_with_model: freq_base  = 500000.0
llama_new_context_with_model: freq_scale = 1
llama_kv_cache_init:      CUDA0 KV buffer size =  4608.00 MiB
llama_new_context_with_model: KV self size  = 4608.00 MiB, K (q4_0): 2304.00 MiB, V (q4_0): 2304.00 MiB
llama_new_context_with_model:  CUDA_Host  output buffer size =     0.49 MiB
llama_new_context_with_model:      CUDA0 compute buffer size =   416.00 MiB
llama_new_context_with_model:  CUDA_Host compute buffer size =   264.01 MiB
llama_new_context_with_model: graph nodes  = 903
llama_new_context_with_model: graph splits = 2
common_init_from_params: warming up the model with an empty run - please wait ... (--no-warmup to disable)

<|begin_of_text|>Hellouser/home/alyssa/lm_fun/llama.cpp/common/common.cpp:1484: GGML_ASSERT(batch.seq_id[batch.n_tokens] && "llama_batch size exceeded") failed
[New LWP 2493122]
[New LWP 2493121]
[New LWP 2493120]
[New LWP 2493119]
[New LWP 2493118]
[New LWP 2493117]
[New LWP 2493116]
[New LWP 2492904]
[New LWP 2492903]
[New LWP 2492902]
[New LWP 2492893]
[Thread debugging using libthread_db enabled]
Using host libthread_db library "/lib/x86_64-linux-gnu/libthread_db.so.1".
0x000075ee441107e3 in __GI___wait4 (pid=2493377, stat_loc=0x7ffcdcad2eb4, options=0, usage=0x0) at ../sysdeps/unix/sysv/linux/wait4.c:30
warning: 30 ../sysdeps/unix/sysv/linux/wait4.c: No such file or directory
#0  0x000075ee441107e3 in __GI___wait4 (pid=2493377, stat_loc=0x7ffcdcad2eb4, options=0, usage=0x0) at ../sysdeps/unix/sysv/linux/wait4.c:30
30  in ../sysdeps/unix/sysv/linux/wait4.c
#1  0x000075ee4483cd22 in ggml_abort () from /home/alyssa/lm_fun/llama.cpp/build/ggml/src/libggml.so
#2  0x00005eabecda5f08 in common_batch_add(llama_batch&, int, int, std::vector<int, std::allocator<int> > const&, bool) ()
#3  0x00005eabecd5c0c7 in main ()
[Inferior 1 (process 2492892) detached]
Aborted
slaren commented 3 days ago

It should work with -c 512 -n 128, but it looks like the speculative example is broken in several ways.

slaren commented 3 days ago

At the very least, the batches are not initialized correctly. This may fix it partially, but there may be other issues.

diff --git a/examples/speculative/speculative.cpp b/examples/speculative/speculative.cpp
index 5a7b3084..8becd6ac 100644
--- a/examples/speculative/speculative.cpp
+++ b/examples/speculative/speculative.cpp
@@ -190,8 +190,8 @@ int main(int argc, char ** argv) {
         drafts[s].smpl = common_sampler_init(model_dft, params.sparams);
     }

-    llama_batch batch_dft = llama_batch_init(params.n_ctx, 0, 1);
-    llama_batch batch_tgt = llama_batch_init(params.n_ctx, 0, n_seq_dft);
+    llama_batch batch_dft = llama_batch_init(llama_n_ctx(ctx_dft), 0, 1);
+    llama_batch batch_tgt = llama_batch_init(llama_n_ctx(ctx_tgt), 0, n_seq_dft);

     const auto t_dec_start = ggml_time_us();

cc @ggerganov

rationalism commented 3 days ago

@slaren That fixes the crash, but it's still acting weirdly

With 405B by itself: 0.34 tokens/second (after prompt has been processed and the first token generated), which is roughly what you'd expect given the model size and the memory bandwidth With default 5 draft tokens = 5: 0.23 tokens/second With 16 draft tokens = 16: 0.19 tokens/second With 32 draft tokens = 32: 0.92 tokens/second

It seems to be significantly faster per forward pass with 32 draft tokens instead of 16, no idea why that might be

Note the draft model here is 50x smaller and is also on GPU (vs. 405B running on CPU), so this should be an ideal case for speculative decoding AFAICT.

ngxson commented 2 days ago

Hmm I'm not very familiar with the speculative.cpp example.

@rationalism Could you establish a baseline on how it should behave? Ideally, we can take smaller models for simplicity, something like 0.5B draft model and 7B base model. I'd suggest Qwen family, so it should be compatible with an old version of llama.cpp

slaren commented 2 days ago

@ngxson the reason I pinged you is because that specific crash in the speculative example is caused by passing a zero-size batch to llama_decode, which causes llama_batch_allocr to crash, probably in logits[logits.size() - 1] = true;. Regardless if this is not a correct call to llama_decode, it should not crash the application.

ggerganov commented 2 days ago

@ngxson the reason I pinged you is because that specific crash in the speculative example is caused by passing a zero-size batch to llama_decode, which causes llama_batch_allocr to crash, probably in logits[logits.size() - 1] = true;. Regardless if this is not a correct call to llama_decode, it should not crash the application.

Yup, when I reviewed #9745 I thought that the check for empty batch is there, but I didn't realize that it is after the llama_batch_allocr creation. So we should fix this case to not crash.

ngxson commented 1 day ago

Ok thanks, I didn't notice the case with logits[logits.size() - 1] = true; where batch is empty. Will fix it now.

rationalism commented 20 hours ago

@slaren Found another segfault (not fixed by either of the two PRs above). Log attached. I think I'm probably just calling the app wrong, but wanted to report segfaults because they're annoying to debug. Thanks :)

llama.log

slaren commented 20 hours ago

gguf_init_from_file: failed to open 'DRAFT_MODEL_PATH': 'No such file or directory'