google / trax

Trax — Deep Learning with Clear Code and Speed
Apache License 2.0
8.01k stars 813 forks source link

AttributeError: module 'jax.ops' has no attribute 'index_add' #1773

Open cmosguy opened 1 year ago

cmosguy commented 1 year ago

Description

I am trying to do something basic in my code:

import numpy as np              # regular ol' numpy
from trax import layers as tl   # core building block
from trax import shapes         # data signatures: dimensionality and type
from trax import fastmath       # uses jax, offers numpy on steroids

Upon import it errors out doing the basics here. What am I doing wrong? Should I be pinning a different version of the code?

Environment information

OS: Cento lsb_release LSB Version: :core-4.1-amd64:core-4.1-ia32:core-4.1-noarch:cxx-4.1-amd64:cxx-4.1-ia32:cxx-4.1-noarch:desktop-4.1-amd64:desktop-4.1-ia32:desktop-4.1-noarch:languages-4.1-amd64:languages-4.1-noarch:printing-4.1-amd64:printing-4.1-noarch

$ pip freeze | grep trax trax==1.3.9

$ pip freeze | grep tensor mesh-tensorflow==0.1.21 tensorboard==2.11.2 tensorboard-data-server==0.6.1 tensorboard-plugin-wit==1.8.1 tensorflow==2.11.0 tensorflow-datasets==4.8.2 tensorflow-estimator==2.11.0 tensorflow-hub==0.12.0 tensorflow-io-gcs-filesystem==0.30.0 tensorflow-metadata==1.12.0 tensorflow-text==2.11.0

$ pip freeze | grep jax jax==0.4.4 jaxlib==0.4.4

$ python -V Python 3.9.16


### For bugs: reproduction and error logs

# Error logs:

...

      1 # coding=utf-8
      2 # Copyright 2021 The Trax Authors.
      3 #
   (...)
     13 # See the License for the specific language governing permissions and
     14 # limitations under the License.
     16 """Trax top level import."""
---> 18 from trax import data
     19 from trax import fastmath
     20 from trax import layers

File ./ds_work/miniconda3/envs/coursera-nlp/lib/python3.9/site-packages/trax/data/__init__.py:36, in <module>
     16 """Functions and classes for obtaining and preprocesing data.
     17 
     18 The ``trax.data`` module presents a flattened (no subpackages) public API.
   (...)
...
    217     'vjp': jax.vjp,
    218     'vmap': jax.vmap,
    219 }

AttributeError: module 'jax.ops' has no attribute 'index_add'
stephengineer commented 1 year ago

downgrade jax to 0.2.21 jax.ops.index_add is deprecated in 0.2.22 https://gitee.com/mirrors/JAX/blob/main/CHANGELOG.md#jax-0222-oct-12-2021