google / trax

Trax — Deep Learning with Clear Code and Speed
Apache License 2.0
8.01k stars 813 forks source link

ImportError: cannot import name 'MergeHeads' from 'trax.layers.attention' #1766

Open jpontalba opened 1 year ago

jpontalba commented 1 year ago

Description

ImportError thrown after importing libraries ...

Environment information

trax 1.4.1

OS: Ubuntu 

$ pip freeze | grep trax
trax                         1.4.1

$ pip freeze | grep tensor
mesh-tensorflow==0.1.21
tensor2tensor==1.15.7
tensorboard==2.11.0
tensorboard-data-server==0.6.1
tensorboard-plugin-wit==1.8.1
tensorflow==2.11.0
tensorflow-addons==0.18.0
tensorflow-datasets==4.7.0
tensorflow-estimator==2.11.0
tensorflow-gan==2.1.0
tensorflow-hub==0.12.0
tensorflow-io-gcs-filesystem==0.28.0
tensorflow-metadata==1.12.0
tensorflow-probability==0.7.0
tensorflow-text==2.11.0
tensorstore==0.1.28

$ pip freeze | grep jax
jax==0.3.25
jaxlib==0.3.25

$ python -V
Python 3.8.10

For bugs: reproduction and error logs

# Steps to reproduce:
!pip install -q -U trax

import numpy as np  # regular ol' numpy

from trax import fastmath
from trax import layers as tl
from trax import shapes
from trax.fastmath import numpy as jnp  # For use in defining new layer types.
from trax.shapes import ShapeDtype
from trax.shapes import signature
# Error logs:
---------------------------------------------------------------------------
ImportError                               Traceback (most recent call last)
Cell In[22], line 3
      1 import numpy as np  # regular ol' numpy
----> 3 from trax import fastmath
      4 from trax import layers as tl
      5 from trax import shapes

File ~/NovaceneAI/trax_projects/.venv/lib/python3.8/site-packages/trax/__init__.py:18
      1 # coding=utf-8
      2 # Copyright 2021 The Trax Authors.
      3 #
   (...)
     13 # See the License for the specific language governing permissions and
     14 # limitations under the License.
     16 """Trax top level import."""
---> 18 from trax import data
     19 from trax import fastmath
     20 from trax import layers

File ~/NovaceneAI/trax_projects/.venv/lib/python3.8/site-packages/trax/data/__init__.py:70
     67 from trax.data.inputs import UnBatch
     68 from trax.data.inputs import UniformlySeek
---> 70 from trax.data.tf_inputs import add_eos_to_output_features
     71 from trax.data.tf_inputs import BertGlueEvalStream
...
     35 from trax.layers.attention import SplitIntoHeads
     38 # Layers are always CamelCase, but functions in general are snake_case
     39 # pylint: disable=invalid-name

ImportError: cannot import name 'MergeHeads' from 'trax.layers.attention'