Closed stellaraccident closed 7 months ago
cc @ydshieh
Hi @stellaraccident Thanks for raising this issue. Having a quick look, I think you are right.
FYI, for model files, we have sth like
else:
import sys
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
to avoid loading the model objects (and relevant frameworks used in it). However, generic.py
(live in utils
) is not using such.
We will have to take a deeper look on this topic however.
Neat way to do it. Utils/generic.py is mostly using its own mechanism for lazy loading: I think it just has a bug. Because it is reactive (ie. Only wants to import the backing framework when it encounters a tensor from one of them, it may call for a different mechanism), it may need something special and mostly has it.
In this case, the is_flax_available function is shadowing the same named function from one of its imports, and the local one is doing it wrong. Maybe it was accidentally left from some refactoring, but I think it can be deleted and that may be enough.
(Can be improved further, but that may get it back to what was intended)
I think the above PR fix it. You can try your notebook with
!pip uninstall -y transformers
!git clone https://github.com/huggingface/transformers.git && cd transformers && git fetch origin && git checkout avoid_import_jnp && pip install -e .
run it then Runtime -> restart session
then run the remaining of your example
I tried it and no more issue you mentioned
Thanks. I can also confirm that JAX is not imported and the warning from it goes away with that repro: https://colab.research.google.com/gist/ScottTodd/0631e2ec19d0ce699bb0ff343e2b6543/colab-fork-jax-warning.ipynb
System Info
absl-py==1.4.0 aiohttp==3.9.3 aiosignal==1.3.1 alabaster==0.7.16 albumentations==1.3.1 altair==4.2.2 annotated-types==0.6.0 anyio==3.7.1 appdirs==1.4.4 argon2-cffi==23.1.0 argon2-cffi-bindings==21.2.0 array_record==0.5.1 arviz==0.15.1 astropy==5.3.4 astunparse==1.6.3 async-timeout==4.0.3 atpublic==4.1.0 attrs==23.2.0 audioread==3.0.1 autograd==1.6.2 Babel==2.14.0 backcall==0.2.0 beautifulsoup4==4.12.3 bidict==0.23.1 bigframes==1.0.0 bleach==6.1.0 blinker==1.4 blis==0.7.11 blosc2==2.0.0 bokeh==3.3.4 bqplot==0.12.43 branca==0.7.1 build==1.2.1 CacheControl==0.14.0 cachetools==5.3.3 catalogue==2.0.10 certifi==2024.2.2 cffi==1.16.0 chardet==5.2.0 charset-normalizer==3.3.2 chex==0.1.86 click==8.1.7 click-plugins==1.1.1 cligj==0.7.2 cloudpathlib==0.16.0 cloudpickle==2.2.1 cmake==3.27.9 cmdstanpy==1.2.2 colorcet==3.1.0 colorlover==0.3.0 colour==0.1.5 community==1.0.0b1 confection==0.1.4 cons==0.4.6 contextlib2==21.6.0 contourpy==1.2.1 cryptography==42.0.5 cufflinks==0.17.3 cupy-cuda12x==12.2.0 cvxopt==1.3.2 cvxpy==1.3.3 cycler==0.12.1 cymem==2.0.8 Cython==3.0.10 dask==2023.8.1 datascience==0.17.6 db-dtypes==1.2.0 dbus-python==1.2.18 debugpy==1.6.6 decorator==4.4.2 defusedxml==0.7.1 distributed==2023.8.1 distro==1.7.0 dlib==19.24.4 dm-tree==0.1.8 docstring_parser==0.16 docutils==0.18.1 dopamine-rl==4.0.6 duckdb==0.10.1 earthengine-api==0.1.397 easydict==1.13 ecos==2.0.13 editdistance==0.6.2 eerepr==0.0.4 en-core-web-sm @ https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.7.1/en_core_web_sm-3.7.1-py3-none-any.whl#sha256=86cc141f63942d4b2c5fcee06630fd6f904788d2f0ab005cce45aadb8fb73889 entrypoints==0.4 et-xmlfile==1.1.0 etils==1.7.0 etuples==0.3.9 exceptiongroup==1.2.0 fastcore==1.5.29 fastdownload==0.0.7 fastjsonschema==2.19.1 fastprogress==1.0.3 fastrlock==0.8.2 filelock==3.13.4 fiona==1.9.6 firebase-admin==5.3.0 Flask==2.2.5 flatbuffers==24.3.25 flax==0.8.2 folium==0.14.0 fonttools==4.51.0 frozendict==2.4.1 frozenlist==1.4.1 fsspec==2023.6.0 future==0.18.3 gast==0.5.4 gcsfs==2023.6.0 GDAL==3.6.4 gdown==4.7.3 geemap==0.32.0 gensim==4.3.2 geocoder==1.38.1 geographiclib==2.0 geopandas==0.13.2 geopy==2.3.0 gin-config==0.5.0 glob2==0.7 google==2.0.3 google-ai-generativelanguage==0.4.0 google-api-core==2.11.1 google-api-python-client==2.84.0 google-auth==2.27.0 google-auth-httplib2==0.1.1 google-auth-oauthlib==1.2.0 google-cloud-aiplatform==1.47.0 google-cloud-bigquery==3.12.0 google-cloud-bigquery-connection==1.12.1 google-cloud-bigquery-storage==2.24.0 google-cloud-core==2.3.3 google-cloud-datastore==2.15.2 google-cloud-firestore==2.11.1 google-cloud-functions==1.13.3 google-cloud-iam==2.14.3 google-cloud-language==2.13.3 google-cloud-resource-manager==1.12.3 google-cloud-storage==2.8.0 google-cloud-translate==3.11.3 google-colab @ file:///colabtools/dist/google-colab-1.0.0.tar.gz#sha256=ed5e8e54679a5e4587c7f6e47886b20c3662957d23519aa55573ca9054e4f274 google-crc32c==1.5.0 google-generativeai==0.3.2 google-pasta==0.2.0 google-resumable-media==2.7.0 googleapis-common-protos==1.63.0 googledrivedownloader==0.4 graphviz==0.20.3 greenlet==3.0.3 grpc-google-iam-v1==0.13.0 grpcio==1.62.1 grpcio-status==1.48.2 gspread==3.4.2 gspread-dataframe==3.3.1 gym==0.25.2 gym-notices==0.0.8 h5netcdf==1.3.0 h5py==3.9.0 holidays==0.46 holoviews==1.17.1 html5lib==1.1 httpimport==1.3.1 httplib2==0.22.0 huggingface-hub==0.20.3 humanize==4.7.0 hyperopt==0.2.7 ibis-framework==8.0.0 idna==3.6 imageio==2.31.6 imageio-ffmpeg==0.4.9 imagesize==1.4.1 imbalanced-learn==0.10.1 imgaug==0.4.0 importlib_metadata==7.1.0 importlib_resources==6.4.0 imutils==0.5.4 inflect==7.0.0 iniconfig==2.0.0 intel-openmp==2023.2.4 ipyevents==2.0.2 ipyfilechooser==0.6.0 ipykernel==5.5.6 ipyleaflet==0.18.2 ipython==7.34.0 ipython-genutils==0.2.0 ipython-sql==0.5.0 ipytree==0.2.2 ipywidgets==7.7.1 iree-compiler==20240410.859 iree-runtime==20240410.859 iree-turbine==2.3.0rc20240410 itsdangerous==2.1.2 jax==0.4.26 jaxlib @ https://storage.googleapis.com/jax-releases/cuda12/jaxlib-0.4.26+cuda12.cudnn89-cp310-cp310-manylinux2014_x86_64.whl#sha256=813cf1fe3e7ca4dbf5327d6e7b4fc8521e92d8bba073ee645ae0d5d036a25750 jeepney==0.7.1 jieba==0.42.1 Jinja2==3.1.3 joblib==1.4.0 jsonpickle==3.0.3 jsonschema==4.19.2 jsonschema-specifications==2023.12.1 jupyter-client==6.1.12 jupyter-console==6.1.0 jupyter-server==1.24.0 jupyter_core==5.7.2 jupyterlab_pygments==0.3.0 jupyterlab_widgets==3.0.10 kaggle==1.5.16 kagglehub==0.2.2 keras==2.15.0 keyring==23.5.0 kiwisolver==1.4.5 langcodes==3.3.0 launchpadlib==1.10.16 lazr.restfulclient==0.14.4 lazr.uri==1.0.6 lazy_loader==0.4 libclang==18.1.1 librosa==0.10.1 lightgbm==4.1.0 linkify-it-py==2.0.3 llvmlite==0.41.1 locket==1.0.0 logical-unification==0.4.6 lxml==4.9.4 malloy==2023.1067 Markdown==3.6 markdown-it-py==3.0.0 MarkupSafe==2.1.5 matplotlib==3.7.1 matplotlib-inline==0.1.6 matplotlib-venn==0.11.10 mdit-py-plugins==0.4.0 mdurl==0.1.2 miniKanren==1.0.3 missingno==0.5.2 mistune==0.8.4 mizani==0.9.3 mkl==2023.2.0 ml-dtypes==0.2.0 mlxtend==0.22.0 more-itertools==10.1.0 moviepy==1.0.3 mpmath==1.3.0 msgpack==1.0.8 multidict==6.0.5 multipledispatch==1.0.0 multitasking==0.0.11 murmurhash==1.0.10 music21==9.1.0 natsort==8.4.0 nbclassic==1.0.0 nbclient==0.10.0 nbconvert==6.5.4 nbformat==5.10.4 nest-asyncio==1.6.0 networkx==3.3 nibabel==4.0.2 nltk==3.8.1 notebook==6.5.5 notebook_shim==0.2.4 numba==0.58.1 numexpr==2.10.0 numpy==1.25.2 oauth2client==4.1.3 oauthlib==3.2.2 opencv-contrib-python==4.8.0.76 opencv-python==4.8.0.76 opencv-python-headless==4.9.0.80 openpyxl==3.1.2 opt-einsum==3.3.0 optax==0.2.2 orbax-checkpoint==0.4.4 osqp==0.6.2.post8 packaging==24.0 pandas==2.0.3 pandas-datareader==0.10.0 pandas-gbq==0.19.2 pandas-stubs==2.0.3.230814 pandocfilters==1.5.1 panel==1.3.8 param==2.1.0 parso==0.8.4 parsy==2.1 partd==1.4.1 pathlib==1.0.1 patsy==0.5.6 peewee==3.17.1 pexpect==4.9.0 pickleshare==0.7.5 Pillow==9.4.0 pip-tools==6.13.0 platformdirs==4.2.0 plotly==5.15.0 plotnine==0.12.4 pluggy==1.4.0 polars==0.20.2 pooch==1.8.1 portpicker==1.5.2 prefetch-generator==1.0.3 preshed==3.0.9 prettytable==3.10.0 proglog==0.1.10 progressbar2==4.2.0 prometheus_client==0.20.0 promise==2.3 prompt-toolkit==3.0.43 prophet==1.1.5 proto-plus==1.23.0 protobuf==3.20.3 psutil==5.9.5 psycopg2==2.9.9 ptyprocess==0.7.0 py-cpuinfo==9.0.0 py4j==0.10.9.7 pyarrow==14.0.2 pyarrow-hotfix==0.6 pyasn1==0.6.0 pyasn1_modules==0.4.0 pycocotools==2.0.7 pycparser==2.22 pydantic==2.6.4 pydantic_core==2.16.3 pydata-google-auth==1.8.2 pydot==1.4.2 pydot-ng==2.0.0 pydotplus==2.0.2 PyDrive==1.3.1 PyDrive2==1.6.3 pyerfa==2.0.1.3 pygame==2.5.2 Pygments==2.16.1 PyGObject==3.42.1 PyJWT==2.3.0 pymc==5.10.4 pymystem3==0.2.0 PyOpenGL==3.1.7 pyOpenSSL==24.1.0 pyparsing==3.1.2 pyperclip==1.8.2 pyproj==3.6.1 pyproject_hooks==1.0.0 pyshp==2.3.1 PySocks==1.7.1 pytensor==2.18.6 pytest==7.4.4 python-apt @ file:///backend-container/containers/python_apt-0.0.0-cp310-cp310-linux_x86_64.whl#sha256=b209c7165d6061963abe611492f8c91c3bcef4b7a6600f966bab58900c63fefa python-box==7.1.1 python-dateutil==2.8.2 python-louvain==0.16 python-slugify==8.0.4 python-utils==3.8.2 pytz==2023.4 pyviz_comms==3.0.2 PyWavelets==1.6.0 PyYAML==6.0.1 pyzmq==23.2.1 qdldl==0.1.7.post1 qudida==0.0.4 ratelim==0.1.6 referencing==0.34.0 regex==2023.12.25 requests==2.31.0 requests-oauthlib==1.3.1 requirements-parser==0.9.0 rich==13.7.1 rpds-py==0.18.0 rpy2==3.4.2 rsa==4.9 safetensors==0.4.2 scikit-image==0.19.3 scikit-learn==1.2.2 scipy==1.11.4 scooby==0.9.2 scs==3.2.4.post1 seaborn==0.13.1 SecretStorage==3.3.1 Send2Trash==1.8.3 sentencepiece==0.1.99 shapely==2.0.3 six==1.16.0 sklearn-pandas==2.2.0 smart-open==6.4.0 sniffio==1.3.1 snowballstemmer==2.2.0 sortedcontainers==2.4.0 soundfile==0.12.1 soupsieve==2.5 soxr==0.3.7 spacy==3.7.4 spacy-legacy==3.0.12 spacy-loggers==1.0.5 Sphinx==5.0.2 sphinxcontrib-applehelp==1.0.8 sphinxcontrib-devhelp==1.0.6 sphinxcontrib-htmlhelp==2.0.5 sphinxcontrib-jsmath==1.0.1 sphinxcontrib-qthelp==1.0.7 sphinxcontrib-serializinghtml==1.1.10 SQLAlchemy==2.0.29 sqlglot==20.11.0 sqlparse==0.4.4 srsly==2.4.8 stanio==0.5.0 statsmodels==0.14.1 sympy==1.12 tables==3.8.0 tabulate==0.9.0 tbb==2021.12.0 tblib==3.0.0 tenacity==8.2.3 tensorboard==2.15.2 tensorboard-data-server==0.7.2 tensorflow @ https://storage.googleapis.com/colab-tf-builds-public-09h6ksrfwbb9g9xv/tensorflow-2.15.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl#sha256=a2ec79931350b378c1ef300ca836b52a55751acb71a433582508a07f0de57c42 tensorflow-datasets==4.9.4 tensorflow-estimator==2.15.0 tensorflow-gcs-config==2.15.0 tensorflow-hub==0.16.1 tensorflow-io-gcs-filesystem==0.36.0 tensorflow-metadata==1.14.0 tensorflow-probability==0.23.0 tensorstore==0.1.45 termcolor==2.4.0 terminado==0.18.1 text-unidecode==1.3 textblob==0.17.1 tf-slim==1.1.0 tf_keras==2.15.1 thinc==8.2.3 threadpoolctl==3.4.0 tifffile==2024.2.12 tinycss2==1.2.1 tokenizers==0.15.2 toml==0.10.2 tomli==2.0.1 toolz==0.12.1 torch==2.3.0+cpu torchsummary==1.5.1 tornado==6.3.3 tqdm==4.66.2 traitlets==5.7.1 traittypes==0.2.1 transformers==4.38.2 triton==2.2.0 tweepy==4.14.0 typer==0.9.4 types-pytz==2024.1.0.20240203 types-setuptools==69.2.0.20240317 typing_extensions==4.11.0 tzdata==2024.1 tzlocal==5.2 uc-micro-py==1.0.3 uritemplate==4.1.1 urllib3==2.0.7 vega-datasets==0.9.0 wadllib==1.3.6 wasabi==1.1.2 wcwidth==0.2.13 weasel==0.3.4 webcolors==1.13 webencodings==0.5.1 websocket-client==1.7.0 Werkzeug==3.0.2 widgetsnbextension==3.6.6 wordcloud==1.9.3 wrapt==1.14.1 xarray==2023.7.0 xarray-einstats==0.7.0 xgboost==2.0.3 xlrd==2.0.1 xyzservices==2024.4.0 yarl==1.9.4 yellowbrick==1.5 yfinance==0.2.37 zict==3.0.0 zipp==3.18.1
Who can help?
No response
Information
Tasks
examples
folder (such as GLUE/SQuAD, ...)Reproduction
Minimal repro: https://colab.research.google.com/gist/ScottTodd/476cfa0dd6620511ace15ce321b03a4e/colab-fork-jax-warning.ipynb (note that we found this because Jax issues a warning about something that pytorch is doing, and we traced that back to the pytorch and this root cause)
Expected behavior
import transformers
should not import jax (and whatever other ML frameworks I may have in my venv).See related bug in pytorch: https://github.com/pytorch/pytorch/issues/123954 And related comment on Jax: https://github.com/google/jax/pull/18989#discussion_r1562799747
I won't be repeating the findings/discussion on the pytorch side here.
I think my expectation is that
import transformers
does not import jax. Like torch, jax is very expensive and it does a lot of non trivial process-wide manipulation on import. It should be avoided unless if needed.I've only spent a few minutes looking at this, but the trouble starts here: https://github.com/huggingface/transformers/blob/main/src/transformers/utils/generic.py#L42
There are a couple of things wrong with this:
is_flax_available
def right after we import one from import_utils? I haven't looked at it deeply, but it looks like the import_utils one is doing the right thing by using find_spec, using env vars, etc. Can the local is_flax_available be deleted so we are using the right/robust one?_is_jax
checks like https://github.com/huggingface/transformers/blob/main/src/transformers/utils/generic.py#L232 could be made a lot more conservative.from jax.core import Tracer
later in the file (https://github.com/huggingface/transformers/blob/main/src/transformers/utils/generic.py#L138C1-L139C1)I think that the issue here is that a
find_spec
level check is not really what you want in order to do an existence check, so whileis_flax_available
(the right one) would be better, it would be better still to, at the point the query needs to be made check if"jax" in sys.modules
. This would indicate thatjax
has actually been imported within the process, and there therefore may be a Jax tensor floating around somewhere. At the point we are doing instance checks, if it is not insys.modules
, no one has imported it and there can be no Jax tensors.I'd suggest just reworking this to this more conservative end state.