ggerganov / llama.cpp

LLM inference in C/C++
MIT License
67.29k stars 9.67k forks source link

Bug: CUDA error: CUBLAS_STATUS_NOT_INITIALIZED #10080

Open morgen52 opened 6 days ago

morgen52 commented 6 days ago

What happened?

Hi there.

My llama-server can work well with the following command:

/llama.cpp-b3985/build_gpu/bin/llama-server -m ../artifact/models/Mistral-7B-Instruct-v0.3.Q4_1.gguf -ngl 31 --threads 16 --batch-size 32 --ubatch-size 8

However, when I keep only the ngl parameter, my server crashes with confusing error message:

./llama.cpp-b3985/build_gpu/bin/llama-server -m ../artifact/models/Mistral-7B-Instruct-v0.3.Q4_1.gguf -ngl 31

I got an CUDA error: CUBLAS_STATUS_NOT_INITIALIZED:

ggml_cuda_init: GGML_CUDA_FORCE_MMQ:    no
ggml_cuda_init: GGML_CUDA_FORCE_CUBLAS: no
ggml_cuda_init: found 1 CUDA devices:
  Device 0: NVIDIA GeForce RTX 3060, compute capability 8.6, VMM: yes
build: 0 (unknown) with cc (Ubuntu 12.3.0-1ubuntu1~22.04) 12.3.0 for x86_64-linux-gnu
system info: n_threads = 6, n_threads_batch = 6, total_threads = 16

system_info: n_threads = 6 (n_threads_batch = 6) / 16 | AVX = 1 | AVX_VNNI = 1 | AVX2 = 1 | AVX512 = 0 | AVX512_VBMI = 0 | AVX512_VNNI = 0 | AVX512_BF16 = 0 | AMX_INT8 = 0 | FMA = 1 | NEON = 0 | SVE = 0 | ARM_FMA = 0 | F16C = 1 | FP16_VA = 0 | RISCV_VECT = 0 | WASM_SIMD = 0 | BLAS = 1 | SSE3 = 1 | SSSE3 = 1 | VSX = 0 | MATMUL_INT8 = 0 | LLAMAFILE = 1 | 

main: HTTP server is listening, hostname: 127.0.0.1, port: 8080, http threads: 15
main: loading model
llama_load_model_from_file: using device CUDA0 (NVIDIA GeForce RTX 3060) - 10362 MiB free
llama_model_loader: loaded meta data with 26 key-value pairs and 291 tensors from ../artifact/models/Mistral-7B-Instruct-v0.3.Q4_1.gguf (version GGUF V3 (latest))
llama_model_loader: Dumping metadata keys/values. Note: KV overrides do not apply in this output.
llama_model_loader: - kv   0:                       general.architecture str              = llama
llama_model_loader: - kv   1:                               general.name str              = Mistral-7B-Instruct-v0.3
llama_model_loader: - kv   2:                          llama.block_count u32              = 32
llama_model_loader: - kv   3:                       llama.context_length u32              = 32768
llama_model_loader: - kv   4:                     llama.embedding_length u32              = 4096
llama_model_loader: - kv   5:                  llama.feed_forward_length u32              = 14336
llama_model_loader: - kv   6:                 llama.attention.head_count u32              = 32
llama_model_loader: - kv   7:              llama.attention.head_count_kv u32              = 8
llama_model_loader: - kv   8:                       llama.rope.freq_base f32              = 1000000.000000
llama_model_loader: - kv   9:     llama.attention.layer_norm_rms_epsilon f32              = 0.000010
llama_model_loader: - kv  10:                          general.file_type u32              = 3
llama_model_loader: - kv  11:                           llama.vocab_size u32              = 32768
llama_model_loader: - kv  12:                 llama.rope.dimension_count u32              = 128
llama_model_loader: - kv  13:            tokenizer.ggml.add_space_prefix bool             = true
llama_model_loader: - kv  14:                       tokenizer.ggml.model str              = llama
llama_model_loader: - kv  15:                         tokenizer.ggml.pre str              = default
llama_model_loader: - kv  16:                      tokenizer.ggml.tokens arr[str,32768]   = ["<unk>", "<s>", "</s>", "[INST]", "[...
llama_model_loader: - kv  17:                      tokenizer.ggml.scores arr[f32,32768]   = [0.000000, 0.000000, 0.000000, 0.0000...
llama_model_loader: - kv  18:                  tokenizer.ggml.token_type arr[i32,32768]   = [2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, ...
llama_model_loader: - kv  19:                tokenizer.ggml.bos_token_id u32              = 1
llama_model_loader: - kv  20:                tokenizer.ggml.eos_token_id u32              = 2
llama_model_loader: - kv  21:            tokenizer.ggml.unknown_token_id u32              = 0
llama_model_loader: - kv  22:               tokenizer.ggml.add_bos_token bool             = true
llama_model_loader: - kv  23:               tokenizer.ggml.add_eos_token bool             = false
llama_model_loader: - kv  24:                    tokenizer.chat_template str              = {{ bos_token }}{% for message in mess...
llama_model_loader: - kv  25:               general.quantization_version u32              = 2
llama_model_loader: - type  f32:   65 tensors
llama_model_loader: - type q4_1:  225 tensors
llama_model_loader: - type q6_K:    1 tensors
llm_load_vocab: special_eos_id is not in special_eog_ids - the tokenizer config may be incorrect
llm_load_vocab: special tokens cache size = 771
llm_load_vocab: token to piece cache size = 0.1731 MB
llm_load_print_meta: format           = GGUF V3 (latest)
llm_load_print_meta: arch             = llama
llm_load_print_meta: vocab type       = SPM
llm_load_print_meta: n_vocab          = 32768
llm_load_print_meta: n_merges         = 0
llm_load_print_meta: vocab_only       = 0
llm_load_print_meta: n_ctx_train      = 32768
llm_load_print_meta: n_embd           = 4096
llm_load_print_meta: n_layer          = 32
llm_load_print_meta: n_head           = 32
llm_load_print_meta: n_head_kv        = 8
llm_load_print_meta: n_rot            = 128
llm_load_print_meta: n_swa            = 0
llm_load_print_meta: n_embd_head_k    = 128
llm_load_print_meta: n_embd_head_v    = 128
llm_load_print_meta: n_gqa            = 4
llm_load_print_meta: n_embd_k_gqa     = 1024
llm_load_print_meta: n_embd_v_gqa     = 1024
llm_load_print_meta: f_norm_eps       = 0.0e+00
llm_load_print_meta: f_norm_rms_eps   = 1.0e-05
llm_load_print_meta: f_clamp_kqv      = 0.0e+00
llm_load_print_meta: f_max_alibi_bias = 0.0e+00
llm_load_print_meta: f_logit_scale    = 0.0e+00
llm_load_print_meta: n_ff             = 14336
llm_load_print_meta: n_expert         = 0
llm_load_print_meta: n_expert_used    = 0
llm_load_print_meta: causal attn      = 1
llm_load_print_meta: pooling type     = 0
llm_load_print_meta: rope type        = 0
llm_load_print_meta: rope scaling     = linear
llm_load_print_meta: freq_base_train  = 1000000.0
llm_load_print_meta: freq_scale_train = 1
llm_load_print_meta: n_ctx_orig_yarn  = 32768
llm_load_print_meta: rope_finetuned   = unknown
llm_load_print_meta: ssm_d_conv       = 0
llm_load_print_meta: ssm_d_inner      = 0
llm_load_print_meta: ssm_d_state      = 0
llm_load_print_meta: ssm_dt_rank      = 0
llm_load_print_meta: ssm_dt_b_c_rms   = 0
llm_load_print_meta: model type       = 7B
llm_load_print_meta: model ftype      = Q4_1
llm_load_print_meta: model params     = 7.25 B
llm_load_print_meta: model size       = 4.24 GiB (5.03 BPW) 
llm_load_print_meta: general.name     = Mistral-7B-Instruct-v0.3
llm_load_print_meta: BOS token        = 1 '<s>'
llm_load_print_meta: EOS token        = 2 '</s>'
llm_load_print_meta: UNK token        = 0 '<unk>'
llm_load_print_meta: LF token         = 781 '<0x0A>'
llm_load_print_meta: EOG token        = 2 '</s>'
llm_load_print_meta: max token length = 48
llm_load_tensors: ggml ctx size =    0.27 MiB
llm_load_tensors: offloading 31 repeating layers to GPU
llm_load_tensors: offloaded 31/33 layers to GPU
llm_load_tensors:        CPU buffer size =  4346.02 MiB
llm_load_tensors:      CUDA0 buffer size =  4030.97 MiB
..................................................................................................
llama_new_context_with_model: n_ctx      = 32768
llama_new_context_with_model: n_batch    = 2048
llama_new_context_with_model: n_ubatch   = 512
llama_new_context_with_model: flash_attn = 0
llama_new_context_with_model: freq_base  = 1000000.0
llama_new_context_with_model: freq_scale = 1
llama_kv_cache_init:  CUDA_Host KV buffer size =   128.00 MiB
llama_kv_cache_init:      CUDA0 KV buffer size =  3968.00 MiB
llama_new_context_with_model: KV self size  = 4096.00 MiB, K (f16): 2048.00 MiB, V (f16): 2048.00 MiB
llama_new_context_with_model:  CUDA_Host  output buffer size =     0.25 MiB
llama_new_context_with_model:      CUDA0 compute buffer size =  2266.00 MiB
llama_new_context_with_model:  CUDA_Host compute buffer size =    72.01 MiB
llama_new_context_with_model: graph nodes  = 1030
llama_new_context_with_model: graph splits = 15
common_init_from_params: warming up the model with an empty run - please wait ... (--no-warmup to disable)
/home/data1/llm_agent/llama.cpp-b3985/ggml/src/ggml-cuda.cu:70: CUDA error
CUDA error: CUBLAS_STATUS_NOT_INITIALIZED
  current device: 0, in function cublas_handle at /home/data1/llm_agent/llama.cpp-b3985/ggml/src/ggml-cuda/common.cuh:663
  cublasCreate_v2(&cublas_handles[device])
[New LWP 1623633]
[New LWP 1623634]
[New LWP 1623635]
[New LWP 1623636]
[New LWP 1623637]
[New LWP 1623638]
[New LWP 1623639]
[New LWP 1623640]
[New LWP 1623641]
[New LWP 1623642]
[New LWP 1623643]
[New LWP 1623644]
[New LWP 1623645]
[New LWP 1623646]
[New LWP 1623647]
[New LWP 1623648]
[New LWP 1623649]
[New LWP 1623650]
[New LWP 1623651]
[New LWP 1623652]
[New LWP 1623662]
[New LWP 1623663]
[New LWP 1623664]
[New LWP 1623665]
[New LWP 1623666]
[Thread debugging using libthread_db enabled]
Using host libthread_db library "/lib/x86_64-linux-gnu/libthread_db.so.1".
0x0000754c8e2ea42f in __GI___wait4 (pid=1623667, stat_loc=0x7ffe736cd754, options=0, usage=0x0) at ../sysdeps/unix/sysv/linux/wait4.c:30
30      ../sysdeps/unix/sysv/linux/wait4.c: No such file or directory.
#0  0x0000754c8e2ea42f in __GI___wait4 (pid=1623667, stat_loc=0x7ffe736cd754, options=0, usage=0x0) at ../sysdeps/unix/sysv/linux/wait4.c:30
30      in ../sysdeps/unix/sysv/linux/wait4.c
#1  0x0000754c8ea3b5e2 in ggml_abort () from /home/data1/llm_agent/llama.cpp-b3985/build_gpu/ggml/src/libggml.so
#2  0x0000754c8eb233a6 in ggml_cuda_error(char const*, char const*, char const*, int, char const*) () from /home/data1/llm_agent/llama.cpp-b3985/build_gpu/ggml/src/libggml.so
#3  0x0000754c8eb26010 in ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context&, ggml_tensor const*, ggml_tensor const*, ggml_tensor*) () from /home/data1/llm_agent/llama.cpp-b3985/build_gpu/ggml/src/libggml.so
#4  0x0000754c8eb2e49e in ggml_cuda_mul_mat(ggml_backend_cuda_context&, ggml_tensor const*, ggml_tensor const*, ggml_tensor*) () from /home/data1/llm_agent/llama.cpp-b3985/build_gpu/ggml/src/libggml.so
#5  0x0000754c8eb30581 in ggml_backend_cuda_graph_compute(ggml_backend*, ggml_cgraph*) () from /home/data1/llm_agent/llama.cpp-b3985/build_gpu/ggml/src/libggml.so
#6  0x0000754c8ea86183 in ggml_backend_sched_graph_compute_async () from /home/data1/llm_agent/llama.cpp-b3985/build_gpu/ggml/src/libggml.so
#7  0x0000754ca4513312 in llama_decode_internal(llama_context&, llama_batch) () from /home/data1/llm_agent/llama.cpp-b3985/build_gpu/src/libllama.so
#8  0x0000754ca451523b in llama_decode () from /home/data1/llm_agent/llama.cpp-b3985/build_gpu/src/libllama.so
#9  0x000055cd59d334ff in common_init_from_params(common_params&) ()
#10 0x000055cd59ccac19 in server_context::load_model(common_params const&) ()
#11 0x000055cd59c7bf65 in main ()
[Inferior 1 (process 1623632) detached]
Aborted (core dumped)

Maybe it is a resource issue? I am not sure. Because when I try to set the --ngl to 32, the server crashes with a clearer error message, "cudaMalloc failed: out of memory"

Name and Version

./llama.cpp-b3985/build_gpu/bin/llama-server --version ggml_cuda_init: GGML_CUDA_FORCE_MMQ: no ggml_cuda_init: GGML_CUDA_FORCE_CUBLAS: no ggml_cuda_init: found 1 CUDA devices: Device 0: NVIDIA GeForce RTX 3060, compute capability 8.6, VMM: yes version: 0 (unknown) built with cc (Ubuntu 12.3.0-1ubuntu1~22.04) 12.3.0 for x86_64-linux-gnu

What operating system are you seeing the problem on?

Linux

Relevant log output

No response

morgen52 commented 6 days ago

common was recently changed a lot probably something to do with that

I think you are right. When I start the server with:

 ./llama.cpp-b3985/build_gpu/bin/llama-server -m ../artifact/models/Mistral-7B-Instruct-v0.3.Q4_1.gguf -ngl 31 -c 8192

It can work properly.

And when I add the --no-warmup config:

./llama.cpp-b3985/build_gpu/bin/llama-server -m ../artifact/models/Mistral-7B-Instruct-v0.3.Q4_1.gguf -ngl 31 --no-warmup

It tells me that --no-warmup is not a valid argument error: invalid argument: --no-warmup

So I think this hint should be updated.

common_init_from_params : warming up the model with an empty run - please wait ... (--no-warmup to disable)

morgen52 commented 6 days ago

May I ask how context size affects GPU memory allocation? My understanding is that context size is just a sliding window for context length. Is memory pre-allocated based on context size?