Closed davidlee321 closed 3 years ago
Did anyone else encounter this error while trying the demo.ipynb on colab?
ImportError Traceback (most recent call last) <ipython-input-9-72cd76e3a907> in <module>() 4 from jax.experimental import maps 5 import numpy as np ----> 6 import optax 7 import transformers 8 6 frames /usr/local/lib/python3.7/dist-packages/optax/__init__.py in <module>() 16 """Optax: composable gradient processing and optimization, in JAX.""" 17 ---> 18 from optax._src.alias import adabelief 19 from optax._src.alias import adagrad 20 from optax._src.alias import adam /usr/local/lib/python3.7/dist-packages/optax/_src/alias.py in <module>() 20 import jax.numpy as jnp 21 ---> 22 from optax._src import combine 23 from optax._src import privacy 24 from optax._src import schedule /usr/local/lib/python3.7/dist-packages/optax/_src/combine.py in <module>() 16 """Flexibly compose gradient transformations.""" 17 ---> 18 from optax._src import transform 19 GradientTransformation = transform.GradientTransformation 20 /usr/local/lib/python3.7/dist-packages/optax/_src/transform.py in <module>() 18 from typing import Any, Callable, NamedTuple, Optional, Sequence, Tuple, Union 19 ---> 20 import chex 21 import jax 22 import jax.numpy as jnp /usr/local/lib/python3.7/dist-packages/chex/__init__.py in <module>() 15 """Chex: Testing made fun, in JAX!""" 16 ---> 17 from chex._src.asserts import assert_axis_dimension 18 from chex._src.asserts import assert_axis_dimension_gt 19 from chex._src.asserts import assert_devices_available /usr/local/lib/python3.7/dist-packages/chex/_src/asserts.py in <module>() 29 import jax 30 import jax.numpy as jnp ---> 31 import jax.test_util as jax_test 32 import numpy as np 33 import tree as dm_tree /usr/local/lib/python3.7/dist-packages/jax/test_util.py in <module>() 33 from . import dtypes as _dtypes 34 from . import lax ---> 35 from .config import flags, bool_env, config 36 from ._src.util import partial, prod 37 from .tree_util import tree_multimap, tree_all, tree_map, tree_reduce ImportError: cannot import name 'flags' from 'jax.config' (/usr/local/lib/python3.7/dist-packages/jax/config.py)
Cannot reproduce, try factory resetting your colab runtime
Did anyone else encounter this error while trying the demo.ipynb on colab?