Open redbrain opened 9 months ago
This is expected. v4-256 is actually only 128 TPU-v4 chips (weird naming convention due to the fact that the two tensorcores on the same chip are viewed as separate devices before v4), so our OpenLLaMA 7B configuration actually uses v4-512. If you want to train on v4-256, consider using batch size 1024 and --mesh_dim='-1,64,1'
.
This results in NotImplementedError: Failed to find assignment for logical_axis_index 1 of size 64 with remaining assignable mesh [4, 4, 8].
Any clue what went wrong?
It appears that since a v4-256 has half the chips of a v4-512, the appropriate mesh topology would be -1,32,1
. But running it with that mesh and with batch sizes of 1024 and even 512 still produces OOM errors. Any advice on how to fix this?
Oh, this is a known problem that by default JAX does not want to split a physical axis into multiple logical axes. However, we can force it to do that by specifying --mesh_dim='!-1,64,1'
Still not working, even with the parameters you suggested for mesh_dim and batch_size.
This is quite strange. Maybe XLA is not smart enough for allocating memory. In this case I'd recommend tweaking with batch sizes and mesh size. For example, try a even smaller batch size of 512 or using !-1,128,1
as mesh dim.
I was able to run 7B model training on TPU v4-256 with mesh_dim = !-1,16,4
and batch_size = 64
at 115000 tokens per second
Command
```sh python -m EasyLM.models.llama.llama_train \ --mesh_dim='-1,32,1' \ --dtype='fp32' \ --total_steps=250000 \ --log_freq=50 \ --save_model_freq=0 \ --save_milestone_freq=2500 \ --load_llama_config='7b' \ --update_llama_config='' \ --load_dataset_state='' \ --load_checkpoint='' \ --tokenizer.vocab_file='gs://.../tokenizer.model' \ --optimizer.type='adamw' \ --optimizer.adamw_optimizer.weight_decay=0.1 \ --optimizer.adamw_optimizer.lr=3e-4 \ --optimizer.adamw_optimizer.end_lr=3e-5 \ --optimizer.adamw_optimizer.lr_warmup_steps=2000 \ --optimizer.adamw_optimizer.lr_decay_steps=250000 \ --train_dataset.type='json' \ --train_dataset.text_processor.fields='text' \ --train_dataset.json_dataset.path='gs://.../dataset.jsonl' \ --train_dataset.json_dataset.seq_length=2048 \ --train_dataset.json_dataset.batch_size=2048 \ --train_dataset.json_dataset.tokenizer_processes=16 \ --checkpointer.save_optimizer_state=True \ --logger.online=True \ --logger.prefix='devingulliver' \ --logger.project="sl_llama_7b" \ --logger.output_dir="gs://.../output/" \ --logger.wandb_dir="$HOME/experiment_output/sl_llama_7b" ```Log
``` I1008 02:58:09.536914 139894565414912 mesh_utils.py:282] _create_device_mesh_for_nd_torus assignment: [(1,), (0, 2), ()] 0% 0/250000 [02:56, ?it/s] Traceback (most recent call last): File "/home/redbrain/EasyLM/EasyLM/models/llama/llama_train.py", line 267, in