openlm-research / open_llama

OpenLLaMA, a permissively licensed open source reproduction of Meta AI’s LLaMA 7B trained on the RedPajama dataset
Apache License 2.0
7.36k stars 374 forks source link

TPU throughput suddenly decrease after 35~40 iterations #28

Closed hxssgaa closed 1 year ago

hxssgaa commented 1 year ago

Hi,

I tried to experiment replicating your pretraining of a 7B LLaMA process on a tpu-v4 VM with EasyLM, but I find out the throughput of the model training suddenly decrease to only 15% of original thoughput after 35\~40 iterations until it waits a few seconds to log the metrics and then the thoughput temporarily increase to normal again before falling to low throughput again after 35\~40 iterations. Then the whole process repeats the problem for every log_freq

I used the following script for pretraining:

python3 -m EasyLM.models.llama.llama_train \
    --total_steps=100000 \
    --mesh_dim='1,1,-1' \
    --load_llama_config='=7b' \
    --tokenizer.vocab_file='./tokenizer.model' \
    --train_dataset.text_processor.fields=text \
    --train_dataset.type=huggingface \
    --train_dataset.huggingface_dataset.path=openwebtext \
    --train_dataset.huggingface_dataset.name=plain_text \
    --train_dataset.huggingface_dataset.batch_size=16 \
    --dtype=bf16

The experiment was done using a single TPU V4-8. Multiple combination of mesh_dim were tuned but all have same issues.

I also attached a SSD disk to the TPU v4 VM to ensure it's not the disk throughput problem. May I know what is your setting for pretraining of the LLaMA 7B model or have you observed similar issues when using TPU-v4, and how do you address the issue if you had the issue?

Thanks!

young-geng commented 1 year ago

What you are observing is a result of the async dispatch of JAX. When you have a training loop, JAX will return the control back to Python even if the training step has not finish running on TPU. However, since we log every 50 iterations, logging the results will force all the async dispatches to complete in order to obtain the training metrics with jax.device_get, which means you will see a much longer time for that iteration.

I should also note that async dispatch is a great feature of JAX, since it enables us to overlap data loading and training for free without the need of explicit implementations like FFCV.

hxssgaa commented 1 year ago

What you are observing is a result of the async dispatch of JAX. When you have a training loop, JAX will return the control back to Python even if the training step has not finish running on TPU. However, since we log every 50 iterations, logging the results will force all the async dispatches to complete in order to obtain the training metrics with jax.device_get, which means you will see a much longer time for that iteration.

I should also note that async dispatch is a great feature of JAX, since it enables us to overlap data loading and training for free without the need of explicit implementations like FFCV.

Many thanks for your detailed explaination, understand now :)