keras-team / keras-nlp

Modular Natural Language Processing workflows with Keras
Apache License 2.0
758 stars 227 forks source link

Cirular import on CachedMultiHeadAttention #1193

Closed koaning closed 1 year ago

koaning commented 1 year ago

Describe the bug

I just installed jaxlib with keras core and figured I might give keras_nlp a spin.

import os
os.environ["KERAS_BACKEND"] = "jax"

import keras_nlp

However, there seems to be a circular import error.

---------------------------------------------------------------------------
ImportError                               Traceback (most recent call last)
Cell In[5], line 1
----> 1 import keras_nlp

File ~/Development/embetter/venv/lib/python3.9/site-packages/keras_nlp/__init__.py:8
      1 """DO NOT EDIT.
      2 
      3 This file was autogenerated. Do not edit it by hand,
      4 since your modifications would be overwritten.
      5 """
----> 8 from keras_nlp import layers
      9 from keras_nlp import metrics
     10 from keras_nlp import models

File ~/Development/embetter/venv/lib/python3.9/site-packages/keras_nlp/layers/__init__.py:8
      1 """DO NOT EDIT.
      2 
      3 This file was autogenerated. Do not edit it by hand,
      4 since your modifications would be overwritten.
      5 """
----> 8 from keras_nlp.src.layers.modeling.cached_multi_head_attention import CachedMultiHeadAttention
      9 from keras_nlp.src.layers.modeling.f_net_encoder import FNetEncoder
     10 from keras_nlp.src.layers.modeling.masked_lm_head import MaskedLMHead

File ~/Development/embetter/venv/lib/python3.9/site-packages/keras_nlp/src/layers/modeling/cached_multi_head_attention.py:16
      1 # Copyright 2023 The KerasNLP Authors
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
   (...)
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 """Cached MHA layer based on `keras.layers.MultiHeadAttention`."""
---> 16 from keras_nlp.src.api_export import keras_nlp_export
     17 from keras_nlp.src.backend import keras
     18 from keras_nlp.src.backend import ops

File ~/Development/embetter/venv/lib/python3.9/site-packages/keras_nlp/src/__init__.py:23
     20 except ImportError:
     21     pass
---> 23 from keras_nlp.src import layers
     24 from keras_nlp.src import metrics
     25 from keras_nlp.src import models

File ~/Development/embetter/venv/lib/python3.9/site-packages/keras_nlp/src/layers/__init__.py:15
      1 # Copyright 2023 The KerasNLP Authors
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
   (...)
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
---> 15 from keras_nlp.src.layers.modeling.cached_multi_head_attention import (
     16     CachedMultiHeadAttention,
     17 )
     18 from keras_nlp.src.layers.modeling.f_net_encoder import FNetEncoder
     19 from keras_nlp.src.layers.modeling.masked_lm_head import MaskedLMHead

ImportError: cannot import name 'CachedMultiHeadAttention' from partially initialized module 'keras_nlp.src.layers.modeling.cached_multi_head_attention' (most likely due to a circular import) (/home/vincent/Development/embetter/venv/lib/python3.9/site-packages/keras_nlp/src/layers/modeling/cached_multi_head_attention.py)

Additional context

I am using Python v3.9 on a M1 Mac.

koaning commented 1 year ago

Oh, and to be clear, when I don't do os.environ["KERAS_BACKEND"] = "jax" it totally seems to work fine.

abheesht17 commented 1 year ago

I don't see this error on Colab...if possible, can you share a Colab where you've reproduced this error?

I don't see it on local either.

(keras_nlp_temp) abheesht@Abheeshts-MacBook-Air envs % pip install tensorflow torch jaxlib jax keras-core keras-nlp -q
(keras_nlp_temp) abheesht@Abheeshts-MacBook-Air envs % python3
Python 3.9.6 (default, May  7 2023, 23:32:44) 
[Clang 14.0.3 (clang-1403.0.22.14.1)] on darwin
Type "help", "copyright", "credits" or "license" for more information.
>>> import os
>>> os.environ["KERAS_BACKEND"] = "jax"
>>> import keras_nlp
Using JAX backend.
mattdangerw commented 1 year ago

Interesting, yeah I am betting this has to do with specific dev environment you have, as we can't just naively repro.

You are running on MacOS? How are you install tensorflow-text? Or are you not? What python environment? (pyenv, conda, nothing) What version of keras-nlp and keras-core do you have? (e.g print(keras_nlp.__version__)

Not really sure what would cause this, but hopefully with some more info we could get to the bottom.

koaning commented 1 year ago

My bad folks!

It seems that I was using a stale venv. I just removed the venv and reinstalled from scratch which does seem to remove the issue. I can't say for sure what exactly was the culprit here, but since it's clean from a fresh venv ... I'd say it's a false alarm.

Thanks for the quick response though!

abheesht17 commented 1 year ago

Awesome, have fun with KerasNLP!

SangameswaranRS commented 9 months ago

Installation

!pip install keras-nlp !pip install tensorflow --upgrade

Error

Tf version: 2.14.0

ImportError Traceback (most recent call last) in <cell line: 2>() 1 print(tf.version) ----> 2 import keras_nlp

4 frames /usr/local/lib/python3.10/dist-packages/keras_nlp/src/layers/init.py in 13 # limitations under the License. 14 ---> 15 from keras_nlp.src.layers.modeling.cached_multi_head_attention import ( 16 CachedMultiHeadAttention, 17 )

ImportError: cannot import name 'CachedMultiHeadAttention' from partially initialized module 'keras_nlp.src.layers.modeling.cached_multi_head_attention' (most likely due to a circular import) (/usr/local/lib/python3.10/dist-packages/keras_nlp/src/layers/modeling/cached_multi_head_attention.py)

Env

google colab 3.10.12 (sys.version)

SangameswaranRS commented 9 months ago

Tried using

!pip install tensorflow torch jaxlib jax keras-core keras-nlp -q

as @abheesht17 suggested in comments, still the same error

chozzz commented 8 months ago

The same issue happened to me on Google Colab, this fixes it for me as per @koaning mentioned, looks like a stale env. What I did was;

  1. !pip uninstall keras keras-core keras-nlp
  2. !pip install keras==2.15.0 keras-core==0.1.7 keras-nlp==0.6.3
  3. Restarted the session on Colab

And importing keras_nlp no longer shows that error, weird..