bminixhofer / zett

Code for Zero-Shot Tokenizer Transfer
https://arxiv.org/abs/2405.07883
115 stars 8 forks source link

Error when training a hypernetwork #6

Open jubgjf opened 5 months ago

jubgjf commented 5 months ago

I tried to train a hypernetwork with English and Chinese dataset, and transfer a bilingual tokenizer for TinyLlama.

My devices are 2 * A100 80G, with CUDA driver version 12.2

My config is:

{
    "output_dir": "output-debug",
    "train_directory": "data/train",
    "valid_directory": "data/valid",
    "langs": "data/langs.txt",
    "model_name_or_path": "TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T",
    "revision": "refs/pr/8",
    "loss": "clm",
    "n_embd": 2048,
    "n_token_subsample": null,
    "random_warmup_steps": 0,
    "identity_n_subsample": 16384,
    "identity_steps": 0,
    "warmup_steps": [
        10000
    ],
    "steps": 200000,
    "dtype": "bfloat16",
    "use_unigram_bias": true,
    "learning_rate": [
        6e-5
    ],
    "max_grad_norm": 0.1,
    "extra_valid_tokenizer_names": [
        "models/TinyLlama-1.1B-intermediate-step-1431k-3T-Ext"
    ],
    "extra_valid_files": [
        "data/valid/en.parquet",
        "data/valid/zh.parquet"
    ],
    "extra_lang_codes": [
        "en",
        "zh"
    ],
    "n_valid_subsample": 4000,
    "do_tokenizer_sampling": true,
    "hn_rescale_embeddings": true,
    "hn_surface_maxlen": 15,
    "tokenizer_sample_mean": 32768,
    "tokenizer_sample_max": 32768,
    "tokenizer_sample_std": 0,
    "tokenizer_batch_size": 32,
    "weight_decay": 0.01,
    "adam_beta2": 0.95,
    "hn_model_name_or_path": "roberta-base",
    "tokenizer_noise_mean": 1e-5,
    "tokenizer_noise_std": 4,
    "hn_embed_lang_id": true,
    "hn_add_inter_token_attention": false,
    "hn_embed_target_priors": false,
    "hn_inter_token_attention_bias_by_priors": true,
    "hn_embed_using_source_embeddings": true,
    "train_batch_size": 2,
    "eval_batch_size": 2,
    "hn_hidden_size": 2048,
    "hn_intermediate_size": 4096,
    "gradient_accumulation_steps": 1,
    "learnable_bias": false,
    "add_target_priors_to_bias": false,
    "lexical_loss_weight": 0.5,
    "debug": false,
    "dataloader_num_workers": 64,
    "mix_languages": false,
    "logging_steps": 10
}

data/langs.txt is

en,1
zh,3

Everything works well in the main training loop, but I meet errors when it goes into logging_steps:

Traceback (most recent call last):
  File "/home/jnguan/code/zett/train.py", line 1605, in <module>
    main()
  File "/home/jnguan/code/zett/train.py", line 1516, in main
    lambda x: x.flatten(), stack_forest(train_metrics)
                           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jnguan/.miniconda/envs/zett/lib/python3.11/site-packages/flax/training/common_utils.py", line 69, in stack_forest
    return jax.tree_util.tree_map(stack_args, *forest)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jnguan/.miniconda/envs/zett/lib/python3.11/site-packages/jax/_src/tree_util.py", line 244, in tree_map
    return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jnguan/.miniconda/envs/zett/lib/python3.11/site-packages/jax/_src/tree_util.py", line 244, in <genexpr>
    return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
                             ^^^^^^
  File "/home/jnguan/.miniconda/envs/zett/lib/python3.11/site-packages/flax/training/common_utils.py", line 68, in <lambda>
    stack_args = lambda *args: np.stack(args)
                               ^^^^^^^^^^^^^^
  File "/home/jnguan/.miniconda/envs/zett/lib/python3.11/site-packages/numpy/core/shape_base.py", line 443, in stack
    arrays = [asanyarray(arr) for arr in arrays]
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jnguan/.miniconda/envs/zett/lib/python3.11/site-packages/numpy/core/shape_base.py", line 443, in <listcomp>
    arrays = [asanyarray(arr) for arr in arrays]
              ^^^^^^^^^^^^^^^
  File "/home/jnguan/.miniconda/envs/zett/lib/python3.11/site-packages/jax/_src/array.py", line 390, in __array__
    return np.asarray(self._value, dtype=dtype)
                      ^^^^^^^^^^^
  File "/home/jnguan/.miniconda/envs/zett/lib/python3.11/site-packages/jax/_src/profiler.py", line 336, in wrapper
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/jnguan/.miniconda/envs/zett/lib/python3.11/site-packages/jax/_src/array.py", line 588, in _value
    if self.is_fully_replicated:
       ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jnguan/.miniconda/envs/zett/lib/python3.11/site-packages/jax/_src/array.py", line 354, in is_fully_replicated
    return self.sharding.is_fully_replicated
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

AttributeError: 'UnspecifiedValue' object has no attribute 'is_fully_replicated'

Full log: zett-142044.log

My environment:

Package                  Version
------------------------ -----------
absl-py                  2.1.0
accelerate               0.30.1
aiohttp                  3.9.5
aiosignal                1.3.1
appdirs                  1.4.4
attrs                    23.2.0
certifi                  2024.2.2
charset-normalizer       3.3.2
chex                     0.1.86
click                    8.1.7
cmake                    3.29.3
contourpy                1.2.1
cycler                   0.12.1
datasets                 2.19.1
dill                     0.3.8
docker-pycreds           0.4.0
etils                    1.8.0
filelock                 3.14.0
flax                     0.8.0
fonttools                4.52.4
frozenlist               1.4.1
fsspec                   2024.5.0
gitdb                    4.0.11
GitPython                3.1.43
h5py                     3.8.0
huggingface-hub          0.23.2
idna                     3.7
importlib_resources      6.4.0
jax                      0.4.23
jax-cuda12-pjrt          0.4.23
jax-cuda12-plugin        0.4.23
jaxlib                   0.4.23
Jinja2                   3.1.4
joblib                   1.4.2
kiwisolver               1.4.5
lit                      18.1.6
markdown-it-py           3.0.0
MarkupSafe               2.1.5
matplotlib               3.7.2
maturin                  1.3.0
mdurl                    0.1.2
ml-dtypes                0.4.0
mpmath                   1.3.0
msgpack                  1.0.8
multidict                6.0.5
multiprocess             0.70.16
nest-asyncio             1.6.0
networkx                 3.3
numpy                    1.26.4
nvidia-cublas-cu11       11.10.3.66
nvidia-cublas-cu12       12.1.3.1
nvidia-cuda-cupti-cu11   11.7.101
nvidia-cuda-cupti-cu12   12.1.105
nvidia-cuda-nvcc-cu12    12.5.40
nvidia-cuda-nvrtc-cu11   11.7.99
nvidia-cuda-nvrtc-cu12   12.1.105
nvidia-cuda-runtime-cu11 11.7.99
nvidia-cuda-runtime-cu12 12.1.105
nvidia-cudnn-cu11        8.5.0.96
nvidia-cudnn-cu12        8.9.2.26
nvidia-cufft-cu11        10.9.0.58
nvidia-cufft-cu12        11.0.2.54
nvidia-curand-cu11       10.2.10.91
nvidia-curand-cu12       10.3.2.106
nvidia-cusolver-cu11     11.4.0.1
nvidia-cusolver-cu12     11.4.5.107
nvidia-cusparse-cu11     11.7.4.91
nvidia-cusparse-cu12     12.1.0.106
nvidia-nccl-cu11         2.14.3
nvidia-nccl-cu12         2.20.5
nvidia-nvjitlink-cu12    12.5.40
nvidia-nvtx-cu11         11.7.91
nvidia-nvtx-cu12         12.1.105
opt-einsum               3.3.0
optax                    0.1.5
orbax-checkpoint         0.5.14
packaging                24.0
pandas                   2.0.3
pathtools                0.1.2
pillow                   10.3.0
pip                      24.0
protobuf                 4.25.3
psutil                   5.9.8
pyahocorasick            2.0.0
pyarrow                  16.1.0
pyarrow-hotfix           0.6
Pygments                 2.18.0
pyparsing                3.0.9
python-dateutil          2.9.0.post0
pytz                     2024.1
PyYAML                   6.0.1
regex                    2024.5.15
requests                 2.32.3
rich                     13.7.1
rust_utils               0.14.1.dev0
safetensors              0.4.3
scikit-learn             1.4.2
scipy                    1.10.1
sentry-sdk               2.3.1
setproctitle             1.3.3
setuptools               69.5.1
six                      1.16.0
smmap                    5.0.1
sympy                    1.12.1
tensorstore              0.1.60
threadpoolctl            3.5.0
tokenizers               0.19.1
toolz                    0.12.1
torch                    2.3.0
tqdm                     4.66.4
transformers             4.41.1
triton                   2.3.0
typing_extensions        4.12.0
tzdata                   2024.1
urllib3                  2.2.1
wandb                    0.15.4
wheel                    0.43.0
xxhash                   3.4.1
yarl                     1.9.4
zipp                     3.19.0
kdcyberdude commented 5 months ago

Hi @jubgjf, can you try branch mentioned in this - https://github.com/bminixhofer/zett/issues/8