rai-llc / LanguageModels.jl

Load nanoGPT-style transformers in Julia. Code ported from @karpathy's llama2.c
MIT License
59 stars 2 forks source link

Manually detect special tokens #6

Closed jiahao closed 1 year ago

jiahao commented 1 year ago

The current tokenizer has some problems automatically detecting the <s> token. I think it will be necessary to sweep the input prompt string manually to find the special tokens (1: <unk>, 2:<s>, 3:</s> and possibly 22581:INST also).

Note: some code to look up tokens in the current tokenizer is

julia> for (i,v) in enumerate(LanguageModels.repl_tokenizer.alphabet)
         if contains("INST", v)
           println(i, ":***", v,"***")
           end end

77:***I***
82:***N***
87:***S***
88:***T***
1178:***IN***
1255:***ST***
3060:***NS***
25581:***INST***
29903:***I***
29904:***S***
29912:***T***
29941:***N***

Interestingly the tokenizer basically re-memorized the individual letters in the 299xx series, even though they are hard-coded as individual bytes in the 4-259 range. It's probably because the raw sentencepiece representation is <0x87> for the individual bytes.

I think there will be some UnicodeErrors downstream when the model tries to emit emojis and that kind of thing.

jiahao commented 1 year ago

Further investigation shows that the duplicate encoding of characters happens only for ASCII codepoints 0x20:0x7e.

julia> for ch in 0x00:0xff
          char = String(UInt8[ch])
          for (id, tok) in enumerate(LanguageModels.repl_tokenizer.alphabet)
             if contains(char, tok)
                println(ch, '\t', char, '\t', id, '\t', tok)
       end end end
0       4   
1       5   
2       6   
...
32      36  
32      29872   
33  !   37  !
33  !   29992   !
...
126 ~   130 ~
126 ~   30023   ~
127     131 

So a hacky way to handle the generation of emojis is to monitor for generated tokens in the range (0x80:0xff .+ 4 = 132:259) and dump them into a raw byte buffer, then parse them into a string once a token outside this range is generated, or when the last token is done.