ModuleNotFoundError Traceback (most recent call last)
in <cell line: 1>()
----> 1 import t5x
2 from t5x import partitioning
3 from t5x import train_state as train_state_lib
4 from t5x import utils
5 from t5x.examples.t5 import network
5 frames
/content/t5x/t5x/init.py in
15 """Import API modules."""
16
---> 17 import t5x.adafactor
18 import t5x.checkpoints
19 import t5x.decoding
/content/t5x/t5x/adafactor.py in
63 import jax.numpy as jnp
64 import numpy as np
---> 65 from t5x import utils
66 from t5x.optimizers import OptimizerDef
67 from t5x.optimizers import OptimizerState
/content/t5x/t5x/utils.py in
44 import jax.numpy as jnp
45 import numpy as np
---> 46 import orbax.checkpoint
47 import seqio
48 from t5x import checkpoints
ModuleNotFoundError Traceback (most recent call last) in <cell line: 1>()
----> 1 import t5x
2 from t5x import partitioning
3 from t5x import train_state as train_state_lib
4 from t5x import utils
5 from t5x.examples.t5 import network
5 frames /content/t5x/t5x/init.py in
15 """Import API modules."""
16
---> 17 import t5x.adafactor
18 import t5x.checkpoints
19 import t5x.decoding
/content/t5x/t5x/adafactor.py in
63 import jax.numpy as jnp
64 import numpy as np
---> 65 from t5x import utils
66 from t5x.optimizers import OptimizerDef
67 from t5x.optimizers import OptimizerState
/content/t5x/t5x/utils.py in
44 import jax.numpy as jnp
45 import numpy as np
---> 46 import orbax.checkpoint
47 import seqio
48 from t5x import checkpoints
/usr/local/lib/python3.10/dist-packages/orbax/checkpoint/init.py in
17 import functools
18
---> 19 from orbax.checkpoint import checkpoint_utils
20 from orbax.checkpoint import lazy_utils
21 from orbax.checkpoint import test_utils
/usr/local/lib/python3.10/dist-packages/orbax/checkpoint/checkpoint_utils.py in
23 from jax.sharding import Mesh
24 import numpy as np
---> 25 from orbax.checkpoint import type_handlers
26 from orbax.checkpoint import utils
27
/usr/local/lib/python3.10/dist-packages/orbax/checkpoint/type_handlers.py in
22 from etils import epath
23 import jax
---> 24 from jax.experimental.gda_serialization import serialization
25 from jax.experimental.gda_serialization.serialization import get_tensorstore_spec
26 import jax.numpy as jnp
ModuleNotFoundError: No module named 'jax.experimental.gda_serialization'