ggerganov / llama.cpp

LLM inference in C/C++
MIT License
65.87k stars 9.46k forks source link

Old GGUF have broken tokenization and there is no warning #7476

Closed turian closed 2 months ago

turian commented 4 months ago

As reported in https://github.com/ggerganov/llama.cpp/issues/6944#issuecomment-2101577066

The llama.cpp tokenizers give different results than HF for old GGUF files.

This is a subtle footgun and at least there should be a warning, since it is impossible now to determine what at what vintage your old GGUF models suddenly spoil.

Right now, the only reliable way to determine this is by running perplexity.cpp and comparing it to HF. The key numbers for the first 512-toks of wiki-2-test are as follows.

Using llama.cpp tokenizer:

model.           quant  perplexity
llama.cpp        Q8_0   perplexity: 15.4660
Llama-CPP-python Q8_0   perplexity: 15.392684936523438
llama.cpp        Q5_K_M perplexity: 15.6994
Llama-CPP-python Q5_K_M perplexity: 15.637877464294434

Using HF tokenizer and passing those tokens into different model implementations:

model            quant  perplexity
Huggingface             perplexity: 6.205880641937256
Llama-CPP-python Q8_0   perplexity: 6.204566478729248
Llama-CPP-python Q5_K_M perplexity: 6.228440761566162

This is demonstrated through an attached notebook, which you can play with at this colab. I'll paste the code below too.

# -*- coding: utf-8 -*-
"""Tokenizer: HF vs llama-cpp-python vs llama.cpp (perplexity.cpp)

Automatically generated by Colab.

Original file is located at
    https://colab.research.google.com/drive/1RYlEj2UhylYWyaASFo-LLATzZ8d29Z0T

:# Tokenizer: HF vs llama-cpp-python vs. llama.cpp (perplexity.cpp)

We show that an OLD previously converted TinyLlama GGUF a) has buggy tokenization in llama.cpp and  b) llama.cpp doesn't provide any warning.

This leads to unusually bad perplexity. For simplicity, we report the perplexity of the first 512-token window of wikitext-2-raw-test.

If we use the HF tokenizer and feed the output the llama.cpp, we get the perplexity we expect.

Using llama.cpp tokenizer:

model. quant perplexity llama.cpp Q8_0 perplexity: 15.4660 Llama-CPP-python Q8_0 perplexity: 15.392684936523438 llama.cpp Q5_K_M perplexity: 15.6994 Llama-CPP-python Q5_K_M perplexity: 15.637877464294434


Using HF tokenizer and passing those tokens into different model implementations:

model quant perplexity Huggingface perplexity: 6.205880641937256 Llama-CPP-python Q8_0 perplexity: 6.204566478729248 Llama-CPP-python Q5_K_M perplexity: 6.228440761566162

"""

N_CTX = 512   # perplexity.cpp default

"""We trim to roughly 1024 tokens of text because the context window is 512 and due to pecularities in how perplexity.cpp is implemented we can't just work with 512 tokens of text.

Note that all our measurements are JUST over the first 512 tokens of wikitext-2-raw-test. Other windows are dummies needed to get perplexity.cpp to run.
"""

!wget https://huggingface.co/datasets/ggml-org/ci/resolve/main/wikitext-2-raw-v1.zip
!unzip wikitext-2-raw-v1.zip

# Install llama-cpp-python
!pip install llama-cpp-python \
  --extra-index-url https://abetlen.github.io/llama-cpp-python/whl/cpu

# Download quantized TinyLlama models for llama-cpp-python
!wget https://huggingface.co/TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF/resolve/main/tinyllama-1.1b-chat-v1.0.Q5_K_M.gguf
!wget https://huggingface.co/TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF/resolve/main/tinyllama-1.1b-chat-v1.0.Q8_0.gguf

# Download Huggingface TinyLlama model and tokenizer
from transformers import AutoTokenizer, AutoModelForCausalLM
hf_tokenizer = AutoTokenizer.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0")
hf_model = AutoModelForCausalLM.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0")

# Load llama-cpp-python quantized models
from llama_cpp import Llama
gguf_models = [
    "tinyllama-1.1b-chat-v1.0.Q5_K_M.gguf",
    "tinyllama-1.1b-chat-v1.0.Q8_0.gguf"
]
llama_models = [Llama(model_path=path, logits_all=True) for path in gguf_models]

"""3612 characters gives us 1024 tokens in llama.cpp tokenizer and 1032 tokens in HF tokenizer."""

text = open("wikitext-2-raw/wiki.test.raw").read()[:3612]
open("wikitext-2-test-3612.raw", "wt").write(text)

hf_tokens = hf_tokenizer.encode(text, add_special_tokens=True)
print(f"n HF tokens: {len(hf_tokens)}")

for llama_model, gguf_file in zip(llama_models, gguf_models):
    llama_tokens = llama_model.tokenize(text.encode("utf-8"))
    print(f"n Llama-CPP {gguf_file} tokens: {len(llama_tokens)}")

"""As a sanity check, we will use both a hand-rolled perplexity implementation AND the one from perplexity.cpp (later in the notebook)"""

# Helper Functions
import torch
import numpy as np
from scipy.special import log_softmax

def get_logits_hf(model, tokens):
    """Get logits from a Huggingface Transformers model.

    Preconditions:
    - model is a valid Huggingface Transformers model
    - tokens is a list of integers

    Postconditions:
    - logits is a 2D numpy array of shape (len(tokens), vocab_size)
    """
    assert isinstance(tokens, list)
    assert all(isinstance(t, int) for t in tokens)

    input_ids = torch.tensor(tokens).unsqueeze(0)
    with torch.no_grad():
        outputs = model(input_ids, labels=input_ids)
        logits = outputs.logits.squeeze(0).cpu().numpy()

    assert logits.ndim == 2
    assert logits.shape[0] == len(tokens), f"logits.shape: {logits.shape}, tokens: {len(tokens)} = {tokens}"
    return logits

# Method to clear llama context
def clear_llama_context(llama):
    """Clear the llama context and reset the number of tokens."""
    llama.reset()
    llama._ctx.kv_cache_clear()
    llama.input_ids.fill(0)
    llama.scores.fill(0)

def get_logits_llama(model, tokens):
    """Get logits from a llama-cpp-python model.

    Preconditions:
    - model is a valid llama-cpp-python model
    - tokens is a list of integers

    Postconditions:
    - logits is a 2D numpy array of shape (len(tokens), vocab_size)
    """
    assert isinstance(model, Llama)
    assert isinstance(tokens, list)
    assert all(isinstance(t, int) for t in tokens)

    clear_llama_context(model)

    assert model.n_tokens == 0
    model.eval(tokens)
    assert model.n_tokens == len(tokens)
    logits = np.array(model.scores).reshape(-1, model.n_vocab())[:len(tokens)]

    assert logits.ndim == 2
    assert logits.shape[0] == len(tokens), f"logits.shape: {logits.shape}, tokens: {len(tokens)} = {tokens}"
    return logits

def compute_token_nlls(logits, tokens):
    """Compute NLLs of all tokens from logits and target tokens.

    Preconditions:
    - logits is a 2D numpy array of shape (len(tokens), vocab_size)
    - tokens is a list of integers with at least 2 elements

    Postconditions:
    - nlls is a numpy array of floats representing NLL of each token
    """
    assert logits.ndim == 2
    assert len(tokens) >= 2
    assert logits.shape[0] == len(tokens), f"logits.shape: {logits.shape}, tokens: {len(tokens)} = {tokens}"

    target_tokens = tokens[1:]
    log_probs = log_softmax(logits[:-1], axis=-1)
    nlls = -log_probs[np.arange(len(target_tokens)), target_tokens]

    assert isinstance(nlls, np.ndarray)
    assert nlls.ndim == 1
    assert len(nlls) == len(target_tokens)
    return nlls

"""Here we see the huggingface tokens and perplexity:"""

# Compute NLLs for Huggingface model
hf_logits = get_logits_hf(hf_model, hf_tokens[:N_CTX])
assert hf_logits.shape[0] == N_CTX
halfway = N_CTX // 2
hf_nll = compute_token_nlls(hf_logits[halfway:,...], hf_tokens[halfway:N_CTX])
print(f"Huggingface tokens:", hf_tokens[:N_CTX])
print(f"Huggingface NLL: {np.mean(hf_nll)}")
print(f"Huggingface perplexity: {np.exp(np.mean(hf_nll))}")

"""Now we see llama-cpp-python tokens and perplexity.

The GGUF llama-cpp-python tokenizers are different after the first few tokens. This really hurts perplexity and NLL.
"""

# Compute NLLs for llama-cpp-python models
for model, gguf_file in zip(llama_models, gguf_models):
#    llama_tokens = model.tokenize(text.encode("utf-8"), add_bos=True)
    llama_tokens = model.tokenize(text.encode("utf-8"))
    llama_logits = get_logits_llama(model, llama_tokens[:N_CTX])
    assert llama_logits.shape[0] == N_CTX
    halfway = N_CTX // 2
    llama_nll = compute_token_nlls(llama_logits[halfway:,...], llama_tokens[halfway:N_CTX])
    print(f"Llama-CPP {gguf_file} tokens:", llama_tokens[:N_CTX])
    print(f"Llama-CPP {gguf_file} NLL: {np.mean(llama_nll)}")
    print(f"Llama-CPP {gguf_file} perplexity: {np.exp(np.mean(llama_nll))}")

"""But if we use HF tokens with llama-cpp-python, we get identical tokens and nearly identical perplexity:"""

# Compute NLLs for llama-cpp-python models, but using HF tokens
for model, gguf_file in zip(llama_models, gguf_models):
    llama_logits = get_logits_llama(model, hf_tokens[:N_CTX])
    assert llama_logits.shape[0] == N_CTX
    halfway = N_CTX // 2
    llama_nll = compute_token_nlls(llama_logits[halfway:,...], hf_tokens[halfway:N_CTX])
    print(f"HF tokens:", hf_tokens[:N_CTX])
    print(f"Llama-CPP {gguf_file} on HF tokens NLL: {np.mean(llama_nll)}")
    print(f"Llama-CPP {gguf_file} on HF tokens perplexity: {np.exp(np.mean(llama_nll))}")

"""Here I use llama.cpp main, so we can make sure the tokens and perplexity are same as llama-cpp-python above."""

!git clone https://github.com/ggerganov/llama.cpp.git

!cd llama.cpp ; git log -1 --format="Commit Hash: %H%nCommit Date: %cd" --date=iso

"""This gnarly little one-liner just patches perplexity.cpp to output the token lists."""

#!perl -i -pe 's/(std::vector<llama_token> tokens = ::llama_tokenize\(.*\);)/$1 fprintf\(stderr, "%s: %d tokens\\n", __func__, int\(tokens.size\(\)\)\);/' llama.cpp/examples/perplexity/perplexity.cpp
!perl -i -pe 's/(std::vector<llama_token> tokens = ::llama_tokenize\(.*\);)/$1 fprintf\(stderr, "%s: %d tokens: ", __func__, int\(tokens.size\(\)\)\); for \(const auto& token : tokens\) { fprintf\(stderr, "%d ", int\(token\)\); } fprintf\(stderr, "\\n"\);/' llama.cpp/examples/perplexity/perplexity.cpp
!cd llama.cpp ; git diff

!cd llama.cpp && make -j2

"""Now, we see that llama.cpp (and not just llama-cpp-python) give bogus non-HF tokens with this old GGUF. And the perplexity scores of the FIRST WINDOW (`[1]`) are almost identical those above, and differ slightly possibly due to seeding which I didn't control for:

below: llama.cpp Q5_K_M perplexity: 15.6994 llama.cpp Q5_K_M perplexity: 15.4660

above: Llama-CPP tinyllama-1.1b-chat-v1.0.Q5_K_M.gguf perplexity: 15.637877464294434 Llama-CPP tinyllama-1.1b-chat-v1.0.Q8_0.gguf perplexity: 15.392684936523438

"""

!llama.cpp/perplexity -f "wikitext-2-test-3612.raw" -m tinyllama-1.1b-chat-v1.0.Q5_K_M.gguf

!llama.cpp/perplexity -f "wikitext-2-test-3612.raw" -m tinyllama-1.1b-chat-v1.0.Q8_0.gguf
slaren commented 4 months ago

Don't you get this warning when using an old, broken model?

llm_load_vocab: missing pre-tokenizer type, using: 'default'
llm_load_vocab:
llm_load_vocab: ************************************
llm_load_vocab: GENERATION QUALITY WILL BE DEGRADED!
llm_load_vocab: CONSIDER REGENERATING THE MODEL
llm_load_vocab: ************************************
llm_load_vocab:
turian commented 4 months ago

@slaren I don't, no. Here is the output of perplexity.cpp from my colab. Perhaps the issue is as simple as adding the warning to perplexity.cpp too, if it's present elsewhere.

main: build = 2971 (1e374365)
main: built with cc (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0 for x86_64-linux-gnu
main: seed  = 1716408409
llama_model_loader: loaded meta data with 23 key-value pairs and 201 tensors from tinyllama-1.1b-chat-v1.0.Q8_0.gguf (version GGUF V3 (latest))
llama_model_loader: Dumping metadata keys/values. Note: KV overrides do not apply in this output.
llama_model_loader: - kv   0:                       general.architecture str              = llama
llama_model_loader: - kv   1:                               general.name str              = tinyllama_tinyllama-1.1b-chat-v1.0
llama_model_loader: - kv   2:                       llama.context_length u32              = 2048
llama_model_loader: - kv   3:                     llama.embedding_length u32              = 2048
llama_model_loader: - kv   4:                          llama.block_count u32              = 22
llama_model_loader: - kv   5:                  llama.feed_forward_length u32              = 5632
llama_model_loader: - kv   6:                 llama.rope.dimension_count u32              = 64
llama_model_loader: - kv   7:                 llama.attention.head_count u32              = 32
llama_model_loader: - kv   8:              llama.attention.head_count_kv u32              = 4
llama_model_loader: - kv   9:     llama.attention.layer_norm_rms_epsilon f32              = 0.000010
llama_model_loader: - kv  10:                       llama.rope.freq_base f32              = 10000.000000
llama_model_loader: - kv  11:                          general.file_type u32              = 7
llama_model_loader: - kv  12:                       tokenizer.ggml.model str              = llama
llama_model_loader: - kv  13:                      tokenizer.ggml.tokens arr[str,32000]   = ["<unk>", "<s>", "</s>", "<0x00>", "<...
llama_model_loader: - kv  14:                      tokenizer.ggml.scores arr[f32,32000]   = [0.000000, 0.000000, 0.000000, 0.0000...
llama_model_loader: - kv  15:                  tokenizer.ggml.token_type arr[i32,32000]   = [2, 3, 3, 6, 6, 6, 6, 6, 6, 6, 6, 6, ...
llama_model_loader: - kv  16:                      tokenizer.ggml.merges arr[str,61249]   = ["▁ t", "e r", "i n", "▁ a", "e n...
llama_model_loader: - kv  17:                tokenizer.ggml.bos_token_id u32              = 1
llama_model_loader: - kv  18:                tokenizer.ggml.eos_token_id u32              = 2
llama_model_loader: - kv  19:            tokenizer.ggml.unknown_token_id u32              = 0
llama_model_loader: - kv  20:            tokenizer.ggml.padding_token_id u32              = 2
llama_model_loader: - kv  21:                    tokenizer.chat_template str              = {% for message in messages %}\n{% if m...
llama_model_loader: - kv  22:               general.quantization_version u32              = 2
llama_model_loader: - type  f32:   45 tensors
llama_model_loader: - type q8_0:  156 tensors
llm_load_vocab: special tokens definition check successful ( 259/32000 ).
llm_load_print_meta: format           = GGUF V3 (latest)
llm_load_print_meta: arch             = llama
llm_load_print_meta: vocab type       = SPM
llm_load_print_meta: n_vocab          = 32000
llm_load_print_meta: n_merges         = 0
llm_load_print_meta: n_ctx_train      = 2048
llm_load_print_meta: n_embd           = 2048
llm_load_print_meta: n_head           = 32
llm_load_print_meta: n_head_kv        = 4
llm_load_print_meta: n_layer          = 22
llm_load_print_meta: n_rot            = 64
llm_load_print_meta: n_embd_head_k    = 64
llm_load_print_meta: n_embd_head_v    = 64
llm_load_print_meta: n_gqa            = 8
llm_load_print_meta: n_embd_k_gqa     = 256
llm_load_print_meta: n_embd_v_gqa     = 256
llm_load_print_meta: f_norm_eps       = 0.0e+00
llm_load_print_meta: f_norm_rms_eps   = 1.0e-05
llm_load_print_meta: f_clamp_kqv      = 0.0e+00
llm_load_print_meta: f_max_alibi_bias = 0.0e+00
llm_load_print_meta: f_logit_scale    = 0.0e+00
llm_load_print_meta: n_ff             = 5632
llm_load_print_meta: n_expert         = 0
llm_load_print_meta: n_expert_used    = 0
llm_load_print_meta: causal attn      = 1
llm_load_print_meta: pooling type     = 0
llm_load_print_meta: rope type        = 0
llm_load_print_meta: rope scaling     = linear
llm_load_print_meta: freq_base_train  = 10000.0
llm_load_print_meta: freq_scale_train = 1
llm_load_print_meta: n_yarn_orig_ctx  = 2048
llm_load_print_meta: rope_finetuned   = unknown
llm_load_print_meta: ssm_d_conv       = 0
llm_load_print_meta: ssm_d_inner      = 0
llm_load_print_meta: ssm_d_state      = 0
llm_load_print_meta: ssm_dt_rank      = 0
llm_load_print_meta: model type       = 1B
llm_load_print_meta: model ftype      = Q8_0
llm_load_print_meta: model params     = 1.10 B
llm_load_print_meta: model size       = 1.09 GiB (8.50 BPW) 
llm_load_print_meta: general.name     = tinyllama_tinyllama-1.1b-chat-v1.0
llm_load_print_meta: BOS token        = 1 '<s>'
llm_load_print_meta: EOS token        = 2 '</s>'
llm_load_print_meta: UNK token        = 0 '<unk>'
llm_load_print_meta: PAD token        = 2 '</s>'
llm_load_print_meta: LF token         = 13 '<0x0A>'
llm_load_tensors: ggml ctx size =    0.10 MiB
llm_load_tensors:        CPU buffer size =  1114.91 MiB
..........................................................................................
llama_new_context_with_model: n_ctx      = 2048
llama_new_context_with_model: n_batch    = 2048
llama_new_context_with_model: n_ubatch   = 512
llama_new_context_with_model: flash_attn = 0
llama_new_context_with_model: freq_base  = 10000.0
llama_new_context_with_model: freq_scale = 1
llama_kv_cache_init:        CPU KV buffer size =    44.00 MiB
llama_new_context_with_model: KV self size  =   44.00 MiB, K (f16):   22.00 MiB, V (f16):   22.00 MiB
llama_new_context_with_model:        CPU  output buffer size =     0.49 MiB
llama_new_context_with_model:        CPU compute buffer size =   148.01 MiB
llama_new_context_with_model: graph nodes  = 710
llama_new_context_with_model: graph splits = 1

system_info: n_threads = 1 / 2 | AVX = 1 | AVX_VNNI = 0 | AVX2 = 1 | AVX512 = 0 | AVX512_VBMI = 0 | AVX512_VNNI = 0 | AVX512_BF16 = 0 | FMA = 1 | NEON = 0 | ARM_FMA = 0 | F16C = 1 | FP16_VA = 0 | WASM_SIMD = 0 | BLAS = 0 | SSE3 = 1 | SSSE3 = 1 | VSX = 0 | MATMUL_INT8 = 0 | LLAMAFILE = 1 | 
perplexity: tokenizing the input ..
perplexity: 1024 tokens: 1 259 13 353 4755 12476 1896 261 353 29871 13 29871 13 4755 12476 1896 261 338 385 4223 2706 1919 11456 322 24520 11339 869 940 750 263 17838 732 29899 29992 5810 5393 6297 373 278 11456 3652 450 6682 297 29871 29906 29900 29900 29900 869 910 471 5643 491 263 5810 5393 6297 297 278 1708 22167 1983 3971 491 11254 14317 29879 1919 607 471 8560 297 29871 29906 29900 29900 29896 472 278 7021 9245 15521 869 940 750 263 17838 6297 297 278 11456 3652 26817 2259 897 287 297 29871 29906 29900 29900 29906 869 512 29871 29906 29900 29900 29946 12476 1896 261 2982 287 263 6297 408 376 28050 376 297 278 12720 376 22040 4518 525 29879 13740 376 310 278 11456 3652 450 6242 14152 29885 2056 540 5810 1127 19963 29701 4485 3767 549 322 2452 1416 10968 29875 869 940 471 4320 297 278 29871 29906 29900 29900 29945 24520 5802 29879 310 278 14920 21710 11671 1032 1708 29389 29891 7509 1919 607 471 8560 472 278 16597 29885 15521 297 1858 962 2438 322 278 7567 631 14542 15519 371 27561 297 4517 869 940 471 10624 491 2259 18439 600 1384 322 5810 1127 19963 4111 806 728 1450 1919 1383 1662 16753 1362 1919 10686 13272 1919 7347 643 15846 690 1919 19122 347 7813 880 322 13298 293 6573 869 29871 13 512 29871 29906 29900 29900 29953 1919 12476 1896 261 5810 1127 19963 806 728 1450 297 278 1708 21353 19642 3527 3971 491 4485 28093 264 29131 869 940 7470 373 263 29871 29906 29900 29900 29953 12720 310 278 11456 3652 1919 15460 29879 1919 5643 491 263 6297 297 278 29871 29906 29900 29900 29955 24520 5802 310 1128 304 10837 344 10624 491 5875 347 15915 29878 446 869 1128 304 10837 344 471 8560 472 24715 15521 297 278 4517 6780 820 310 26356 414 2415 29882 322 23004 3391 869 12476 1896 261 5810 1127 297 1023 12298 297 29871 29906 29900 29900 29947 1919 8373 4366 6417 495 29891 491 2706 28107 3681 10255 2034 1919 322 3872 1989 12129 17608 29882 10624 491 7137 368 6054 18712 869 512 2610 29871 29906 29900 29900 29947 1919 12476 1896 261 1754 263 17838 10097 373 263 1023 732 29899 29992 760 12720 15232 310 278 11456 3652 22552 9292 278 16992 1919 5643 491 385 10097 373 278 11456 3652 6298 24759 943 297 3979 29871 29906 29900 29900 29947 869 940 750 263 1162 1038 292 6297 297 3006 23238 310 278 11456 3652 6960 950 1017 297 29871 29906 29900 29896 29900 1919 408 376 16540 1489 29876 13859 14246 2276 376 869 12476 1896 261 5810 1127 297 278 29871 29906 29900 29896 29896 2706 4702 10278 4314 10624 491 3681 10255 2034 869 29871 13 29871 13 353 353 15825 353 353 29871 13 29871 13 29871 13 353 353 353 29871 29906 29900 29900 29900 785 29871 29906 29900 29900 29945 353 353 353 29871 13 29871 13 512 29871 29906 29900 29900 29900 12476 1896 261 750 263 17838 732 29899 29992 5810 5393 6297 373 278 11456 3652 450 6682 2056 540 2011 25724 376 8075 1459 719 376 297 278 12720 1919 376 512 14795 29872 5166 29879 376 869 12476 1896 261 5810 1127 408 376 8075 376 297 278 1708 22167 1983 3971 491 11254 14317 29879 1919 607 471 8560 297 29871 29906 29900 29900 29896 472 278 7021 9245 15521 869 319 9076 310 12476 1896 261 525 29879 4180 297 450 25266 373 16340 5439 1075 408 376 4029 1091 368 1757 9390 376 297 278 6297 1919 322 540 4520 12187 21804 297 450 2439 2741 1919 322 7753 292 10117 869 940 7470 297 278 11456 3652 26817 2259 897 287 297 29871 29906 29900 29900 29906 408 376 3462 331 8481 16639 376 297 278 12720 376 25137 12027 15844 3819 376 1919 322 750 263 6297 408 263 1422 2931 376 22354 29891 2443 6146 376 373 450 6682 869 29871 13 940 750 263 1162 1038 292 6297 297 29871 29906 29900 29900 29941 373 1023 23238 310 450 6682 1919 408 2931 376 1281 15459 20743 376 869 512 29871 29906 29900 29900 29946 12476 1896 261 2982 287 263 6297 408 376 28050 376 297 278 12720 376 22040 4518 525 29879 13740 376 310 278 11456 3652 450 6242 14152 29885 2056 540 5810 1127 19963 29701 4485 3767 549 322 2452 1416 10968 29875 869 12476 1896 261 5810 1127 408 376 7335 1267 376 1919 297 278 29871 29906 29900 29900 29945 24520 5802 29879 310 278 14920 21710 11671 1032 1708 29389 29891 7509 869 739 471 8560 472 278 16597 29885 15521 297 1858 962 2438 1919 322 278 7567 631 14542 15519 371 27561 297 4517 869 940 471 10624 491 2259 18439 600 1384 322 5810 1127 19963 4111 806 728 1450 1919 1383 1662 16753 1362 1919 10686 13272 1919 7347 643 15846 690 1919 19122 347 7813 880 322 13298 293 6573 869 12476 1896 261 4520 263 7853 519 9076 297 450 23331 9699 4262 584 376 450 16684 338 528 2620 11687 12838 29872 1919 411 8014 29881 21637 515 4111 806 728 1450 313 1286 443 29423 8069 569 515 670 4180 408 6479 5267 12487 15755 525 29879 7904 1026 1723 1919 4755 12476 1896 261 1919 1383 1662 16753 1362 322 7347 643 15846 690 869 376 450 29429 11682 1919 376 4111 806 728 1450 322 4755 12476 1896 261 5957 22707 2264 28655 278 4048 1875 29891 869 376 29871 13 29871 13 353 353 353 29871 29906 29900 29900 29953 785 2198 353 353 353 29871 13 29871 13 512 29871 29906 29900 29900 29953 12476 1896 261 5810 1127 297 278 1708 21353 19642 3527 3971 491 4485 28093 264 29131 869 450 1708 471 760 310 263 3652 607 15000 1422 1708 15866 5861 1919 29871 
perplexity: tokenization took 5.168 ms
perplexity: calculating perplexity over 2 chunks, n_ctx=512, batch_size=2048, n_seq=4
perplexity: 88.79 seconds per pass - ETA 0.73 minutes
[1]15.4660,[2]23.7605,
Final estimate: PPL = 23.7605 +/- 4.15181

llama_print_timings:        load time =     794.34 ms
llama_print_timings:      sample time =       0.00 ms /     1 runs   (    0.00 ms per token,      inf tokens per second)
llama_print_timings: prompt eval time =   88785.80 ms /  1024 tokens (   86.70 ms per token,    11.53 tokens per second)
llama_print_timings:        eval time =       0.00 ms /     1 runs   (    0.00 ms per token,      inf tokens per second)
llama_print_timings:       total time =   88888.76 ms /  1025 tokens
slaren commented 4 months ago

The warning is only generated for models with a BPE tokenizer, but that model has a SPM tokenizer. I am not sure if that's because the model was converted with the wrong tokenizer type or anything else, but I don't see an easy way to detect that.

ggerganov commented 4 months ago

@turian Try to find the exact commit that changes the results for your model

giladgd commented 4 months ago

I think I found the cause for this. Release b2953, which corresponds to https://github.com/ggerganov/llama.cpp/pull/7375 introduced a change in the tokenizer that affected the tokenization for an old GGUF model I use in tests.

Prior to that release, using this model to tokenize <|from|>system with special tokens enabled resulted in [32002, 6574], which corresponds to that same initial text. After that release, tokenizing the exact same text results in [32002, 1587], which corresponds to <|from|> system (notice the added space in the middle).

The issue appears to be that text after special tokens begin with an added space.

ggerganov commented 4 months ago

@giladgd This issue was fixed shortly after: https://github.com/ggerganov/llama.cpp/pull/7425. It's not related

giladgd commented 4 months ago

@ggerganov I've just tested it again with the latest release (b2998) and the issue is still there

turian commented 4 months ago

@ggerganov the issue appear to be longstanding and perhaps implicit in perplexity.cpp itself

3358c381f6251bf6e65855e1c93bfaa9ec82ddb3 from september of last year still has bad perplexity values.

ee77efea2a1e3f7d153976b0934522b6bbaa62e6 from august of last year simply cannot load the model.

ggerganov commented 4 months ago

Not sure what could be the problem. We do our best to keep backwards compat or at least print warnings when there are breaking changes, but it's possible we overlook some cases. Therefore the only recommended way to use llama.cpp is to convert and quantize a model yourself using the latest version of the code. Downloading pre-quantized models always has the risk of compatibility problems if you use an incorrect version of the code or if the model was not converted to GGUF correctly

github-actions[bot] commented 2 months ago

This issue was closed because it has been inactive for 14 days since being marked as stale.