huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
134.3k stars 26.85k forks source link

Running the run_mlm_flax on TPU v4 pods #20252

Open peregilk opened 1 year ago

peregilk commented 1 year ago

System Info

transformers 4.24.0

Who can help?

@patil-suraj

I am having problems scaling the run_mlm_flax scripts so that they run on TPU VM v4 Pods (ie the v4-16, v4-32 etc). When running "out of the box", the performance is exactly the same as when running on a v4-8. To me this indicates that I am feeding a lot of empty data. The max per_device_train_batch_size for 512 sequences in RoBERTa is 62 in both cases, but since the output is identical, it is obviously not scaling.

From trying to understand the code, it seems to be logical to multiply the batch size here with the jax.process_count() (src example). However, this does not seem to be the way to approach it.

Any ideas about how to approach this? Is the script tested on v4s?

Information

Tasks

Reproduction

See explanation above.

Expected behavior

Expect the batch size to scale automatically.

sgugger commented 1 year ago

Also cc @sanchit-gandhi

sanchit-gandhi commented 1 year ago

Hey @peregilk! Cool to see that you're using the Flax training scripts! Nice that you have TPU v4 pods as well 🚀

The scripts are only tested on single TPU devices (i.e. TPU v2-8, v3-8 and v4-8), however they can be made to work in a multi-host set-up.

How are you launching the script on a TPU v4-16/32? Are you SSH'd into worker 0? You'll need to launch the same command on all 2/4 TPU workers for a v4-16/32 respectively.

peregilk commented 1 year ago

Hi @sanchit-gandhi. I am running a slightly altered version of the scripts, based on the run_mlm_stream.py. I am both installing the software and starting the training simultaneously on all the TPU VMs. I am using a script Ive made (ttconnect) for experiments like this.

The script runs also without any issues. Both on individual TPUs and on any sized pods. However, the result from training on a TPU v4-8 and on a TPU Pod v4-32 is exactly the same. Meaning the loss is the same, the training time is the same, etc. I really want the batches to scale across the pods. I am doing additional training of XLM-RoBERTa here, and it is trained with batch sizes around 3k. Then you need multiple TPUs. I want to increase batch size, not speed. My theory is that currently the batches do not span across the TPUs.

I made an attempt to simply multiplying the batch size in the script with jax.process_count(). That did not work.

sanchit-gandhi commented 1 year ago

Hey @peregilk,

Thanks for sharing those details. Your set-up looks good - the script you've made with ttconnect is super nice! The important thing is to run the same command across devices, which you are doing with that set-up.

The behaviour you have described seems to suggest that you're replicating exactly the same training across all four of your TPU devices. The batch size should scale with number of TPU devices to give you appropriate data parallelism: https://github.com/huggingface/transformers/blob/4bb07647504a277398856e828fa48ddbec97678e/examples/flax/language-modeling/run_mlm_flax.py#L654

Could you verify that the number of devices is indeed 32?

import jax
print(jax.device_count())
peregilk commented 1 year ago

This is returning 16 on the v4-32. This is correct according to the user guide since the v4 have 4 double chips. Could that be the cause of any problems?

Multiplying by jax.device_count() as I suggested is then definitively wrong.

FYI: I did run this code both with v3-8 and with v4-8. I then did double my per_device_batch_size before getting OOM errors.

sanchit-gandhi commented 1 year ago

Okay, this could well be part of the problem! Could you try printing out all the different calls from this sub-section of the guides on pmap (except pmap) https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap:

import jax

print(jax.devices())
print(jax.local_devices())
...
print(jax.process_count())

Just to see what the right one is!

peregilk commented 1 year ago

The TPU v4-32 returns the following

jax.devices():
[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0), TpuDevice(id=1, process_index=0, coords=(1,0,0), core_on_chip=0), TpuDevice(id=2, process_index=0, coords=(0,1,0), core_on_chip=0), TpuDevice(id=3, process_index=0, coords=(1,1,0), core_on_chip=0), TpuDevice(id=4, process_index=1, coords=(0,0,1), core_on_chip=0), TpuDe
vice(id=5, process_index=1, coords=(1,0,1), core_on_chip=0), TpuDevice(id=6, process_index=1, coords=(0,1,1), core_on_chip=0), TpuDevice(id=7, process_index=1, coords=(1,1,1), core_on_chip=0), TpuDevice(id=8, process_index=2, coords=(0,0,2), core_on_chip=0), TpuDevice(id=9, process_index=2, coords=(1,0,2), core_on_chip=0), TpuDevice(i
d=10, process_index=2, coords=(0,1,2), core_on_chip=0), TpuDevice(id=11, process_index=2, coords=(1,1,2), core_on_chip=0), TpuDevice(id=12, process_index=3, coords=(0,0,3), core_on_chip=0), TpuDevice(id=13, process_index=3, coords=(1,0,3), core_on_chip=0), TpuDevice(id=14, process_index=3, coords=(0,1,3), core_on_chip=0), TpuDevice(id=15, process_index=3, coords=(1,1,3), core_on_chip=0)]

jax.local_devices():
[TpuDevice(id=12, process_index=3, coords=(0,0,3), core_on_chip=0), TpuDevice(id=13, process_index=3, coords=(1,0,3), core_on_chip=0), TpuDevice(id=14, process_index=3, coords=(0,1,3), core_on_chip=0), TpuDevice(id=15, process_index=3, coords=(1,1,3), core_on_chip=0)]

jax.process_index():
worker-0: 1
worker-1: 2
worker-2: 3
worker-3: 4

jax.device_count():
16

jax.local_device_count():
4

jax.process_count():
4

The TPU v4-8 returns the following:

jax.devices():
[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0), TpuDevice(id=1, process_index=0, coords=(1,0,0), core_on_chip=0), TpuDevice(id=2, process_index=0, coords=(0,1,0), core_on_chip=0), TpuDevice(id=3, process_index=0, coords=(1,1,0), core_on_chip=0)]

jax.local_devices():
[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0), TpuDevice(id=1, process_index=0, coords=(1,0,0), core_on_chip=0), TpuDevice(id=2, process_index=0, coords=(0,1,0), core_on_chip=0), TpuDevice(id=3, process_index=0, coords=(1,1,0), core_on_chip=0)]

jax.process_index():
0

jax.device_count():
4

jax.local_device_count():
4

jax_process_count():
1

More info:

jax.print_environment_info()                                                                          
jax:    0.3.23                                                                                            
jaxlib: 0.3.22                                                                                            
numpy:  1.22.4                                                                                            
python: 3.8.10 (default, Jun 22 2022, 20:18:18)  [GCC 9.4.0]                                              
jax.devices (16 total, 4 local): [TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0) TpuDevice(id=1, process_index=0, coords=(1,0,0), core_on_chip=0) ... TpuDevice(id=14, process_index=3, coords=(0,1,3), core_on_chip=0) TpuDevice(id=15, process_index=3, coords=(1,1,3), core_on_chip=0)]                 
process_count: 4    
peregilk commented 1 year ago

@sanchit-gandhi I found something very interesting that might be the source of most of my confusion here. When inserting a breakpoint into my code here: breakpoint

I notice that the value of jax.device_count() actually is 4(!!), and the jax_process_count() returns 1. Starting python from the command line, importing jax, and then printing the same "jax.device_count()", the value is 16.

I do not have time to dig more into this right now. Just thought that I should mention this in case you decide to look more into this.

peregilk commented 1 year ago

@sanchit-gandhi I think I have been able to isolate the problem. This can be run directly from the command line on a v4-32:

>>> import jax        
>>> jax.device_count()
16  

However, importing TrainingArguments seem to change the number of visible devices:

>>> import jax                                 
>>> from transformers import TrainingArguments 
>>> jax.device_count()                         
4      

I can not see why this should happen. I also see the following error that might give a hint about what is going on:

>>> import jax                                             
>>> jax.device_count()                                     
16                                                         
>>> from transformers import TrainingArguments             
[percpu.cc : 557] RAW: rseq syscall failed with errno 22   
>>> from transformers import TrainingArguments             
>>> jax.device_count()                                     
16  
sanchit-gandhi commented 1 year ago

Great job at tracing it to a JAX-Transformers interaction! That's super weird - does this happen with just TrainingArguments, or other Transformers modules too (i.e. AutoConfig)? Does swapping the order of your imports change the behaviour?

>>> from transformers import TrainingArguments 
>>> import jax                                 
>>> jax.device_count()

(we need to get a TPU v4 to test these issues!)

peregilk commented 1 year ago

Seem to be a bit of a Schrödinger's cat-problem. Whether you look at it determines if it is dead...;) "Looking" at jax.device_count() (that probably activates the device) seems to let you import TrainingArguments without breaking the pods.

Switching transformers and jax imports does not help. It still reports 4.

I think I have tried all the other transformer modules, and I have not been able to reproduce this with any of them.

skye commented 1 year ago

I'm not able to reproduce this. Running on a v4-16:

In [1]: import jax

In [2]: from transformers import TrainingArguments

In [3]: jax.device_count()
Out[3]: 8

(v4-16 = 8 chips = 8 jax devices)

@peregilk can you share your jax, jaxlib, libtpu-nightly, and transformers versions? Also make sure you're creating the TPUv4 with --version=tpu-vm-v4-base

peregilk commented 1 year ago

Thanks @skye. For reference, in the reported error I was using --runtime-version=v2-alpha-tpuv4-pod with the following libraries.

jax:    0.3.23                                                                                            
jaxlib: 0.3.22
libtpu-nightly: 0.1.dev20221109                                                                                       
transformers: 4.24.0

Not reported above, when debugging I actually also tried using --runtime-version=tpu-vm-v4-base but did get:

In [1]: import jax
In [2]: jax.device_count()
Out[2]: 4

I might have done a mistake when creating this pod. I will try from scratch again using --runtime-version=tpu-vm-v4-base.

Thanks.

sanchit-gandhi commented 1 year ago

Thank you @skye! 🙌

skye commented 1 year ago

Ah yeah, v2-alpha-tpuv4-pod confusingly was only for running TF on a pod slice, and would prevent jax from running across the slice. So that explains it. You should always use tpu-vm-v4-base with jax now (or tpu-vm-base for v2 and v3).

You can always check the Cloud TPU docs for the latest gcloud commands (I like https://cloud.google.com/tpu/docs/run-calculation-jax and https://cloud.google.com/tpu/docs/jax-pods). I understand it's hard to know when things change; hopefully they won't change very frequently moving forward :)

peregilk commented 1 year ago

Thanks a lot @skye! I can now see the devices after loading Transformers. I have also verified that is calculates the batch size correctly: train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count()

With per_device_train_batch_size=62 on a v4-8, this means batch_size=248. This runs on the single TPU.

On a v4-32 this becomes batch_size=992. Here I am still getting OOM-errors. I also reduced the batch size but I still get OOM errors.

Are there any other changes that needs to be done here?

skye commented 1 year ago

I'm not very familiar with using transformers, but you may need to use jax.local_device_count() instead of jax.device_count() somewhere? See https://jax.readthedocs.io/en/latest/multi_process.html. Let me know if you still have questions, this can be tricky!

peregilk commented 1 year ago

Thanks a lot @skye and @sanchit-gandhi for assisting in this. Really useful comments. It seems like splitting between the nodes simply isnt implemented in the code I am using. @agemagician actually implemented this in pull #16527 but it is only added to run_mlm_flax_t5.py. It is not implemented for the other run_mlm-scripts and not in run_mlm_flax_streaming.py that is the one I am using.

I can make a pull request to the other scripts, basically doing this change. However, there is one remaining issue that needs to be resolved first.

For me (at least when I am using the streaming script), this turns out being extremely slow on the pods. Here is a speed comparison. All running seq_length=512 and per_device_train_batch_size=56.

device batch_size seconds per iteration
v4-8 224 1
v4-64 1792 32
v4-128 3584 220

Currently this is way too slow to do real training. I have not been able to test this on the non-streaming scripts, and have not done any attempts at trying to understand where the slowdown is. Maybe any of you have theories about what could be wrong here? It is also worth noting that starting up training (initialising from a pretrained checkpoint) typically takes 4-5 hours (same time for both single TPUs and pods). This is however not a showstopper for doing pretraining.

peregilk commented 1 year ago

@sanchit-gandhi I have not been able to fix this yet, but I think that I at least have been able to pin down the bottleneck here.

This iteration is extremely slow. The entire iteration takes a couple of minutes per training step. Not sure why it is so slow though, and I do not see why "id" and "text" are excluded here. The grouping is done differently in the non-streaming dataset and these scripts seem to run a lot faster.

This actually also turns out to be the reason for the long startup time. The entire evaluation set is pre-tokenized and grouped, and then iterated over. With 50k steps in the evaluation set, this takes several hours. When reducing the eval set to just a few samples, the startup is almost instant.

peregilk commented 1 year ago

@sanchit-gandhi: I now have a working version that runs decently fast on the pods! I am down from 220 sec/it to around 10s/it on a v4-128.

I made the following change to the streaming code:

# samples = {
#    k: samples[k] + tokenized_samples[k] for k in ["input_ids", "attention_mask", "special_tokens_mask"]
# }
samples["input_ids"] += tokenized_samples["input_ids"]
samples["attention_mask"] += tokenized_samples["attention_mask"]
samples["special_tokens_mask"] += tokenized_samples["special_tokens_mask"]

For some reason this is a lot faster, and fast enough to be "useful". I still do not think this is optimal though. Tokenising and grouping is still slowing down the training considerably when you are using a streaming dataset.

skye commented 1 year ago

I'm guessing using += is a lot faster because Python is smart enough to extend the samples lists in-place, whereas the original implementation will end up completely rewriting each list. If that's right, I think using += is the best you can do short of multi-threading (I'm not a Python performance expert though).

peregilk commented 1 year ago

There are a few other things in the script that seem suboptimal. For instance are the tokenization not split across the VMs.

@skye: Do you have an estimate of what performance that should ideally be expected here? Lets say one training step takes 1 second on a v4-8. How long should it take to run it on a v4-128? I guess there are some overhead in dividing the job across the TPUs, right? Just looking for an estimate on how much the current performance depends on the CPUs.

sanchit-gandhi commented 1 year ago

Hey @peregilk! Sorry for the delayed response.

We can't use multiple processes with Datasets' map method when using a streaming dataset. This is because we read the raw dataset's tar file and iterate over the bytes incrementally, iterating over the dataset samples and loading them into memory under a single file at a time. This is why we don't pass the num_proc arg to .map when tokenising the dataset.

If your dataset is small, it might be worth downloading the dataset, pre-processing it and saving it to cache (all done under the hood by Datasets for a non-streaming dataset)? Otherwise this is part of the trade-off for using streaming datasets! We have no disk space constraints but have to load data on the fly.

peregilk commented 1 year ago

Thanks @sanchit-gandhi. In my case, storing the dataset locally is not an option. I would then have to attach a disk to each of the pods, and for the large pods that is not an option.

I understand the samples needs to be tokenized before it is possible to shard them across the tpus, and I also understand that this in reality needs to be done on a single TPU VM. However, I still see more than 10 seconds per step here - it just seems to be a lot.

Do you know if it is possible to pre-tokenize (or even pre-shard) a dataset and keep it streaming? Is it worth looking into that, or do you think it is better looking closer into what is taking time here?

Each TPU VM is quite a capable machine (200 CPU cores). Even if it is hard to split this over multiple VMs, are there better ways of using the VM that need to do the processing?

skye commented 1 year ago

Do you have an estimate of what performance that should ideally be expected here? Lets say one training step takes 1 second on a v4-8. How long should it take to run it on a v4-128?

Sorry missed this earlier. Not sure it's still useful, but for batch parallelism, you should expect near linear scaling if you keep the per-device batch size the same. I.e. if you increase the global batch size 16-fold going from v4-8 -> v4-128, the step time should remain constant. If you keep the global batch size the same (i.e. decrease the per-device batch size as you increase devices), the speedup should be roughly linear until you reach a certain minimum per-device batch size.

peregilk commented 1 year ago

Thanks a lot @skye. Great to get this confirmed. Basically the script today runs 10X slower than it potentially should. Or....put another way... 90% of the time is used for preparing the dataset and 10% is used efficiently for training.

If I understand correctly, @sanchit-gandhi, there will soon be a flax implementation for Whisper with the streaming dataset. I will test this as well, and see if I get the same issues here.

I have a few ideas about how to figure out what is really going on here, and I will start looking into this more thoroughly early next year.

Hope it is OK that I am also tagging @lhoestq .

lhoestq commented 1 year ago

You can use something like a torch DataLoader with num_workers > 0 with your streaming dataset. This way you load and collate the data in parallel to your forward and backward passes.

peregilk commented 1 year ago

Thanks a lot @lhoestq. If I understand correctly, the way this works on streaming datasets is that the DataLoader is starting a worker for each of the dataset shards. So if you have the compute capacity, the optimal setting is num_workers=dataset.n_shards (With my test dataset this is 85).

I tried implementing this like:

# Replace
# training_iter = iter(tokenized_datasets)
training_iter = iter(torch.utils.data.DataLoader(tokenized_datasets.with_format("torch"), batch_size=1, shuffle=False, num_workers=dataset.n_shards, collate_fn=lambda x: x))

My reference is 1 sek/iteration on a v4-8. According to @skye this should continue to be 1 sek/iteration on a v4-128 with my setup. As shown above, I started at 220 sek/iteration on a v4-128. Before the suggestion from @lhoestq, I was down to 11 sek/iteration. After adding the Torch DataLoader this is reduced to 5 sek/iteration.

Even if things are looking way better, I still think this can be improved further. I took a look at the load of the VMs CPUs, and the load is still very low: Approx 10% with some very short peaks. All cores are used.

I am willing to share what I have so far here. @patrickvonplaten: Are you interested in merging the support for the tpu v4-pods into run_mlm_flax_stream.py? Maybe others can contribute and improve on this as well?

sanchit-gandhi commented 1 year ago

Sorry for dropping the ball here @peregilk! I understand that your MLM experiments are working quite well currently?

Are you interested in merging the support for the tpu v4-pods into run_mlm_flax_stream.py? Maybe others can contribute and improve on this as well?

This would be super! We could start with your working MLM streaming script? Feel free to open a PR on transformers if you're interested and tag me 🤗 happy to iterate with you here!

peregilk commented 1 year ago

Yes, @sanchit-gandhi, the training is running at acceptable speed. I am currently training some larger models. When I get the results from these, and are certain that everything really works, Ill open a PR.

peregilk commented 1 year ago

Keeping alive. I will do this together with the Whisper pod support.

Lime-Cakes commented 1 year ago

Kinda unrelated to the issue, since I was working on diffuser model instead. But I noticed some oddity.

At the moment, it seems like the solutions for tpu pod/multiple process is to divide global batch into local batch corresponding to the process ( #16527 ). That would mean a single dataloader for all process, if dataloading is kept behind a process index check. Otherwise, it means the exact same dataloader on all process, each loading a global batch and discarding non-local data.

Wouldn't it be better for each process to have its own dataloader, streaming from a list of pre-divided datasets?

Instead of having one dataset (hf streaming dataset from the datasets library), I tested splitting my dataset into multiple ones of exact size under different index. It seemed to allow faster data loading. I have yet test it on pod environment though. The reasoning behind splitting into multiple dataset is from experience working with tfrecord, which recommends multiple smaller far instead of a massive file as hf datasets currently do with tar streaming.

lhoestq commented 1 year ago

With HF datasets library you can already split_dataset_by_node:

from datasets.distributed import split_dataset_by_node

ds = split_dataset_by_node(ds, rank=rank, world_size=world_size)

this works for regular (="map-style") and iterable datasets (e.g. when streaming).

From the documentation:

For map-style datasets:

Each node is assigned a chunk of data, e.g. rank 0 is given the first chunk of the dataset. To maximize data loading throughput, chunks are made of contiguous data on disk if possible.

For iterable datasets:

If the dataset has a number of shards that is a factor of world_size (i.e. if dataset.n_shards % world_size == 0), then the shards are evenly assigned across the nodes, which is the most optimized. Otherwise, each node keeps 1 example out of world_size, skipping the other examples.

peregilk commented 1 year ago

@Lime-Cakes Thanks a lot @lhoestq. This turned out to be the way for making this run fast on the tpu-v4-pods. The two tricks seems to be using the split_dataset_by_node, and then using the torch.utils.data.DataLoader with a high (30+) number of workers.

I have mainly been working on getting this to run for Whisper lately. I now have a training script here that I am willing to submit. @sanchit-gandhi, please advice, and I will open a pull request. Implementation for run_mlm_flax_streaming.py should be possible to do the same way.

Attaching a graph showing how it scales between v4-8, v4-16 and v4-32. It also shows how defining too few workers will drastically reduce the speed. The scaling is now very close to linear as @skye commented on earlier.

image

To be able to get "enough" workers on the larger pods, the dataset also needs to have a lot of shards. In this example I used 256 shards, giving a maximum of 64 shards on each of the VMs on a v4-32.

peregilk commented 1 year ago

Updating this post. The Whisper Tiny model seem to scale almost perfectly here. I am able to use the pods both for increasing batch size and for increasing speed.

However, for some very strange reason, I am unable to do this for the larger Whisper models. All of them seem to train great for the first few steps, then they simply freezes. I have spent a lot of time debugging this, and is a bit lost at the moment. It does however seem to be related to updating the model state, and not related to the dataset loading.

The training script is available here: https://github.com/NbAiLab/nb-whisper/blob/main/run_flax_speech_recognition_seq2seq_streaming.py

However, we probably need to iron out this bug before making a pull request.

@sanchit-gandhi : I can set up a minimum example for reproducing the bug. Please let me know.

peregilk commented 1 year ago

We now finally have a working training script for Flax Whisper! It uses dataset streaming and runs really fast on TPU pods. It also runs on single TPUs and GPUs. We are making some final modifications and cleanups, and have agreed with @sanchit-gandhi to make a review before making a pull request in a few days.

If anyone following this thread have access to TPUs and want to train Whisper, please notify me. We would really like to also implement gradient checkpointing to boost boost batch size on the large models. However, we do not have the capacity or knowledge to implement it ourselves, but we would be happy to contribute by testing it if anyone has the capacity.