Everything works well in the main training loop, but I meet errors when it goes into logging_steps:
Traceback (most recent call last):
File "/home/jnguan/code/zett/train.py", line 1605, in <module>
main()
File "/home/jnguan/code/zett/train.py", line 1516, in main
lambda x: x.flatten(), stack_forest(train_metrics)
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jnguan/.miniconda/envs/zett/lib/python3.11/site-packages/flax/training/common_utils.py", line 69, in stack_forest
return jax.tree_util.tree_map(stack_args, *forest)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jnguan/.miniconda/envs/zett/lib/python3.11/site-packages/jax/_src/tree_util.py", line 244, in tree_map
return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jnguan/.miniconda/envs/zett/lib/python3.11/site-packages/jax/_src/tree_util.py", line 244, in <genexpr>
return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
^^^^^^
File "/home/jnguan/.miniconda/envs/zett/lib/python3.11/site-packages/flax/training/common_utils.py", line 68, in <lambda>
stack_args = lambda *args: np.stack(args)
^^^^^^^^^^^^^^
File "/home/jnguan/.miniconda/envs/zett/lib/python3.11/site-packages/numpy/core/shape_base.py", line 443, in stack
arrays = [asanyarray(arr) for arr in arrays]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jnguan/.miniconda/envs/zett/lib/python3.11/site-packages/numpy/core/shape_base.py", line 443, in <listcomp>
arrays = [asanyarray(arr) for arr in arrays]
^^^^^^^^^^^^^^^
File "/home/jnguan/.miniconda/envs/zett/lib/python3.11/site-packages/jax/_src/array.py", line 390, in __array__
return np.asarray(self._value, dtype=dtype)
^^^^^^^^^^^
File "/home/jnguan/.miniconda/envs/zett/lib/python3.11/site-packages/jax/_src/profiler.py", line 336, in wrapper
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/home/jnguan/.miniconda/envs/zett/lib/python3.11/site-packages/jax/_src/array.py", line 588, in _value
if self.is_fully_replicated:
^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jnguan/.miniconda/envs/zett/lib/python3.11/site-packages/jax/_src/array.py", line 354, in is_fully_replicated
return self.sharding.is_fully_replicated
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AttributeError: 'UnspecifiedValue' object has no attribute 'is_fully_replicated'
I tried to train a hypernetwork with English and Chinese dataset, and transfer a bilingual tokenizer for TinyLlama.
My devices are 2 * A100 80G, with CUDA driver version 12.2
My config is:
data/langs.txt
isEverything works well in the main training loop, but I meet errors when it goes into logging_steps:
Full log: zett-142044.log
My environment: