google-research / t5x

Apache License 2.0
2.65k stars 301 forks source link

Installation Error #1512

Closed jntdst closed 6 months ago

jntdst commented 7 months ago

ModuleNotFoundError Traceback (most recent call last) in <cell line: 1>() ----> 1 import t5x 2 from t5x import partitioning 3 from t5x import train_state as train_state_lib 4 from t5x import utils 5 from t5x.examples.t5 import network

5 frames /content/t5x/t5x/init.py in 15 """Import API modules.""" 16 ---> 17 import t5x.adafactor 18 import t5x.checkpoints 19 import t5x.decoding

/content/t5x/t5x/adafactor.py in 63 import jax.numpy as jnp 64 import numpy as np ---> 65 from t5x import utils 66 from t5x.optimizers import OptimizerDef 67 from t5x.optimizers import OptimizerState

/content/t5x/t5x/utils.py in 44 import jax.numpy as jnp 45 import numpy as np ---> 46 import orbax.checkpoint 47 import seqio 48 from t5x import checkpoints

/usr/local/lib/python3.10/dist-packages/orbax/checkpoint/init.py in 17 import functools 18 ---> 19 from orbax.checkpoint import checkpoint_utils 20 from orbax.checkpoint import lazy_utils 21 from orbax.checkpoint import test_utils

/usr/local/lib/python3.10/dist-packages/orbax/checkpoint/checkpoint_utils.py in 23 from jax.sharding import Mesh 24 import numpy as np ---> 25 from orbax.checkpoint import type_handlers 26 from orbax.checkpoint import utils 27

/usr/local/lib/python3.10/dist-packages/orbax/checkpoint/type_handlers.py in 22 from etils import epath 23 import jax ---> 24 from jax.experimental.gda_serialization import serialization 25 from jax.experimental.gda_serialization.serialization import get_tensorstore_spec 26 import jax.numpy as jnp

ModuleNotFoundError: No module named 'jax.experimental.gda_serialization'