mivanit / tabGPT

use GPT to classify a bunch of your open tabs!
3 stars 0 forks source link

transformers requiring jax? #2

Closed mivanit closed 1 year ago

mivanit commented 1 year ago

when running py classify_tabs.py gen "Lorem Ipsum" which has the expected behavior of generating a completion using GPT-2, I instead receive an error:

Traceback (most recent call last):
  File "C:\Python\Python3_10\lib\site-packages\jax\_src\lib\__init__.py", line 37, in <module>
    import jaxlib
ModuleNotFoundError: No module named 'jaxlib'

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "F:\projects\tools\tabGPT\classify_tabs.py", line 7, in <module>
    from transformers import AutoTokenizer, AutoModelForCausalLM
  File "C:\Python\Python3_10\lib\site-packages\transformers\__init__.py", line 30, in <module>
    from . import dependency_versions_check
  File "C:\Python\Python3_10\lib\site-packages\transformers\dependency_versions_check.py", line 17, in <module>
    from .utils.versions import require_version, require_version_core
  File "C:\Python\Python3_10\lib\site-packages\transformers\utils\__init__.py", line 34, in <module>
    from .generic import (
  File "C:\Python\Python3_10\lib\site-packages\transformers\utils\generic.py", line 36, in <module>
    import jax.numpy as jnp
  File "C:\Python\Python3_10\lib\site-packages\jax\__init__.py", line 37, in <module>
    from . import config as _config_module
  File "C:\Python\Python3_10\lib\site-packages\jax\config.py", line 18, in <module>
    from jax._src.config import config
  File "C:\Python\Python3_10\lib\site-packages\jax\_src\config.py", line 27, in <module>
    from jax._src import lib
  File "C:\Python\Python3_10\lib\site-packages\jax\_src\lib\__init__.py", line 39, in <module>
    raise ModuleNotFoundError(
ModuleNotFoundError: jax requires jaxlib to be installed. See https://github.com/google/jax#installation for installation instructions.

an easy workaround is just to install jax (which is difficult on windows, will try on WSL shortly), but jax shouldn't be required for transformers if we are using the PyTorch based models. Also, getting jax working on Colab is a pain (compared to pytorch, at any rate).

next steps:

rusheb commented 1 year ago

Cannot repro this on mac.

I'm going to spin up a windows VM and try there.

rusheb commented 1 year ago

I can't reproduce on windows either.

Steps taken:

Output:

 $ python3 classify_tabs.py gen "lorem ipsum"
Downloading (…)lve/main/config.json: 100%|█████████████████████████████████████████████| 762/762 [00:00<00:00, 112kB/s]
C:\Users\Student\AppData\Local\Packages\PythonSoftwareFoundation.Python.3.10_qbz5n2kfra8p0\LocalCache\local-packages\Python310\site-packages\huggingface_hub\file_download.py:129: UserWarning: `huggingface_hub` cache-system uses symlinks by default to efficiently store duplicated files but your machine does not support them in C:\Users\Student\.cache\huggingface\hub. Caching files will still work but in a degraded version that might require more space on your disk. This warning can be disabled by setting the `HF_HUB_DISABLE_SYMLINKS_WARNING` environment variable. For more details, see https://huggingface.co/docs/huggingface_hub/how-to-cache#limitations.
To support symlinks on Windows, you either need to activate Developer Mode or to run Python as an administrator. In order to see activate developer mode, see this article: https://docs.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-development
  warnings.warn(message)
Downloading (…)olve/main/vocab.json: 100%|█████████████████████████████████████████| 1.04M/1.04M [00:01<00:00, 799kB/s]
Downloading (…)olve/main/merges.txt: 100%|███████████████████████████████████████████| 456k/456k [00:00<00:00, 675kB/s]
Downloading (…)/main/tokenizer.json: 100%|█████████████████████████████████████████| 1.36M/1.36M [00:01<00:00, 842kB/s]
Downloading (…)"pytorch_model.bin";: 100%|██████████████████████████████████████████| 353M/353M [03:02<00:00, 1.93MB/s]
Downloading (…)neration_config.json: 100%|█████████████████████████████████████████████| 124/124 [00:00<00:00, 122kB/s]
The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
lorem ipsum ipsum ip
tokens: ['l', 'orem', ' ', 'ips', 'um', ' ', 'ips', 'um', ' ip']
negative prob: 1.8566004655440338e-05
positive prob: 1.1448397344793193e-05
mivanit commented 1 year ago

wow, ok. this must be a problem with my installation, then. I'll take a look today and try to get it resolved.

rusheb commented 1 year ago

Output of pip show on mac:

❯ pip show jax
WARNING: Package(s) not found: jax
❯ pip show jaxlib
WARNING: Package(s) not found: jaxlib
mivanit commented 1 year ago

issue resolved. Solution: if you happen to have jax installed, make sure jaxlib is also installed correctly.

For some reason, I had jax installed in that env but not jaxlib.