Closed hxssgaa closed 1 year ago
What you are observing is a result of the async dispatch of JAX. When you have a training loop, JAX will return the control back to Python even if the training step has not finish running on TPU. However, since we log every 50 iterations, logging the results will force all the async dispatches to complete in order to obtain the training metrics with jax.device_get
, which means you will see a much longer time for that iteration.
I should also note that async dispatch is a great feature of JAX, since it enables us to overlap data loading and training for free without the need of explicit implementations like FFCV.
What you are observing is a result of the async dispatch of JAX. When you have a training loop, JAX will return the control back to Python even if the training step has not finish running on TPU. However, since we log every 50 iterations, logging the results will force all the async dispatches to complete in order to obtain the training metrics with
jax.device_get
, which means you will see a much longer time for that iteration.I should also note that async dispatch is a great feature of JAX, since it enables us to overlap data loading and training for free without the need of explicit implementations like FFCV.
Many thanks for your detailed explaination, understand now :)
Hi,
I tried to experiment replicating your pretraining of a 7B LLaMA process on a tpu-v4 VM with EasyLM, but I find out the throughput of the model training suddenly decrease to only 15% of original thoughput after 35\~40 iterations until it waits a few seconds to log the metrics and then the thoughput temporarily increase to normal again before falling to low throughput again after 35\~40 iterations. Then the whole process repeats the problem for every
log_freq
I used the following script for pretraining:
The experiment was done using a single TPU V4-8. Multiple combination of
mesh_dim
were tuned but all have same issues.I also attached a SSD disk to the TPU v4 VM to ensure it's not the disk throughput problem. May I know what is your setting for pretraining of the LLaMA 7B model or have you observed similar issues when using TPU-v4, and how do you address the issue if you had the issue?
Thanks!