Closed sjw8793 closed 8 months ago
That is indeed expected. Those flags are targeting some new features of TPU-v4, which are not available in TPU-v3.
Oh, I see. That must be why I couldn't get any information when I searched about those flags, I guess. Thank you for explanation :)
Hi, thank you a lot for this amazing codebase and wonderful sample scripts. I'm trying to pre-train LLaMA with TPU vm, referring this example script and this issue. My pre-training setting is written below:
Environment
- TPU: `v3-8` - TPU vm software: `tf-vm-base` - Terminal: SSH shell that GCP offers with TPU vm.
Command
```shell ./scripts/tpu_vm_setup.sh export LIBTPU_INIT_ARGS='--xla_jf_spmd_threshold_for_windowed_einsum_mib=0 --xla_tpu_spmd_threshold_for_allgather_cse=10000 --xla_tpu_spmd_rewrite_einsum_with_reshape=true --xla_enable_async_all_gather=true --jax_enable_async_collective_offload=true --xla_tpu_enable_latency_hiding_scheduler=true TPU_MEGACORE=MEGACORE_DENSE' python -m EasyLM.models.llama.llama_train \ --mesh_dim='1,-1,1' \ --dtype='fp32' \ --total_steps=10000 \ --log_freq=50 \ --load_llama_config='1b' \ --update_llama_config='' \ --load_dataset_state='' \ --load_checkpoint='' \ --tokenizer.vocab_file='tokenizer.model' \ --optimizer.type='adamw' \ --optimizer.adamw_optimizer.weight_decay=0.1 \ --optimizer.adamw_optimizer.lr=1e-3 \ --optimizer.adamw_optimizer.end_lr=1e-4 \ --optimizer.adamw_optimizer.lr_warmup_steps=1000 \ --optimizer.adamw_optimizer.lr_decay_steps=10000 \ --train_dataset.type='json' \ --train_dataset.text_processor.fields='text' \ --train_dataset.json_dataset.path='data_sample.jsonl' \ --train_dataset.json_dataset.seq_length=1024 \ --train_dataset.json_dataset.batch_size=64 \ --train_dataset.json_dataset.tokenizer_processes=1 \ --checkpointer.save_optimizer_state=True \ --logger.output_dir="gs://my-bucket/openllama/" \ |& tee $HOME/output.txt ```
This produces these errors:
Accessing retired flag 'jax_enable_async_collective_offload'
Fatal Python error: Floating point exception
Removing flag
--jax_enable_async_collective_offload
solved the error 1(I'm not sure, but the error message didn't appear at least). Error 2 pointed line 60 ofllama_train.py
, and didn't appeared when I unsetLIBTPU_INIT_ARGS
. Is it normal to receive this error message with v3 TPU? If so, is there any method to improve training throughput with v3 TPU?