ggerganov / llama.cpp

LLM inference in C/C++
MIT License
66.14k stars 9.5k forks source link

Bug: Llama.cpp with cuda support outputs garbage response when prompt is above 30-40ish Tokens #9838

Open bmahabirbu opened 8 hours ago

bmahabirbu commented 8 hours ago

What happened?

You are a helpful assistant
> what is 2+2+2+2
44444444444444444444444444444444444444444444444444444444444444444444444444444444444444444
> 

When I run llama-cli with cuda support depending on how long the prompt is I get back garbage what could be causing this issue? Im running this in a container on WSL2.

Name and Version

version: 3899 (dca1d4b5) built with cc (GCC) 11.4.1 20231218 (Red Hat 11.4.1-3) for x86_64-redhat-linux

What operating system are you seeing the problem on?

WSL2 linux

Relevant log output

==========
== CUDA ==
==========

CUDA Version 12.6.1

Container image Copyright (c) 2016-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.

This container image and its contents are governed by the NVIDIA Deep Learning Container License.
By pulling and using the container, you accept the terms and conditions of this license:
https://developer.nvidia.com/ngc/nvidia-deep-learning-container-license

A copy of this license is made available in this container at /NGC-DL-CONTAINER-LICENSE for your convenience.

llama-cli -m /var/lib/ramalama/models/ollama/granite-code:latest --in-prefix  --in-suffix  -ngl 99 -t 8 --n-predict 100 -b 2048 -ub 2048 --interactive-first --no-warmup -p You are a helpful assistant -c 2048 -cnv
ggml_cuda_init: GGML_CUDA_FORCE_MMQ:    no
ggml_cuda_init: GGML_CUDA_FORCE_CUBLAS: no
ggml_cuda_init: found 1 CUDA devices:
  Device 0: NVIDIA GeForce RTX 3080, compute capability 8.6, VMM: yes
build: 3899 (dca1d4b5) with cc (GCC) 11.4.1 20231218 (Red Hat 11.4.1-3) for x86_64-redhat-linux
main: llama backend init
main: load the model and apply lora adapter, if any
llama_model_loader: loaded meta data with 33 key-value pairs and 514 tensors from /var/lib/ramalama/models/ollama/granite-code:latest (version GGUF V3 (latest))
llama_model_loader: Dumping metadata keys/values. Note: KV overrides do not apply in this output.
llama_model_loader: - kv   0:                       general.architecture str              = llama
llama_model_loader: - kv   1:                               general.type str              = model
llama_model_loader: - kv   2:                               general.name str              = Granite 3b Code Instruct 128k
llama_model_loader: - kv   3:                           general.finetune str              = code-instruct-128k
llama_model_loader: - kv   4:                           general.basename str              = granite
llama_model_loader: - kv   5:                         general.size_label str              = 3B
llama_model_loader: - kv   6:                            general.license str              = apache-2.0
llama_model_loader: - kv   7:                               general.tags arr[str,3]       = ["code", "granite", "text-generation"]
llama_model_loader: - kv   8:                           general.datasets arr[str,9]       = ["bigcode/commitpackft", "TIGER-Lab/M...
llama_model_loader: - kv   9:                          llama.block_count u32              = 32
llama_model_loader: - kv  10:                       llama.context_length u32              = 128000
llama_model_loader: - kv  11:                     llama.embedding_length u32              = 2560
llama_model_loader: - kv  12:                  llama.feed_forward_length u32              = 10240
llama_model_loader: - kv  13:                 llama.attention.head_count u32              = 32
llama_model_loader: - kv  14:              llama.attention.head_count_kv u32              = 32
llama_model_loader: - kv  15:                       llama.rope.freq_base f32              = 10000000.000000
llama_model_loader: - kv  16:     llama.attention.layer_norm_rms_epsilon f32              = 0.000010
llama_model_loader: - kv  17:                          general.file_type u32              = 2
llama_model_loader: - kv  18:                           llama.vocab_size u32              = 49152
llama_model_loader: - kv  19:                 llama.rope.dimension_count u32              = 80
llama_model_loader: - kv  20:            tokenizer.ggml.add_space_prefix bool             = false
llama_model_loader: - kv  21:               tokenizer.ggml.add_bos_token bool             = false
llama_model_loader: - kv  22:                       tokenizer.ggml.model str              = gpt2
llama_model_loader: - kv  23:                         tokenizer.ggml.pre str              = refact
llama_model_loader: - kv  24:                      tokenizer.ggml.tokens arr[str,49152]   = ["<|endoftext|>", "<fim_prefix>", "<f...
llama_model_loader: - kv  25:                  tokenizer.ggml.token_type arr[i32,49152]   = [3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, ...
llama_model_loader: - kv  26:                      tokenizer.ggml.merges arr[str,48891]   = ["Ġ Ġ", "ĠĠ ĠĠ", "ĠĠĠĠ ĠĠ...
llama_model_loader: - kv  27:                tokenizer.ggml.bos_token_id u32              = 0
llama_model_loader: - kv  28:                tokenizer.ggml.eos_token_id u32              = 0
llama_model_loader: - kv  29:            tokenizer.ggml.unknown_token_id u32              = 0
llama_model_loader: - kv  30:            tokenizer.ggml.padding_token_id u32              = 0
llama_model_loader: - kv  31:                    tokenizer.chat_template str              = {% for message in messages %}\n{% if m...
llama_model_loader: - kv  32:               general.quantization_version u32              = 2
llama_model_loader: - type  f32:  289 tensors
llama_model_loader: - type q4_0:  224 tensors
llama_model_loader: - type q6_K:    1 tensors
llm_load_vocab: special tokens cache size = 19
llm_load_vocab: token to piece cache size = 0.2826 MB
llm_load_print_meta: format           = GGUF V3 (latest)
llm_load_print_meta: arch             = llama
llm_load_print_meta: vocab type       = BPE
llm_load_print_meta: n_vocab          = 49152
llm_load_print_meta: n_merges         = 48891
llm_load_print_meta: vocab_only       = 0
llm_load_print_meta: n_ctx_train      = 128000
llm_load_print_meta: n_embd           = 2560
llm_load_print_meta: n_layer          = 32
llm_load_print_meta: n_head           = 32
llm_load_print_meta: n_head_kv        = 32
llm_load_print_meta: n_rot            = 80
llm_load_print_meta: n_swa            = 0
llm_load_print_meta: n_embd_head_k    = 80
llm_load_print_meta: n_embd_head_v    = 80
llm_load_print_meta: n_gqa            = 1
llm_load_print_meta: n_embd_k_gqa     = 2560
llm_load_print_meta: n_embd_v_gqa     = 2560
llm_load_print_meta: f_norm_eps       = 0.0e+00
llm_load_print_meta: f_norm_rms_eps   = 1.0e-05
llm_load_print_meta: f_clamp_kqv      = 0.0e+00
llm_load_print_meta: f_max_alibi_bias = 0.0e+00
llm_load_print_meta: f_logit_scale    = 0.0e+00
llm_load_print_meta: n_ff             = 10240
llm_load_print_meta: n_expert         = 0
llm_load_print_meta: n_expert_used    = 0
llm_load_print_meta: causal attn      = 1
llm_load_print_meta: pooling type     = 0
llm_load_print_meta: rope type        = 0
llm_load_print_meta: rope scaling     = linear
llm_load_print_meta: freq_base_train  = 10000000.0
llm_load_print_meta: freq_scale_train = 1
llm_load_print_meta: n_ctx_orig_yarn  = 128000
llm_load_print_meta: rope_finetuned   = unknown
llm_load_print_meta: ssm_d_conv       = 0
llm_load_print_meta: ssm_d_inner      = 0
llm_load_print_meta: ssm_d_state      = 0
llm_load_print_meta: ssm_dt_rank      = 0
llm_load_print_meta: ssm_dt_b_c_rms   = 0
llm_load_print_meta: model type       = 3B
llm_load_print_meta: model ftype      = Q4_0
llm_load_print_meta: model params     = 3.48 B
llm_load_print_meta: model size       = 1.86 GiB (4.58 BPW) 
llm_load_print_meta: general.name     = Granite 3b Code Instruct 128k
llm_load_print_meta: BOS token        = 0 '<|endoftext|>'
llm_load_print_meta: EOS token        = 0 '<|endoftext|>'
llm_load_print_meta: UNK token        = 0 '<|endoftext|>'
llm_load_print_meta: PAD token        = 0 '<|endoftext|>'
llm_load_print_meta: LF token         = 145 'Ä'
llm_load_print_meta: EOT token        = 0 '<|endoftext|>'
llm_load_print_meta: EOG token        = 0 '<|endoftext|>'
llm_load_print_meta: max token length = 512
llm_load_tensors: ggml ctx size =    0.43 MiB
llm_load_tensors: offloading 32 repeating layers to GPU
llm_load_tensors: offloading non-repeating layers to GPU
llm_load_tensors: offloaded 33/33 layers to GPU
llm_load_tensors:        CPU buffer size =    98.44 MiB
llm_load_tensors:      CUDA0 buffer size =  1903.13 MiB
.............................................................................................
llama_new_context_with_model: n_ctx      = 2048
llama_new_context_with_model: n_batch    = 2048
llama_new_context_with_model: n_ubatch   = 2048
llama_new_context_with_model: flash_attn = 0
llama_new_context_with_model: freq_base  = 10000000.0
llama_new_context_with_model: freq_scale = 1
llama_kv_cache_init:      CUDA0 KV buffer size =   640.00 MiB
llama_new_context_with_model: KV self size  =  640.00 MiB, K (f16):  320.00 MiB, V (f16):  320.00 MiB
llama_new_context_with_model:  CUDA_Host  output buffer size =     0.19 MiB
llama_new_context_with_model:      CUDA0 compute buffer size =   608.02 MiB
llama_new_context_with_model:  CUDA_Host compute buffer size =    36.02 MiB
llama_new_context_with_model: graph nodes  = 1254
llama_new_context_with_model: graph splits = 2
main: llama threadpool init, n_threads = 8
main: in-suffix/prefix is specified, chat template will be disabled

system_info: n_threads = 8 (n_threads_batch = 8) / 16 | AVX = 1 | AVX_VNNI = 0 | AVX2 = 1 | AVX512 = 0 | AVX512_VBMI = 0 | AVX512_VNNI = 0 | AVX512_BF16 = 0 | FMA = 1 | NEON = 0 | SVE = 0 | ARM_FMA = 0 | F16C = 1 | FP16_VA = 0 | RISCV_VECT = 0 | WASM_SIMD = 0 | BLAS = 1 | SSE3 = 1 | SSSE3 = 1 | VSX = 0 | MATMUL_INT8 = 0 | LLAMAFILE = 1 | 

main: interactive mode on.
sampler seed: 2079466916
sampler params: 
        repeat_last_n = 64, repeat_penalty = 1.000, frequency_penalty = 0.000, presence_penalty = 0.000
        top_k = 40, tfs_z = 1.000, top_p = 0.950, min_p = 0.050, typical_p = 1.000, temp = 0.800
        mirostat = 0, mirostat_lr = 0.100, mirostat_ent = 5.000
sampler chain: logits -> logit-bias -> penalties -> top-k -> tail-free -> typical -> top-p -> min-p -> temp-ext -> softmax -> dist 
generate: n_ctx = 2048, n_batch = 2048, n_predict = 100, n_keep = 0

== Running in interactive mode. ==
 - Press Ctrl+C to interject at any time.
 - Press Return to return control to the AI.
 - To return control without starting a new line, end your input with '/'.
 - If you want to submit another line, end your input with '\'.

You are a helpful assistant
> what is 2+2+2+2
44444444444444444444444444444444444444444444444444444444444444444444444444444444444444444
> 
llama_perf_sampler_print:    sampling time =       1.56 ms /   100 runs   (    0.02 ms per token, 64102.56 tokens per second)
llama_perf_context_print:        load time =   61134.80 ms
llama_perf_context_print: prompt eval time =   60365.52 ms /    16 tokens ( 3772.85 ms per token,     0.27 tokens per second)
llama_perf_context_print:        eval time =     588.91 ms /    88 runs   (    6.69 ms per token,   149.43 tokens per second)
llama_perf_context_print:       total time =   63294.18 ms /   104 tokens
Interrupted by user
bmahabirbu commented 8 hours ago
ggml_debug:                   Kcur-2 = (f32)       ROPE(Kcur-2 (reshaped){128, 8, 11, 1}, CUDA0#inp_pos#0{11, 1, 1, 1}}) = {128, 8, 11, 1}
                                     [
                                      [
                                       [         nan,          nan,          nan, ...,          nan,          nan,          nan],
                                       [         nan,          nan,          nan, ...,          nan,          nan,          nan],
                                       [         nan,          nan,          nan, ...,          nan,          nan,          nan],
                                       ..., 
                                       [         nan,          nan,          nan, ...,          nan,          nan,          nan],
                                       [         nan,          nan,          nan, ...,          nan,          nan,          nan],
                                       [         nan,          nan,          nan, ...,          nan,          nan,          nan],

https://github.com/ggerganov/llama.cpp/issues/7048 I checked out llama-eval-callback with a prompt that causes the error and its the same as this issue above. Although the issue wasn't resolved the author had a hardware error. Im not sure how to test for vram hardware error although I'm sure that i don't have one. Could this be with an issue with WSL2 and vram allocation?

slaren commented 8 hours ago

Could this be with an issue with WSL2 and vram allocation?

I don't think that's very likely, I use CUDA almost exclusively with WSL2.

bmahabirbu commented 8 hours ago

https://github.com/ggerganov/llama.cpp/issues/6957#issuecomment-2251308757

Came across this. A little more insight would be great as I'm not sure what this means? Is it a problem with how I'm using the llama3 model?

slaren commented 7 hours ago

I don't know, you can try adding this arch to the list of exceptions to see if that fixes it.

JohannesGaessler commented 7 hours ago

If this problem only occurs for specific prompts (and not for all prompts of the same length) it could be due to numerical issues. Does the ggml_debug print above mean that the output of the CUDA implementation of RoPE is nan even though the inputs are not nan?

bmahabirbu commented 6 hours ago

Not sure. The ggml_dbug starts as normal then nans start to flood after

ggml_debug:                      k-7 = (f16)       VIEW(cache_k_l7{524288, 1, 1, 1}, }) = {64, 32, 4, 1}
                                     [
                                      [
                                       [     -0.0157,      -0.0056,      -0.0117, ...,       0.3660,       1.1211,      -1.8447],
                                       [      0.2224,      -0.1504,       0.7095, ...,      -0.4265,      -0.0087,       0.7388],
                                       [      0.3860,       0.6807,      -1.2812, ...,      -2.0469,      -1.5107,       0.7300],

from what I've found its prompts of the same length.

An Important update from me. In WSL2 I'm using Podman with a container I made

FROM nvidia/cuda:12.6.1-devel-ubi9

# renovate: datasource=github-releases depName=huggingface/huggingface_hub extractVersion=^v(?<version>.*)
ARG HUGGINGFACE_HUB_VERSION=0.25.0
# renovate: datasource=github-releases depName=containers/omlmd extractVersion=^v(?<version>.*)
ARG OMLMD_VERSION=0.1.4
# renovate: datasource=git-refs depName=ggerganov/llama.cpp packageName=https://github.com/ggerganov/llama.cpp gitRef=master versioning=loose type=digest
ARG LLAMA_CPP_SHA=dca1d4b58a7f1acf1bd253be84e50d6367f492fd
# renovate: datasource=git-refs depName=ggerganov/whisper.cpp packageName=https://github.com/ggerganov/whisper.cpp gitRef=master versioning=loose type=digest
ARG WHISPER_CPP_SHA=5caa19240d55bfd6ee316d50fbad32c6e9c39528

# vulkan-headers vulkan-loader-devel vulkan-tools glslc glslang python3-pip mesa-libOpenCL-$MESA_VER.aarch64
RUN dnf install -y https://dl.fedoraproject.org/pub/epel/epel-release-latest-9.noarch.rpm && \
    crb enable && \
    dnf install -y epel-release && \
    dnf --enablerepo=ubi-9-appstream-rpms install -y git procps-ng vim \
      dnf-plugins-core python3-dnf-plugin-versionlock cmake gcc-c++ \
      python3-pip && \
    dnf clean all && \
    rm -rf /var/cache/*dnf*

RUN /usr/bin/python3 --version
RUN pip install "huggingface_hub[cli]==${HUGGINGFACE_HUB_VERSION}"
RUN pip install "omlmd==${OMLMD_VERSION}"

# Build wouldnt complete couldnt find libcuda so made a systemlink
# But this didnt work

# RUN ln -s /usr/local/cuda/lib64/stubs/libcuda.so /usr/local/cuda/lib64/stubs/libcuda.so.1
# RUN LD_LIBRARY_PATH=/usr/local/cuda/lib64/stubs/

# Disable cashing for debugging build

# ENV GGML_CCACHE=0

# Build wouldnt complete with cmake even with nvidia container toolkit installed

RUN git clone https://github.com/ggerganov/llama.cpp && \
    cd llama.cpp && \
    git reset --hard ${LLAMA_CPP_SHA} && \
    make -j $(nproc) GGML_CUDA=1 CUDA_ARCH=ALL && \
    mv llama-cli /usr/bin/llama-cli && \
    mv llama-server /usr/bin/llama-server && \
    mv llama-eval-callback /usr/bin/llama-eval-callback && \
    cd / && \
    rm -rf llama.cpp

That builds llama.cpp with cuda from a maintained nvidia container

However when I built llama.cpp with cuda on wsl2 without using a container it ran perfectly! something is wrong when trying to do this from within a container. I would greatly appreciate anyone who got llama.cpp with cuda running in a container for additional setup advice. I have a feeling the container probably does not have access to the full vram or something along those lines

JohannesGaessler commented 6 hours ago

The Makefile has no argument CUDA_ARCH so that argument is being ignored. Does it work with CUDA_DOCKER_ARCH=compute_86?