young-geng / EasyLM

Large language models (LLMs) made easy, EasyLM is a one stop solution for pre-training, finetuning, evaluating and serving LLMs in JAX/Flax.
Apache License 2.0
2.33k stars 247 forks source link

TPU specific flags produce errors #102

Closed sjw8793 closed 8 months ago

sjw8793 commented 8 months ago

Hi, thank you a lot for this amazing codebase and wonderful sample scripts. I'm trying to pre-train LLaMA with TPU vm, referring this example script and this issue. My pre-training setting is written below:

Environment

- TPU: `v3-8` - TPU vm software: `tf-vm-base` - Terminal: SSH shell that GCP offers with TPU vm.

Command

```shell ./scripts/tpu_vm_setup.sh export LIBTPU_INIT_ARGS='--xla_jf_spmd_threshold_for_windowed_einsum_mib=0 --xla_tpu_spmd_threshold_for_allgather_cse=10000 --xla_tpu_spmd_rewrite_einsum_with_reshape=true --xla_enable_async_all_gather=true --jax_enable_async_collective_offload=true --xla_tpu_enable_latency_hiding_scheduler=true TPU_MEGACORE=MEGACORE_DENSE' python -m EasyLM.models.llama.llama_train \ --mesh_dim='1,-1,1' \ --dtype='fp32' \ --total_steps=10000 \ --log_freq=50 \ --load_llama_config='1b' \ --update_llama_config='' \ --load_dataset_state='' \ --load_checkpoint='' \ --tokenizer.vocab_file='tokenizer.model' \ --optimizer.type='adamw' \ --optimizer.adamw_optimizer.weight_decay=0.1 \ --optimizer.adamw_optimizer.lr=1e-3 \ --optimizer.adamw_optimizer.end_lr=1e-4 \ --optimizer.adamw_optimizer.lr_warmup_steps=1000 \ --optimizer.adamw_optimizer.lr_decay_steps=10000 \ --train_dataset.type='json' \ --train_dataset.text_processor.fields='text' \ --train_dataset.json_dataset.path='data_sample.jsonl' \ --train_dataset.json_dataset.seq_length=1024 \ --train_dataset.json_dataset.batch_size=64 \ --train_dataset.json_dataset.tokenizer_processes=1 \ --checkpointer.save_optimizer_state=True \ --logger.output_dir="gs://my-bucket/openllama/" \ |& tee $HOME/output.txt ```

This produces these errors:

  1. Accessing retired flag 'jax_enable_async_collective_offload'
  2. Fatal Python error: Floating point exception

Removing flag --jax_enable_async_collective_offload solved the error 1(I'm not sure, but the error message didn't appear at least). Error 2 pointed line 60 of llama_train.py, and didn't appeared when I unset LIBTPU_INIT_ARGS. Is it normal to receive this error message with v3 TPU? If so, is there any method to improve training throughput with v3 TPU?

young-geng commented 8 months ago

That is indeed expected. Those flags are targeting some new features of TPU-v4, which are not available in TPU-v3.

sjw8793 commented 8 months ago

Oh, I see. That must be why I couldn't get any information when I searched about those flags, I guess. Thank you for explanation :)