huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
133.96k stars 26.79k forks source link

`import transformers` accidentally initializing both torch and jax/xla at startup time #30226

Closed stellaraccident closed 6 months ago

stellaraccident commented 6 months ago

System Info

absl-py==1.4.0 aiohttp==3.9.3 aiosignal==1.3.1 alabaster==0.7.16 albumentations==1.3.1 altair==4.2.2 annotated-types==0.6.0 anyio==3.7.1 appdirs==1.4.4 argon2-cffi==23.1.0 argon2-cffi-bindings==21.2.0 array_record==0.5.1 arviz==0.15.1 astropy==5.3.4 astunparse==1.6.3 async-timeout==4.0.3 atpublic==4.1.0 attrs==23.2.0 audioread==3.0.1 autograd==1.6.2 Babel==2.14.0 backcall==0.2.0 beautifulsoup4==4.12.3 bidict==0.23.1 bigframes==1.0.0 bleach==6.1.0 blinker==1.4 blis==0.7.11 blosc2==2.0.0 bokeh==3.3.4 bqplot==0.12.43 branca==0.7.1 build==1.2.1 CacheControl==0.14.0 cachetools==5.3.3 catalogue==2.0.10 certifi==2024.2.2 cffi==1.16.0 chardet==5.2.0 charset-normalizer==3.3.2 chex==0.1.86 click==8.1.7 click-plugins==1.1.1 cligj==0.7.2 cloudpathlib==0.16.0 cloudpickle==2.2.1 cmake==3.27.9 cmdstanpy==1.2.2 colorcet==3.1.0 colorlover==0.3.0 colour==0.1.5 community==1.0.0b1 confection==0.1.4 cons==0.4.6 contextlib2==21.6.0 contourpy==1.2.1 cryptography==42.0.5 cufflinks==0.17.3 cupy-cuda12x==12.2.0 cvxopt==1.3.2 cvxpy==1.3.3 cycler==0.12.1 cymem==2.0.8 Cython==3.0.10 dask==2023.8.1 datascience==0.17.6 db-dtypes==1.2.0 dbus-python==1.2.18 debugpy==1.6.6 decorator==4.4.2 defusedxml==0.7.1 distributed==2023.8.1 distro==1.7.0 dlib==19.24.4 dm-tree==0.1.8 docstring_parser==0.16 docutils==0.18.1 dopamine-rl==4.0.6 duckdb==0.10.1 earthengine-api==0.1.397 easydict==1.13 ecos==2.0.13 editdistance==0.6.2 eerepr==0.0.4 en-core-web-sm @ https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.7.1/en_core_web_sm-3.7.1-py3-none-any.whl#sha256=86cc141f63942d4b2c5fcee06630fd6f904788d2f0ab005cce45aadb8fb73889 entrypoints==0.4 et-xmlfile==1.1.0 etils==1.7.0 etuples==0.3.9 exceptiongroup==1.2.0 fastcore==1.5.29 fastdownload==0.0.7 fastjsonschema==2.19.1 fastprogress==1.0.3 fastrlock==0.8.2 filelock==3.13.4 fiona==1.9.6 firebase-admin==5.3.0 Flask==2.2.5 flatbuffers==24.3.25 flax==0.8.2 folium==0.14.0 fonttools==4.51.0 frozendict==2.4.1 frozenlist==1.4.1 fsspec==2023.6.0 future==0.18.3 gast==0.5.4 gcsfs==2023.6.0 GDAL==3.6.4 gdown==4.7.3 geemap==0.32.0 gensim==4.3.2 geocoder==1.38.1 geographiclib==2.0 geopandas==0.13.2 geopy==2.3.0 gin-config==0.5.0 glob2==0.7 google==2.0.3 google-ai-generativelanguage==0.4.0 google-api-core==2.11.1 google-api-python-client==2.84.0 google-auth==2.27.0 google-auth-httplib2==0.1.1 google-auth-oauthlib==1.2.0 google-cloud-aiplatform==1.47.0 google-cloud-bigquery==3.12.0 google-cloud-bigquery-connection==1.12.1 google-cloud-bigquery-storage==2.24.0 google-cloud-core==2.3.3 google-cloud-datastore==2.15.2 google-cloud-firestore==2.11.1 google-cloud-functions==1.13.3 google-cloud-iam==2.14.3 google-cloud-language==2.13.3 google-cloud-resource-manager==1.12.3 google-cloud-storage==2.8.0 google-cloud-translate==3.11.3 google-colab @ file:///colabtools/dist/google-colab-1.0.0.tar.gz#sha256=ed5e8e54679a5e4587c7f6e47886b20c3662957d23519aa55573ca9054e4f274 google-crc32c==1.5.0 google-generativeai==0.3.2 google-pasta==0.2.0 google-resumable-media==2.7.0 googleapis-common-protos==1.63.0 googledrivedownloader==0.4 graphviz==0.20.3 greenlet==3.0.3 grpc-google-iam-v1==0.13.0 grpcio==1.62.1 grpcio-status==1.48.2 gspread==3.4.2 gspread-dataframe==3.3.1 gym==0.25.2 gym-notices==0.0.8 h5netcdf==1.3.0 h5py==3.9.0 holidays==0.46 holoviews==1.17.1 html5lib==1.1 httpimport==1.3.1 httplib2==0.22.0 huggingface-hub==0.20.3 humanize==4.7.0 hyperopt==0.2.7 ibis-framework==8.0.0 idna==3.6 imageio==2.31.6 imageio-ffmpeg==0.4.9 imagesize==1.4.1 imbalanced-learn==0.10.1 imgaug==0.4.0 importlib_metadata==7.1.0 importlib_resources==6.4.0 imutils==0.5.4 inflect==7.0.0 iniconfig==2.0.0 intel-openmp==2023.2.4 ipyevents==2.0.2 ipyfilechooser==0.6.0 ipykernel==5.5.6 ipyleaflet==0.18.2 ipython==7.34.0 ipython-genutils==0.2.0 ipython-sql==0.5.0 ipytree==0.2.2 ipywidgets==7.7.1 iree-compiler==20240410.859 iree-runtime==20240410.859 iree-turbine==2.3.0rc20240410 itsdangerous==2.1.2 jax==0.4.26 jaxlib @ https://storage.googleapis.com/jax-releases/cuda12/jaxlib-0.4.26+cuda12.cudnn89-cp310-cp310-manylinux2014_x86_64.whl#sha256=813cf1fe3e7ca4dbf5327d6e7b4fc8521e92d8bba073ee645ae0d5d036a25750 jeepney==0.7.1 jieba==0.42.1 Jinja2==3.1.3 joblib==1.4.0 jsonpickle==3.0.3 jsonschema==4.19.2 jsonschema-specifications==2023.12.1 jupyter-client==6.1.12 jupyter-console==6.1.0 jupyter-server==1.24.0 jupyter_core==5.7.2 jupyterlab_pygments==0.3.0 jupyterlab_widgets==3.0.10 kaggle==1.5.16 kagglehub==0.2.2 keras==2.15.0 keyring==23.5.0 kiwisolver==1.4.5 langcodes==3.3.0 launchpadlib==1.10.16 lazr.restfulclient==0.14.4 lazr.uri==1.0.6 lazy_loader==0.4 libclang==18.1.1 librosa==0.10.1 lightgbm==4.1.0 linkify-it-py==2.0.3 llvmlite==0.41.1 locket==1.0.0 logical-unification==0.4.6 lxml==4.9.4 malloy==2023.1067 Markdown==3.6 markdown-it-py==3.0.0 MarkupSafe==2.1.5 matplotlib==3.7.1 matplotlib-inline==0.1.6 matplotlib-venn==0.11.10 mdit-py-plugins==0.4.0 mdurl==0.1.2 miniKanren==1.0.3 missingno==0.5.2 mistune==0.8.4 mizani==0.9.3 mkl==2023.2.0 ml-dtypes==0.2.0 mlxtend==0.22.0 more-itertools==10.1.0 moviepy==1.0.3 mpmath==1.3.0 msgpack==1.0.8 multidict==6.0.5 multipledispatch==1.0.0 multitasking==0.0.11 murmurhash==1.0.10 music21==9.1.0 natsort==8.4.0 nbclassic==1.0.0 nbclient==0.10.0 nbconvert==6.5.4 nbformat==5.10.4 nest-asyncio==1.6.0 networkx==3.3 nibabel==4.0.2 nltk==3.8.1 notebook==6.5.5 notebook_shim==0.2.4 numba==0.58.1 numexpr==2.10.0 numpy==1.25.2 oauth2client==4.1.3 oauthlib==3.2.2 opencv-contrib-python==4.8.0.76 opencv-python==4.8.0.76 opencv-python-headless==4.9.0.80 openpyxl==3.1.2 opt-einsum==3.3.0 optax==0.2.2 orbax-checkpoint==0.4.4 osqp==0.6.2.post8 packaging==24.0 pandas==2.0.3 pandas-datareader==0.10.0 pandas-gbq==0.19.2 pandas-stubs==2.0.3.230814 pandocfilters==1.5.1 panel==1.3.8 param==2.1.0 parso==0.8.4 parsy==2.1 partd==1.4.1 pathlib==1.0.1 patsy==0.5.6 peewee==3.17.1 pexpect==4.9.0 pickleshare==0.7.5 Pillow==9.4.0 pip-tools==6.13.0 platformdirs==4.2.0 plotly==5.15.0 plotnine==0.12.4 pluggy==1.4.0 polars==0.20.2 pooch==1.8.1 portpicker==1.5.2 prefetch-generator==1.0.3 preshed==3.0.9 prettytable==3.10.0 proglog==0.1.10 progressbar2==4.2.0 prometheus_client==0.20.0 promise==2.3 prompt-toolkit==3.0.43 prophet==1.1.5 proto-plus==1.23.0 protobuf==3.20.3 psutil==5.9.5 psycopg2==2.9.9 ptyprocess==0.7.0 py-cpuinfo==9.0.0 py4j==0.10.9.7 pyarrow==14.0.2 pyarrow-hotfix==0.6 pyasn1==0.6.0 pyasn1_modules==0.4.0 pycocotools==2.0.7 pycparser==2.22 pydantic==2.6.4 pydantic_core==2.16.3 pydata-google-auth==1.8.2 pydot==1.4.2 pydot-ng==2.0.0 pydotplus==2.0.2 PyDrive==1.3.1 PyDrive2==1.6.3 pyerfa==2.0.1.3 pygame==2.5.2 Pygments==2.16.1 PyGObject==3.42.1 PyJWT==2.3.0 pymc==5.10.4 pymystem3==0.2.0 PyOpenGL==3.1.7 pyOpenSSL==24.1.0 pyparsing==3.1.2 pyperclip==1.8.2 pyproj==3.6.1 pyproject_hooks==1.0.0 pyshp==2.3.1 PySocks==1.7.1 pytensor==2.18.6 pytest==7.4.4 python-apt @ file:///backend-container/containers/python_apt-0.0.0-cp310-cp310-linux_x86_64.whl#sha256=b209c7165d6061963abe611492f8c91c3bcef4b7a6600f966bab58900c63fefa python-box==7.1.1 python-dateutil==2.8.2 python-louvain==0.16 python-slugify==8.0.4 python-utils==3.8.2 pytz==2023.4 pyviz_comms==3.0.2 PyWavelets==1.6.0 PyYAML==6.0.1 pyzmq==23.2.1 qdldl==0.1.7.post1 qudida==0.0.4 ratelim==0.1.6 referencing==0.34.0 regex==2023.12.25 requests==2.31.0 requests-oauthlib==1.3.1 requirements-parser==0.9.0 rich==13.7.1 rpds-py==0.18.0 rpy2==3.4.2 rsa==4.9 safetensors==0.4.2 scikit-image==0.19.3 scikit-learn==1.2.2 scipy==1.11.4 scooby==0.9.2 scs==3.2.4.post1 seaborn==0.13.1 SecretStorage==3.3.1 Send2Trash==1.8.3 sentencepiece==0.1.99 shapely==2.0.3 six==1.16.0 sklearn-pandas==2.2.0 smart-open==6.4.0 sniffio==1.3.1 snowballstemmer==2.2.0 sortedcontainers==2.4.0 soundfile==0.12.1 soupsieve==2.5 soxr==0.3.7 spacy==3.7.4 spacy-legacy==3.0.12 spacy-loggers==1.0.5 Sphinx==5.0.2 sphinxcontrib-applehelp==1.0.8 sphinxcontrib-devhelp==1.0.6 sphinxcontrib-htmlhelp==2.0.5 sphinxcontrib-jsmath==1.0.1 sphinxcontrib-qthelp==1.0.7 sphinxcontrib-serializinghtml==1.1.10 SQLAlchemy==2.0.29 sqlglot==20.11.0 sqlparse==0.4.4 srsly==2.4.8 stanio==0.5.0 statsmodels==0.14.1 sympy==1.12 tables==3.8.0 tabulate==0.9.0 tbb==2021.12.0 tblib==3.0.0 tenacity==8.2.3 tensorboard==2.15.2 tensorboard-data-server==0.7.2 tensorflow @ https://storage.googleapis.com/colab-tf-builds-public-09h6ksrfwbb9g9xv/tensorflow-2.15.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl#sha256=a2ec79931350b378c1ef300ca836b52a55751acb71a433582508a07f0de57c42 tensorflow-datasets==4.9.4 tensorflow-estimator==2.15.0 tensorflow-gcs-config==2.15.0 tensorflow-hub==0.16.1 tensorflow-io-gcs-filesystem==0.36.0 tensorflow-metadata==1.14.0 tensorflow-probability==0.23.0 tensorstore==0.1.45 termcolor==2.4.0 terminado==0.18.1 text-unidecode==1.3 textblob==0.17.1 tf-slim==1.1.0 tf_keras==2.15.1 thinc==8.2.3 threadpoolctl==3.4.0 tifffile==2024.2.12 tinycss2==1.2.1 tokenizers==0.15.2 toml==0.10.2 tomli==2.0.1 toolz==0.12.1 torch==2.3.0+cpu torchsummary==1.5.1 tornado==6.3.3 tqdm==4.66.2 traitlets==5.7.1 traittypes==0.2.1 transformers==4.38.2 triton==2.2.0 tweepy==4.14.0 typer==0.9.4 types-pytz==2024.1.0.20240203 types-setuptools==69.2.0.20240317 typing_extensions==4.11.0 tzdata==2024.1 tzlocal==5.2 uc-micro-py==1.0.3 uritemplate==4.1.1 urllib3==2.0.7 vega-datasets==0.9.0 wadllib==1.3.6 wasabi==1.1.2 wcwidth==0.2.13 weasel==0.3.4 webcolors==1.13 webencodings==0.5.1 websocket-client==1.7.0 Werkzeug==3.0.2 widgetsnbextension==3.6.6 wordcloud==1.9.3 wrapt==1.14.1 xarray==2023.7.0 xarray-einstats==0.7.0 xgboost==2.0.3 xlrd==2.0.1 xyzservices==2024.4.0 yarl==1.9.4 yellowbrick==1.5 yfinance==0.2.37 zict==3.0.0 zipp==3.18.1

Who can help?

No response

Information

Tasks

Reproduction

Minimal repro: https://colab.research.google.com/gist/ScottTodd/476cfa0dd6620511ace15ce321b03a4e/colab-fork-jax-warning.ipynb (note that we found this because Jax issues a warning about something that pytorch is doing, and we traced that back to the pytorch and this root cause)

Expected behavior

import transformers should not import jax (and whatever other ML frameworks I may have in my venv).

See related bug in pytorch: https://github.com/pytorch/pytorch/issues/123954 And related comment on Jax: https://github.com/google/jax/pull/18989#discussion_r1562799747

I won't be repeating the findings/discussion on the pytorch side here.

I think my expectation is that import transformers does not import jax. Like torch, jax is very expensive and it does a lot of non trivial process-wide manipulation on import. It should be avoided unless if needed.

I've only spent a few minutes looking at this, but the trouble starts here: https://github.com/huggingface/transformers/blob/main/src/transformers/utils/generic.py#L42

There are a couple of things wrong with this:

I think that the issue here is that a find_spec level check is not really what you want in order to do an existence check, so while is_flax_available (the right one) would be better, it would be better still to, at the point the query needs to be made check if "jax" in sys.modules. This would indicate that jax has actually been imported within the process, and there therefore may be a Jax tensor floating around somewhere. At the point we are doing instance checks, if it is not in sys.modules, no one has imported it and there can be no Jax tensors.

I'd suggest just reworking this to this more conservative end state.

amyeroberts commented 6 months ago

cc @ydshieh

ydshieh commented 6 months ago

Hi @stellaraccident Thanks for raising this issue. Having a quick look, I think you are right.

FYI, for model files, we have sth like

else:
    import sys

    sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)

to avoid loading the model objects (and relevant frameworks used in it). However, generic.py (live in utils) is not using such.

We will have to take a deeper look on this topic however.

stellaraccident commented 6 months ago

Neat way to do it. Utils/generic.py is mostly using its own mechanism for lazy loading: I think it just has a bug. Because it is reactive (ie. Only wants to import the backing framework when it encounters a tensor from one of them, it may call for a different mechanism), it may need something special and mostly has it.

In this case, the is_flax_available function is shadowing the same named function from one of its imports, and the local one is doing it wrong. Maybe it was accidentally left from some refactoring, but I think it can be deleted and that may be enough.

(Can be improved further, but that may get it back to what was intended)

ydshieh commented 6 months ago

I think the above PR fix it. You can try your notebook with

!pip uninstall -y transformers
!git clone https://github.com/huggingface/transformers.git && cd transformers && git fetch origin && git checkout avoid_import_jnp && pip install -e .

run it then Runtime -> restart session

then run the remaining of your example

I tried it and no more issue you mentioned

ScottTodd commented 6 months ago

Thanks. I can also confirm that JAX is not imported and the warning from it goes away with that repro: https://colab.research.google.com/gist/ScottTodd/0631e2ec19d0ce699bb0ff343e2b6543/colab-fork-jax-warning.ipynb