kingoflolz / mesh-transformer-jax

Model parallel transformers in JAX and Haiku
Apache License 2.0
6.27k stars 891 forks source link

Import optax on Colab gives: cannot import name 'flags' from 'jax.config' #53

Closed davidlee321 closed 3 years ago

davidlee321 commented 3 years ago

Did anyone else encounter this error while trying the demo.ipynb on colab?

ImportError                               Traceback (most recent call last)

<ipython-input-9-72cd76e3a907> in <module>()
      4 from jax.experimental import maps
      5 import numpy as np
----> 6 import optax
      7 import transformers
      8 

6 frames

/usr/local/lib/python3.7/dist-packages/optax/__init__.py in <module>()
     16 """Optax: composable gradient processing and optimization, in JAX."""
     17 
---> 18 from optax._src.alias import adabelief
     19 from optax._src.alias import adagrad
     20 from optax._src.alias import adam

/usr/local/lib/python3.7/dist-packages/optax/_src/alias.py in <module>()
     20 import jax.numpy as jnp
     21 
---> 22 from optax._src import combine
     23 from optax._src import privacy
     24 from optax._src import schedule

/usr/local/lib/python3.7/dist-packages/optax/_src/combine.py in <module>()
     16 """Flexibly compose gradient transformations."""
     17 
---> 18 from optax._src import transform
     19 GradientTransformation = transform.GradientTransformation
     20 

/usr/local/lib/python3.7/dist-packages/optax/_src/transform.py in <module>()
     18 from typing import Any, Callable, NamedTuple, Optional, Sequence, Tuple, Union
     19 
---> 20 import chex
     21 import jax
     22 import jax.numpy as jnp

/usr/local/lib/python3.7/dist-packages/chex/__init__.py in <module>()
     15 """Chex: Testing made fun, in JAX!"""
     16 
---> 17 from chex._src.asserts import assert_axis_dimension
     18 from chex._src.asserts import assert_axis_dimension_gt
     19 from chex._src.asserts import assert_devices_available

/usr/local/lib/python3.7/dist-packages/chex/_src/asserts.py in <module>()
     29 import jax
     30 import jax.numpy as jnp
---> 31 import jax.test_util as jax_test
     32 import numpy as np
     33 import tree as dm_tree

/usr/local/lib/python3.7/dist-packages/jax/test_util.py in <module>()
     33 from . import dtypes as _dtypes
     34 from . import lax
---> 35 from .config import flags, bool_env, config
     36 from ._src.util import partial, prod
     37 from .tree_util import tree_multimap, tree_all, tree_map, tree_reduce

ImportError: cannot import name 'flags' from 'jax.config' (/usr/local/lib/python3.7/dist-packages/jax/config.py)
kingoflolz commented 3 years ago

Cannot reproduce, try factory resetting your colab runtime