LargeWorldModel / LWM

Large World Model -- Modeling Text and Video with Millions Context
https://largeworldmodel.github.io/
Apache License 2.0
7.15k stars 552 forks source link

ValueError: bytes is too large when running scripts/run_train_text.sh #48

Open liu21yd opened 8 months ago

liu21yd commented 8 months ago

Detail message: Traceback (most recent call last): File "/user_work_path/miniconda3/envs/lwm/lib/python3.10/runpy.py", line 196, in _run_module_as_main return _run_code(code, main_globals, None, File "/user_work_path/miniconda3/envs/lwm/lib/python3.10/runpy.py", line 86, in _run_code exec(code, run_globals) File "/user_work_path/train/codes/LWM/lwm/train.py", line 396, in run(main) File "/user_work_path/miniconda3/envs/lwm/lib/python3.10/site-packages/absl/app.py", line 308, in run _run_main(main, args) File "/user_work_path/miniconda3/envs/lwm/lib/python3.10/site-packages/absl/app.py", line 254, in _run_main sys.exit(main(argv)) File "/user_work_path/train/codes/LWM/lwm/train.py", line 387, in main save_checkpoint(train_state, milestone=True) File "/user_work_path/train/codes/LWM/lwm/train.py", line 325, in save_checkpoint checkpointer.save_all( File "/user_work_path/miniconda3/envs/lwm/lib/python3.10/site-packages/tux/checkpoint.py", line 102, in save_all self.save_checkpoint( File "/user_work_path/miniconda3/envs/lwm/lib/python3.10/site-packages/tux/checkpoint.py", line 46, in save_checkpoint self.save_train_state_to_file( File "/user_work_path/miniconda3/envs/lwm/lib/python3.10/site-packages/tux/checkpoint.py", line 78, in save_train_state_to_file fout.write(packer.pack((key, to_bytes(value)))) File "/user_work_path/miniconda3/envs/lwm/lib/python3.10/site-packages/msgpack/fallback.py", line 826, in pack self._pack(obj) File "/user_work_path/miniconda3/envs/lwm/lib/python3.10/site-packages/msgpack/fallback.py", line 803, in _pack self._pack(obj[i], nest_limit - 1) File "/user_work_path/miniconda3/envs/lwm/lib/python3.10/site-packages/msgpack/fallback.py", line 750, in _pack raise ValueError("%s is too large" % type(obj).name) ValueError: bytes is too large

train script

! /bin/bash

export SCRIPT_DIR="$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )" export PROJECT_DIR="$( cd -- "$( dirname -- "$SCRIPT_DIR" )" &> /dev/null && pwd )" cd $PROJECT_DIR export PYTHONPATH="$PYTHONPATH:$PROJECT_DIR"

export LLAMA_TOKENIZER_PATH=/user_work_path/tokenizer.model export DATASET_PATH=/user_work_path/sample.jsonl export SEED=1025

export PROJECT_ID='lwm' export EXPERIMENT_NOTE='' export EXPERIMENT_ID='example-text-train' export OUTPUT_DIR=${PROJECT_DIR}/output

export COORDINATOR_ADDRESS=localhost:12345 export NUM_PROCESSES=1 export PROCESS_ID=0 export INITIALIZE_JAX_DISTRIBUTED=true

python3 -u -m lwm.train \ --jax_distributed.coordinator_address ${COORDINATOR_ADDRESS} \ --jax_distributed.initialize_jax_distributed ${INITIALIZE_JAX_DISTRIBUTED} \ --jax_distributed.num_processes ${NUM_PROCESSES} \ --jax_distributed.process_id ${PROCESS_ID} \ --modality='text' \ --mesh_dim='1,1,1,8' \ --dtype='bf16' \ --seed=${SEED} \ --total_steps=10\ --log_freq=1 \ --save_model_freq=0 \ --save_milestone_freq=5 \ --load_llama_config='13b' \ --update_llama_config="dict(theta=10000,max_sequence_length=4096,use_flash_attention=False,scan_attention=True,scan_query_chunk_size=512,scan_key_chunk_size=1024,remat_attention='nothing_saveable',scan_mlp=True,scan_mlp_chunk_size=8192,remat_mlp='nothing_saveable',remat_block='nothing_saveable',scan_layers=True)" \ --tokenizer.vocab_file="$LLAMA_TOKENIZER_PATH" \ --optimizer.type='adamw' \ --optimizer.accumulate_gradient_steps=1 \ --optimizer.adamw_optimizer.weight_decay=0.1 \ --optimizer.adamw_optimizer.lr=8e-5 \ --optimizer.adamw_optimizer.end_lr=8e-5 \ --optimizer.adamw_optimizer.lr_warmup_steps=5 \ --optimizer.adamw_optimizer.lr_decay_steps=200 \ --use_data_sharded_loader=True \ --train_dataset.type='json' \ --train_dataset.text_processor.fields='text' \ --train_dataset.json_dataset.path="$DATASET_PATH" \ --train_dataset.json_dataset.seq_length=1024 \ --train_dataset.json_dataset.batch_size=8 \ --train_dataset.json_dataset.tokenizer_processes=4 \ --train_dataset.json_dataset.tokenizer_parallel_chunk_size=2 \ --train_dataset.json_dataset.tokenizer_parallel_batch_size=8 \ --train_dataset.json_dataset.use_data_sharded_loader=True \ --checkpointer.save_optimizer_state=True \ --autoresume=False \ --logger.append_uuid=False \ --logger.online=False \ --logger.project_id="$PROJECT_ID" \ --logger.experiment_id="$EXPERIMENT_ID" \ --logger.experiment_note="$EXPERIMENT_NOTE" \ --logger.output_dir="$OUTPUT_DIR" \ --logger.wandb_dir="$HOME/experiment_output/$PROJECT_ID" read

environment Package Version


absl-py 2.1.0 aiohttp 3.9.3 aiosignal 1.3.1 appdirs 1.4.4 asttokens 2.4.1 async-timeout 4.0.3 attrs 23.2.0 cachetools 5.3.3 certifi 2024.2.2 charset-normalizer 3.3.2 chex 0.1.82 click 8.1.7 cloudpickle 3.0.0 contextlib2 21.6.0 datasets 2.17.1 decorator 5.1.1 decord 0.6.0 dill 0.3.8 docker-pycreds 0.4.0 einops 0.7.0 etils 1.7.0 exceptiongroup 1.2.0 executing 2.0.1 filelock 3.13.1 flax 0.7.0 frozenlist 1.4.1 fsspec 2023.10.0 gcsfs 2023.10.0 gitdb 4.0.11 GitPython 3.1.42 google-api-core 2.17.1 google-auth 2.28.1 google-auth-oauthlib 1.2.0 google-cloud-core 2.4.1 google-cloud-storage 2.14.0 google-crc32c 1.5.0 google-resumable-media 2.7.0 googleapis-common-protos 1.62.0 huggingface-hub 0.20.3 idna 3.6 imageio 2.34.0 imageio-ffmpeg 0.4.9 importlib_resources 6.1.2 ipdb 0.13.13 ipython 8.22.1 jax 0.4.23 jaxlib 0.4.23+cuda12.cudnn89 jedi 0.19.1 markdown-it-py 3.0.0 matplotlib-inline 0.1.6 mdurl 0.1.2 ml-collections 0.1.1 ml-dtypes 0.3.2 msgpack 1.0.7 multidict 6.0.5 multiprocess 0.70.16 nest-asyncio 1.6.0 numpy 1.26.4 nvidia-cublas-cu12 12.3.4.1 nvidia-cuda-cupti-cu12 12.3.101 nvidia-cuda-nvcc-cu12 12.3.107 nvidia-cuda-nvrtc-cu12 12.3.107 nvidia-cuda-runtime-cu12 12.3.101 nvidia-cudnn-cu12 8.9.7.29 nvidia-cufft-cu12 11.0.12.1 nvidia-cusolver-cu12 11.5.4.101 nvidia-cusparse-cu12 12.2.0.103 nvidia-nccl-cu12 2.19.3 nvidia-nvjitlink-cu12 12.3.101 oauthlib 3.2.2 opt-einsum 3.3.0 optax 0.1.7 orbax-checkpoint 0.5.3 packaging 23.2 pandas 2.2.1 parso 0.8.3 pexpect 4.9.0 pillow 10.2.0 pip 23.3.1 prompt-toolkit 3.0.43 protobuf 4.25.3 psutil 5.9.8 ptyprocess 0.7.0 pure-eval 0.2.2 pyarrow 15.0.0 pyarrow-hotfix 0.6 pyasn1 0.5.1 pyasn1-modules 0.3.0 Pygments 2.17.2 python-dateutil 2.8.2 pytz 2024.1 PyYAML 6.0.1 regex 2023.12.25 requests 2.31.0 requests-oauthlib 1.3.1 rich 13.7.0 rsa 4.9 scipy 1.12.0 sentencepiece 0.2.0 sentry-sdk 1.40.5 setproctitle 1.3.3 setuptools 68.2.2 six 1.16.0 smmap 5.0.1 stack-data 0.6.3 tensorstore 0.1.54 tiktoken 0.6.0 tokenizers 0.13.3 tomli 2.0.1 toolz 0.12.1 tqdm 4.66.2 traitlets 5.14.1 transformers 4.29.2 tux 0.0.2 typing_extensions 4.10.0 tzdata 2024.1 urllib3 2.2.1 wandb 0.16.3 wcwidth 0.2.13 wheel 0.41.2 xxhash 3.4.1 yarl 1.9.4 zipp 3.17.0

Could some one help me?

Het-Joshi commented 5 months ago

modify your save_train_state_to_file function to include logging and handle large objects:


# Add logging to identify large objects
import logging

def save_train_state_to_file(self, key, value, fout):
    try:
        packed_value = to_bytes(value)
        logging.info(f"Saving key: {key}, size: {len(packed_value)} bytes")
        fout.write(packer.pack((key, packed_value)))
    except ValueError as e:
        logging.error(f"Error saving key: {key}, size: {len(packed_value)} bytes")
        # Handle the error, potentially by splitting the object or taking other actions
        raise e

If value is too large, consider a method to split it into smaller chunks. Here's a simplified example:


import msgpack

def save_train_state_to_file(self, key, value, fout):
    try:
        packed_value = to_bytes(value)
        if len(packed_value) > SOME_THRESHOLD:  # Define a threshold based on your needs
            logging.info(f"Splitting large object: {key}")
            chunks = [packed_value[i:i + CHUNK_SIZE] for i in range(0, len(packed_value), CHUNK_SIZE)]
            for idx, chunk in enumerate(chunks):
                fout.write(packer.pack((f"{key}_part_{idx}", chunk)))
        else:
            fout.write(packer.pack((key, packed_value)))
    except ValueError as e:
        logging.error(f"Error saving key: {key}, size: {len(packed_value)} bytes")
        raise e

pip install --upgrade msgpack