mukel / llama3.java

Practical Llama 3 inference in Java
MIT License
514 stars 61 forks source link

Support for split UTF-8 sequences #5

Open srogmann opened 3 months ago

srogmann commented 3 months ago

Hi @mukel,

I like your Llama3 implementation using the Vector API.

Here is a pull request to handle split UTF-8 sequences.

An example is the prompt "How to write 'three little cats' in chinese? Add an emoji.". In this example the UTF-8 bytes of the cat emoji U+1F638 may be split by Llama-3 into 240, 159, 152 in the first event and the missing 184 in the next event.

mukel commented 3 months ago

Thanks for the PR! I was looking for a general fix that worked also for streaming; I think this only works for decoding of full token sequences. When streaming tokens, it's possible to get a partial codepoint, I think the fix should be something similar, hold the partial codepoint until it is complete and can be printed. Also, the UTF-8 bytes cannot be trusted to be valid. Will take a closer look tomorrow.

srogmann commented 3 months ago

When streaming tokens, it's possible to get a partial codepoint

The byte-array in the fix is used to collect a partial codepoint to support streaming.

Also, the UTF-8 bytes cannot be trusted to be valid.

I hadn't wrong UTF-8 bytes in my examples, so there is no check for bit-mask 0b10...... in bytes 2, 3, 4.

srogmann commented 3 months ago

I was wondering if using a record array could be an alternative to the if-chain:

record Utf8Mask(int mask, int pattern, int len) {
    static final Utf8Mask[] MASKS = {
            new Utf8Mask(0b11100000, 0b11000000, 2),
            new Utf8Mask(0b11110000, 0b11100000, 3),
            new Utf8Mask(0b11111000, 0b11110000, 4)
    };
}

[...]

                for (Utf8Mask utf8Mask : Utf8Mask.MASKS) {
                    if ((b & utf8Mask.mask()) == utf8Mask.pattern()) {
                        currUtf8Mask = utf8Mask;
                        bufUtf8[currUtf8Index++] = b;
                        continue loopDecoded;
                    }
                }

patch_record_Utf8Mask.txt