ggerganov / llama.cpp

LLM inference in C/C++
MIT License
67.51k stars 9.69k forks source link

Errors w/ BPE tokenizers (GGML_ASSERT: llama.cpp:2029: codepoints_from_utf8(word).size() > 0 and more) #4360

Closed lhl closed 7 months ago

lhl commented 11 months ago

Prerequisites

Please answer the following questions for yourself before submitting an issue.

Expected Behavior

I have a Mistral 7B based model shisa-7b-v1 that has an extended (128128) BPE tokenizer. This works fine and I have pulled the vocab.json from tokenizer.json (and there is a special_tokens_map.json, and some added_tokens in the tokenizer.json.

I am able to convert the model with --vocabtype bpe with no errors.

And I am actually able to run llama_bench on the model however when infererencing, I get this error:

GGML_ASSERT: llama.cpp:2695: codepoints_from_utf8(word).size() > 0
Aborted (core dumped)

Current Behavior

As mentioned, there is an assert, here that get's triggered: https://github.com/ggerganov/llama.cpp/blob/bcc0eb4591bec5ec02fad3f2bdcb1b265052ea56/llama.cpp#L2695

I did a bit of poking and ended up hacking in a replacement token just to see if I could make it go:

      // GGML_ASSERT(codepoints_from_utf8(word).size() > 0); 
        if (codepoints_from_utf8(word).empty()) {                                                                        
            std::stringstream ss;                                                           
            for (unsigned char c : word) {  // Ensure char is treated as unsigned                           
                ss << std::hex << static_cast<int>(c) << " ";  // Convert each byte to hex                             
            }           
            LLAMA_LOG_WARN("%s: Word '%s' could not be converted to UTF-8 codepoints and will be replaced with: ❌❌\n", __func__, ss.str().c_str());
            word = "❌❌";                                                          
        }                                                             

I tried to get a codepoint, and it turns out it only triggers once, but sadly seems to be a literal null character?

llm_load_vocab: Word '' could not be converted to UTF-8 codepoints and will be replaced with: ❌❌         

Sadly this was not the only error once things got running, as this gets output as well

llm_load_vocab: mismatch in special tokens definition ( 1087/120128 vs 55/120128 ).                                                                                                                                                        

That's an aweful lot of special tokens? (there are only 4 in our special_tokens_map.json...

I modified the code to print out what tokens it thought were issues:

@@ -2811,9 +2827,11 @@ static void llm_load_vocab(
                         // Count manually found special tokens
                         special_tokens_count_from_verification++;

+
                         // If this manually found special token is not marked as such, flag a mismatch
                         if (vocab.id_to_token[id].type == LLAMA_TOKEN_TYPE_NORMAL) {
                             special_tokens_definition_mismatch = true;
+                            LLAMA_LOG_WARN("%s: Special token mismatch for token '%s'. Expected special, found normal.\n", __func__, vocab.id_to_token[id].text.c_str());
                         }
                     }
                 }

It prints out lots of regular tokens, not sure why it's expecting special tokens?

llm_load_vocab: Special token mismatch for token '▁7.5J▁7.50-18'. Expected special, found normal.                                                                                                                                          
llm_load_vocab: Special token mismatch for token 'MG150'. Expected special, found normal.                            
llm_load_vocab: Special token mismatch for token 'デイトナ▁1165'. Expected special, found normal.                                                                                                                                          
llm_load_vocab: Special token mismatch for token '(火)▁22'. Expected special, found normal.                                                                                                                                                
llm_load_vocab: Special token mismatch for token '平成24'. Expected special, found normal.                           
llm_load_vocab: Special token mismatch for token '18V'. Expected special, found normal.                              
llm_load_vocab: Special token mismatch for token '001▁概要▁仕様書▁'. Expected special, found normal.                                                                                                                                       
llm_load_vocab: Special token mismatch for token '分(▁02/'. Expected special, found normal.                         
llm_load_vocab: Special token mismatch for token '(火)▁23'. Expected special, found normal.                          
llm_load_vocab: Special token mismatch for token '7750搭載'. Expected special, found normal.                         
llm_load_vocab: Special token mismatch for token 'USB▁3.0'. Expected special, found normal.                          
...

Once everything is loaded we get:

<s>Hello world: terminate called after throwing an instance of 'std::out_of_range'
  what():  unordered_map::at
Aborted (core dumped)

But I didn't follow up more, since it seems that there's somewhere either in the conversion or the token handling code that's beyond my ken and messed up.

Note, I found to related open discussions/issues (but I don't think they got past the initial assert) - they are both models that use extended BPE tokenizers I believe though:

These issues seem to be unresolved from a few months back, but reporting this new issue since hopefully it sheds some more light on what might be going on? Maybe the bpe conversion is actually broken?

In our base model card, we actually have a list of models using other tokenizers so that might also help in tracking down issues. StableLM Beta JAVocab and CALM2-7B are two more Llama2 models using non-standard tokenizers.

Environment and Context

Can relay more info if not reproducible but I don't think that's it

cmp-nct commented 11 months ago

The BPE tokenizer was taken from a project of mine, it was accompanied by a slim unicode library (cmpnct_unicode.cpp) I assume to keep a smaller codebase, or simplify it there were a few shortcuts taken and my lib was not included but only some parts of it taken, one of the parts that were not taken was the codepoint conversion.

I ran into a similar issue as well with chinese tokens, when working with OpenBuddy which contains a large chinese bpe vocab as special tokens. I did not have time to properly debug or fix it but I use a quickfix on my end by just switching back to my lib - until the bug is fixed. Given you are already invested, until the problem is solved you can use my method. It fixes the decoding, and the display of special (chinese) tokens. You'll need to also add a similar addition into the load_vocab you encountered (I did not encounter that problem on my end)

1) in llama.cpp add #include "common/cmpnct_unicode.h" 2) add https://github.com/cmp-nct/ggllm.cpp/blob/master/cmpnct_unicode.cpp and the header into common directory 3) in Makefile add them to the common library cmpnct_unicode.h cmpnct_unicode.cpp Now you have the full ggllm unicode "library" when needed, it's a static class 4) In addition add the ggllm bpe decoding functions into llama.cpp:

// split a string into unicode strings MIT licensed - ()
static std::unordered_map<std::string, unsigned char> unicode_to_bytes() {
    static std::unordered_map<std::string, unsigned char> hex_map = { { "\x21", 0x21 }, { "\x22", 0x22 }, { "\x23", 0x23 }, { "\x24", 0x24 }, { "\x25", 0x25 }, { "\x26", 0x26 }, { "\x27", 0x27 }, { "\x28", 0x28 }, { "\x29", 0x29 }, { "\x2A", 0x2A }, { "\x2B", 0x2B }, { "\x2C", 0x2C }, { "\x2D", 0x2D }, { "\x2E", 0x2E }, { "\x2F", 0x2F }, { "\x30", 0x30 }, { "\x31", 0x31 }, { "\x32", 0x32 }, { "\x33", 0x33 }, { "\x34", 0x34 }, { "\x35", 0x35 }, { "\x36", 0x36 }, { "\x37", 0x37 }, { "\x38", 0x38 }, { "\x39", 0x39 }, { "\x3A", 0x3A }, { "\x3B", 0x3B }, { "\x3C", 0x3C }, { "\x3D", 0x3D }, { "\x3E", 0x3E }, { "\x3F", 0x3F }, { "\x40", 0x40 }, { "\x41", 0x41 }, { "\x42", 0x42 }, { "\x43", 0x43 }, { "\x44", 0x44 }, { "\x45", 0x45 }, { "\x46", 0x46 }, { "\x47", 0x47 }, { "\x48", 0x48 }, { "\x49", 0x49 }, { "\x4A", 0x4A }, { "\x4B", 0x4B }, { "\x4C", 0x4C }, { "\x4D", 0x4D }, { "\x4E", 0x4E }, { "\x4F", 0x4F }, { "\x50", 0x50 }, { "\x51", 0x51 }, { "\x52", 0x52 }, { "\x53", 0x53 }, { "\x54", 0x54 }, { "\x55", 0x55 }, { "\x56", 0x56 }, { "\x57", 0x57 }, { "\x58", 0x58 }, { "\x59", 0x59 }, { "\x5A", 0x5A }, { "\x5B", 0x5B }, { "\x5C", 0x5C }, { "\x5D", 0x5D }, { "\x5E", 0x5E }, { "\x5F", 0x5F }, { "\x60", 0x60 }, { "\x61", 0x61 }, { "\x62", 0x62 }, { "\x63", 0x63 }, { "\x64", 0x64 }, { "\x65", 0x65 }, { "\x66", 0x66 }, { "\x67", 0x67 }, { "\x68", 0x68 }, { "\x69", 0x69 }, { "\x6A", 0x6A }, { "\x6B", 0x6B }, { "\x6C", 0x6C }, { "\x6D", 0x6D }, { "\x6E", 0x6E }, { "\x6F", 0x6F }, { "\x70", 0x70 }, { "\x71", 0x71 }, { "\x72", 0x72 }, { "\x73", 0x73 }, { "\x74", 0x74 }, { "\x75", 0x75 }, { "\x76", 0x76 }, { "\x77", 0x77 }, { "\x78", 0x78 }, { "\x79", 0x79 }, { "\x7A", 0x7A }, { "\x7B", 0x7B }, { "\x7C", 0x7C }, { "\x7D", 0x7D }, { "\x7E", 0x7E }, { "\xC2\xA1", 0xA1 }, { "\xC2\xA2", 0xA2 }, { "\xC2\xA3", 0xA3 }, { "\xC2\xA4", 0xA4 }, { "\xC2\xA5", 0xA5 }, { "\xC2\xA6", 0xA6 }, { "\xC2\xA7", 0xA7 }, { "\xC2\xA8", 0xA8 }, { "\xC2\xA9", 0xA9 }, { "\xC2\xAA", 0xAA }, { "\xC2\xAB", 0xAB }, { "\xC2\xAC", 0xAC }, { "\xC2\xAE", 0xAE }, { "\xC2\xAF", 0xAF }, { "\xC2\xB0", 0xB0 }, { "\xC2\xB1", 0xB1 }, { "\xC2\xB2", 0xB2 }, { "\xC2\xB3", 0xB3 }, { "\xC2\xB4", 0xB4 }, { "\xC2\xB5", 0xB5 }, { "\xC2\xB6", 0xB6 }, { "\xC2\xB7", 0xB7 }, { "\xC2\xB8", 0xB8 }, { "\xC2\xB9", 0xB9 }, { "\xC2\xBA", 0xBA }, { "\xC2\xBB", 0xBB }, { "\xC2\xBC", 0xBC }, { "\xC2\xBD", 0xBD }, { "\xC2\xBE", 0xBE }, { "\xC2\xBF", 0xBF }, { "\xC3\x80", 0xC0 }, { "\xC3\x81", 0xC1 }, { "\xC3\x82", 0xC2 }, { "\xC3\x83", 0xC3 }, { "\xC3\x84", 0xC4 }, { "\xC3\x85", 0xC5 }, { "\xC3\x86", 0xC6 }, { "\xC3\x87", 0xC7 }, { "\xC3\x88", 0xC8 }, { "\xC3\x89", 0xC9 }, { "\xC3\x8A", 0xCA }, { "\xC3\x8B", 0xCB }, { "\xC3\x8C", 0xCC }, { "\xC3\x8D", 0xCD }, { "\xC3\x8E", 0xCE }, { "\xC3\x8F", 0xCF }, { "\xC3\x90", 0xD0 }, { "\xC3\x91", 0xD1 }, { "\xC3\x92", 0xD2 }, { "\xC3\x93", 0xD3 }, { "\xC3\x94", 0xD4 }, { "\xC3\x95", 0xD5 }, { "\xC3\x96", 0xD6 }, { "\xC3\x97", 0xD7 }, { "\xC3\x98", 0xD8 }, { "\xC3\x99", 0xD9 }, { "\xC3\x9A", 0xDA }, { "\xC3\x9B", 0xDB }, { "\xC3\x9C", 0xDC }, { "\xC3\x9D", 0xDD }, { "\xC3\x9E", 0xDE }, { "\xC3\x9F", 0xDF }, { "\xC3\xA0", 0xE0 }, { "\xC3\xA1", 0xE1 }, { "\xC3\xA2", 0xE2 }, { "\xC3\xA3", 0xE3 }, { "\xC3\xA4", 0xE4 }, { "\xC3\xA5", 0xE5 }, { "\xC3\xA6", 0xE6 }, { "\xC3\xA7", 0xE7 }, { "\xC3\xA8", 0xE8 }, { "\xC3\xA9", 0xE9 }, { "\xC3\xAA", 0xEA }, { "\xC3\xAB", 0xEB }, { "\xC3\xAC", 0xEC }, { "\xC3\xAD", 0xED }, { "\xC3\xAE", 0xEE }, { "\xC3\xAF", 0xEF }, { "\xC3\xB0", 0xF0 }, { "\xC3\xB1", 0xF1 }, { "\xC3\xB2", 0xF2 }, { "\xC3\xB3", 0xF3 }, { "\xC3\xB4", 0xF4 }, { "\xC3\xB5", 0xF5 }, { "\xC3\xB6", 0xF6 }, { "\xC3\xB7", 0xF7 }, { "\xC3\xB8", 0xF8 }, { "\xC3\xB9", 0xF9 }, { "\xC3\xBA", 0xFA }, { "\xC3\xBB", 0xFB }, { "\xC3\xBC", 0xFC }, { "\xC3\xBD", 0xFD }, { "\xC3\xBE", 0xFE }, { "\xC3\xBF", 0xFF }, { "\xC4\x80", 0x00 }, { "\xC4\x81", 0x01 }, { "\xC4\x82", 0x02 }, { "\xC4\x83", 0x03 }, { "\xC4\x84", 0x04 }, { "\xC4\x85", 0x05 }, { "\xC4\x86", 0x06 }, { "\xC4\x87", 0x07 }, { "\xC4\x88", 0x08 }, { "\xC4\x89", 0x09 }, { "\xC4\x8A", 0x0A }, { "\xC4\x8B", 0x0B }, { "\xC4\x8C", 0x0C }, { "\xC4\x8D", 0x0D }, { "\xC4\x8E", 0x0E }, { "\xC4\x8F", 0x0F }, { "\xC4\x90", 0x10 }, { "\xC4\x91", 0x11 }, { "\xC4\x92", 0x12 }, { "\xC4\x93", 0x13 }, { "\xC4\x94", 0x14 }, { "\xC4\x95", 0x15 }, { "\xC4\x96", 0x16 }, { "\xC4\x97", 0x17 }, { "\xC4\x98", 0x18 }, { "\xC4\x99", 0x19 }, { "\xC4\x9A", 0x1A }, { "\xC4\x9B", 0x1B }, { "\xC4\x9C", 0x1C }, { "\xC4\x9D", 0x1D }, { "\xC4\x9E", 0x1E }, { "\xC4\x9F", 0x1F }, { "\xC4\xA0", 0x20 }, { "\xC4\xA1", 0x7F }, { "\xC4\xA2", 0x80 }, { "\xC4\xA3", 0x81 }, { "\xC4\xA4", 0x82 }, { "\xC4\xA5", 0x83 }, { "\xC4\xA6", 0x84 }, { "\xC4\xA7", 0x85 }, { "\xC4\xA8", 0x86 }, { "\xC4\xA9", 0x87 }, { "\xC4\xAA", 0x88 }, { "\xC4\xAB", 0x89 }, { "\xC4\xAC", 0x8A }, { "\xC4\xAD", 0x8B }, { "\xC4\xAE", 0x8C }, { "\xC4\xAF", 0x8D }, { "\xC4\xB0", 0x8E }, { "\xC4\xB1", 0x8F }, { "\xC4\xB2", 0x90 }, { "\xC4\xB3", 0x91 }, { "\xC4\xB4", 0x92 }, { "\xC4\xB5", 0x93 }, { "\xC4\xB6", 0x94 }, { "\xC4\xB7", 0x95 }, { "\xC4\xB8", 0x96 }, { "\xC4\xB9", 0x97 }, { "\xC4\xBA", 0x98 }, { "\xC4\xBB", 0x99 }, { "\xC4\xBC", 0x9A }, { "\xC4\xBD", 0x9B }, { "\xC4\xBE", 0x9C }, { "\xC4\xBF", 0x9D }, { "\xC5\x80", 0x9E }, { "\xC5\x81", 0x9F }, { "\xC5\x82", 0xA0 }, { "\xC5\x83", 0xAD }};
    return hex_map;
}
static std::string cnct_decode_token(const std::string& token)
{
    static std::unordered_map< std::string, unsigned char> byte_decoder = unicode_to_bytes();
    std::string decoded_token="";
    auto unicode_seqeunces = cnct_split_utf8(token);
    for (auto& unicode_sequence : unicode_seqeunces)
    {
        decoded_token += byte_decoder[unicode_sequence];
    }

    return decoded_token;
}

5) Now you can replace the llama_decode_text() function in llama.cpp with the ggml variant which uses the new codepoint decoder:

static std::string llama_decode_text(const std::string & text) {
    return cnct_decode_token(text);
}

6) Now in addition the llama default token to string function has no option to display special tokens. In case of some models with chinese tokens you will find the extra vocabulary as "special tokens" so they would be invisible in generated text (skipped), just as control characters. Here is a PR of mine with a bool flag to print them: https://github.com/ggerganov/llama.cpp/pull/4106

This all is quite a bit much to do, just showing what I did on my local fork of llama.cpp I assume there is an easier solution to the codepoint problem, I just didn't want to dig into that given I already solved it differently in ggllm. In your case you'll need to also fix the load_vocab by taking a peek at unicode_to_bytes() or with cnct_decode_token()

that's all not meant as a fix to the bug itself, just as a workaround and indication where it is.

lhl commented 11 months ago

Maybe of interest, but for our extended tokenizer (and maybe other extended SentencePiece tokenizers like ELYZA's, a Japanese dev @mmnga was able to GGUF quant our model by using a slightly modified convert.py script that just adds the additional vocab in. (I though it would be really hard, but the diff looks not so bad?

❯ diff convert.py convert_shisa.py
357c357
< 
---
> from transformers import AutoTokenizer
360a361,364
> 
>         # vocab.json
>         self.vocabjson = json.load(open(Path(fname_tokenizer).parent / "vocab.json", encoding="utf-8"))
> 
373,374c377,378
<         if expected_new_ids != actual_new_ids:
<             raise ValueError(f"Expected new token IDs {expected_new_ids} to be sequential; got {actual_new_ids}")
---
>         # if expected_new_ids != actual_new_ids:
>         #     raise ValueError(f"Expected new token IDs {expected_new_ids} to be sequential; got {actual_new_ids}")
379c383,385
<         self.vocab_size         = self.vocab_size_base + len(self.added_tokens_list)
---
>         # self.vocab_size         = self.vocab_size_base + len(self.added_tokens_list)
>         self.vocab_size         = len(self.vocabjson) + len(self.added_tokens_list)
> 
385c391,392
<         for i in range(tokenizer.vocab_size()):
---
>         n_sp_vocab = tokenizer.vocab_size()
>         for i in range(n_sp_vocab):
406a414,431
>         # add extra tokens by vocab.json
>         reversed_vocab = {id: encoded_tok for encoded_tok, id in self.vocabjson.items()}
>         n_hf_tokens = len(self.vocabjson.items())
>         print("n_hf_tokens", n_hf_tokens)
> 
>         for i in range(n_sp_vocab, n_hf_tokens):
>             # text = reversed_vocab[i].encode("utf-8")
>             if i in reversed_vocab:
>                 if reversed_vocab[i].encode("utf-8") == b'\x00':
>                     text = f"[PAD{i}]".encode("utf-8")
>                     print("space token to pad", b'\x00')
>                 else:
>                     text = reversed_vocab[i].encode("utf-8")
>             else:
>                 text = f"[PAD{i}]".encode("utf-8")
> 
>             yield text, 0.0, gguf.TokenType.NORMAL
cmp-nct commented 10 months ago

I use a similar hack on models in my convert.py, if I recall right llama.cpp actually HAD that support half a year ago or so and for some reason it was removed from it. So one way is to pad like you did in the convert.py One way is to provide the padding tokens in the configuration of the HF model already One way is to fix the internal code of llama.cpp to add virtual padding tokens or ignore the issue as needed

I guess the 3rd option is the best, that's likely why the support was dropped (too early)

teleprint-me commented 10 months ago

@cmp-nct I'm adding it back in. Give me about 2 - 3 days.

kalomaze commented 9 months ago

This is still a problem for the new foundational models released by InternLM (which have been Llama-fied by Charles Goddard)

image
intervitens commented 9 months ago

In case of the InternLM2 model, the problem is with the token 354 "\u0000":354, It gets converted into an empty vector by the codepoints_from_utf8 function, which then triggers the assert. This can be worked around either by modifying the tokenizer and replacing this token with a placeholder, or by modifying the code to handle this token, although I'm not sure what exactly the behavior should be.

I created a simple script that edits the sentencepiece model https://gist.github.com/intervitens/d171990ade60afd5dfe51415f6bf8c3b

github-actions[bot] commented 7 months ago

This issue is stale because it has been open for 30 days with no activity.

github-actions[bot] commented 7 months ago

This issue was closed because it has been inactive for 14 days since being marked as stale.