kuprel / min-dalle

min(DALL·E) is a fast, minimal port of DALL·E Mini to PyTorch
MIT License
3.48k stars 257 forks source link

Text Tokenizer is fragmenting words #32

Closed Kreevoz closed 2 years ago

Kreevoz commented 2 years ago

I'm running into unexpected behavior of the text tokenizer, running this on Windows, Python 3.7 , in a virtual environment, using the supplied image_from_text.py script file.

The input text is tokenized in a way that breaks up the words, thus preventing the output from actually depicting what was requested:

'a comfy chair that looks like an avocado' ->

tokenizing text
['Ġ', 'a']
['Ġ', 'com', 'fy']
['Ġ', 'chair']
['Ġ', 'th', 'at']
['Ġ', 'look', 's']
['Ġ', 'like']
['Ġ', 'an']
['Ġ', 'av', 'oc', 'ado']
text tokens [0, 3, 28, 3, 157, 10065, 3, 10022, 3, 184, 73, 3, 7003, 46, 3, 19831, 3, 65, 3, 178, 158, 1165, 2]

'alien life' ->

tokenizing text
['Ġ', 'al', 'ien']
['Ġ', 'life']
text tokens [0, 3, 71, 1385, 3, 3210, 2]

Since the wrong tokens were chosen, the model returns a generic gamer chair for the first prompt, and some petri dish for the second, which is expected given the garbled tokens.

I checked that the tokenizer.json files were downloaded correctly for both the mini and mega models and they are - manually searching for the words in them finds them in there without any issue.

Is there a specific dependency for the text tokenizer that I'm unaware of or is this simply a bug?

kuprel commented 2 years ago

Strange, this is what I get from the tokenizer for that prompt:

tokenizing text
['Ġa']
['Ġcomfy']
['Ġchair']
['Ġthat']
['Ġlooks']
['Ġlike']
['Ġan']
['Ġavocado']
text tokens [0, 58, 29872, 2408, 766, 4126, 1572, 101, 16632, 2]
Kreevoz commented 2 years ago

Strange indeed... I can't explain this behavior either. Doesn't seem to be an issue with parsing the text from the commandline.. I'll try setting up a few virtual environments with different python versions. 🤔

Kreevoz commented 2 years ago

Exact same result using Python 3.9 in a fresh virtual environment. Edit: Tested 3.9.13 and 3.9.7, same behavior, so it is not a recent bug I suppose?

Do any packages here look out of order? (Other than the jaxlib which I installed manually from https://github.com/cloudhan/jax-windows-builder )

(vENV39) PS Y:\min-dalle> pip freeze
absl-py==1.1.0
certifi==2022.6.15
charset-normalizer==2.0.12
chex==0.1.3
cycler==0.11.0
dm-tree==0.1.7
etils==0.6.0
flatbuffers==2.0
flax==0.4.2
fonttools==4.33.3
idna==3.3
importlib-resources==5.8.0
jax==0.3.14
jaxlib @ file:///Y:/jaxlib-0.3.14%2Bcuda11.cudnn82-cp39-none-win_amd64.whl
kiwisolver==1.4.3
matplotlib==3.5.2
msgpack==1.0.4
numpy==1.23.0
opt-einsum==3.3.0
optax==0.1.2
packaging==21.3
Pillow==9.1.1
pyparsing==3.0.9
python-dateutil==2.8.2
requests==2.28.0
scipy==1.8.1
six==1.16.0
toolz==0.11.2
torch==1.12.0+cu116
torchaudio==0.12.0+cu116
torchvision==0.13.0+cu116
typing_extensions==4.2.0
urllib3==1.26.9
zipp==3.8.0
Kreevoz commented 2 years ago

I've added a couple more print statements to see what the tokenizer is up to:

tokenizing text
['Ġ', 'a']
['Ġ', 'a']
['Ġ', 'a']
['Ġ', 'c', 'o', 'm', 'f', 'y']
['Ġ', 'c', 'o', 'm', 'f', 'y']
['Ġ', 'c', 'om', 'f', 'y']
['Ġ', 'com', 'f', 'y']
['Ġ', 'com', 'fy']
['Ġ', 'com', 'fy']
['Ġ', 'c', 'h', 'a', 'i', 'r']
['Ġ', 'c', 'h', 'a', 'i', 'r']
['Ġ', 'ch', 'a', 'i', 'r']
['Ġ', 'ch', 'a', 'ir']
['Ġ', 'ch', 'air']
['Ġ', 'chair']
['Ġ', 'chair']
['Ġ', 't', 'h', 'a', 't']
['Ġ', 't', 'h', 'a', 't']
['Ġ', 't', 'h', 'at']
['Ġ', 'th', 'at']
['Ġ', 'th', 'at']
['Ġ', 'l', 'o', 'o', 'k', 's']
['Ġ', 'l', 'o', 'o', 'k', 's']
['Ġ', 'l', 'o', 'ok', 's']
['Ġ', 'l', 'ook', 's']
['Ġ', 'look', 's']
['Ġ', 'look', 's']
['Ġ', 'l', 'i', 'k', 'e']
['Ġ', 'l', 'i', 'k', 'e']
['Ġ', 'l', 'ik', 'e']
['Ġ', 'l', 'ike']
['Ġ', 'like']
['Ġ', 'like']
['Ġ', 'a', 'n']
['Ġ', 'a', 'n']
['Ġ', 'an']
['Ġ', 'an']
['Ġ', 'a', 'v', 'o', 'c', 'a', 'd', 'o']
['Ġ', 'a', 'v', 'o', 'c', 'a', 'd', 'o']
['Ġ', 'a', 'v', 'o', 'c', 'ad', 'o']
['Ġ', 'a', 'v', 'oc', 'ad', 'o']
['Ġ', 'av', 'oc', 'ad', 'o']
['Ġ', 'av', 'oc', 'ado']
['Ġ', 'av', 'oc', 'ado']

Why would your tokenizer fail to complete the long words when executing on my hardware? 😵

kuprel commented 2 years ago

Not sure. It works properly in colab too: https://colab.research.google.com/github/kuprel/min-dalle/blob/main/min_dalle.ipynb

alexx-km commented 2 years ago

I've got the same issue, also on Windows (running on CPU only as I have an AMD GPU)... I'll check my packages if there are any other similarities between your setup and mine!

Kreevoz commented 2 years ago

I found a pattern for this bug!

Vocabulary entries that have an entry that begins with: Ġ are not being matched by the tokenizer. Tokens that do not start with that symbol will assemble successfully into long words. (I do not know why the list of tokens contains both types of entries?)

For instance: "project", "record", "management" will assemble into valid tokens. But: "projections", "recordings", "manage" will not, because they are only listed as "Ġprojections", "Ġrecordings", "Ġmanage" in the json files.

That is why it breaks up words into such odd chunks. It can only pick the ones that start without that special character!

So there must be platform differences between linux and windows in how that accented Ġ is parsed. Can you account for this in your tokenizer? Can we strip that out?

kuprel commented 2 years ago

Do you get that Ġ character when you run this in python? print(chr(ord(" ") + 256))

Kreevoz commented 2 years ago

Affirmative.

>>> print(chr(ord(" ") + 256))
Ġ
kuprel commented 2 years ago

If you can figure out what will make it work on windows let me know. I don't have any windows machines

Kreevoz commented 2 years ago

Yes, I got a fix. It was one of those annoying OS-specific things indeed.

You need to explicitly specify that the json files get parsed as utf-8.

On Windows the parser will default to the system locale unless specified (usually cp1252 or similar for english installs, different codepages for other languages). This causes the accented G to get lost/garbled up.

The fix is easily added in lines 16, 18 and 20 in the ./min_dalle/min_dalle.py file:

        with open(os.path.join(model_path, 'config.json'), 'r', encoding='utf8') as f: 
            self.config = json.load(f)
        with open(os.path.join(model_path, 'vocab.json'), 'r', encoding='utf8') as f:
            vocab = json.load(f)
        with open(os.path.join(model_path, 'merges.txt'), 'r', encoding='utf8') as f:

This should not negatively impact how the code executes under linux. The output now conforms to your examples when executing on windows and the tokens are correct.

kuprel commented 2 years ago

Awesome thanks. I just updated it. Does it work now?

Kreevoz commented 2 years ago

Yep, just freshly cloned your repo to make sure it's all okay, and it is. Windows users may rejoice now!