huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
132.3k stars 26.35k forks source link

Multi-TPU training uses just 1 out of 8 cores. #9841

Closed avacaondata closed 3 years ago

avacaondata commented 3 years ago

Environment info

Who can help

@patrickvonplaten, @LysandreJik @sgugger

Information

Model I am using (Bert, XLNet ...): ALBERT base

The problem arises when using:

The problem occurs when I try to train with run_mlm_wwm.py through xla_spawn.py. I've checked that when xla_spawn calls run_mlm_ww.py, xm.xrt_world_size() is 8, which should be. However, when the Trainer starts to train, its batch size is only 64, but should be 64 num_cores = 512. I've printed out the parameters sent by xla_spawn and those received by run_mlm_wwm.py, and they coincide, thus I don't understand why in line 690 of trainer: ```{python}total_train_batch_size = self.args.train_batch_size xm.xrt_world_size()``` the total_train_batch_size is not converted to 512...

This is the full call:

XRT_TPU_CONFIG="tpu_worker;0;10.44.99.146:8470" python -u transformers/examples/xla_spawn.py --num_cores 8 \
    transformers/examples/language-modeling/run_mlm_wwm.py \
    --model_type albert \
    --config_name ./config/albert-base-v2.json \
    --tokenizer_name ./tokenizer_2912 \
    --train_file ./train_texts_1_percent.txt \
    --validation_file ./validation_data/good_texts.csv \
    --output_dir ./models/model_1_percent \
    --overwrite_output_dir \
    --do_train \
    --do_eval \
    --evaluation_strategy steps \
    --per_device_train_batch_size 64 \
    --per_device_eval_batch_size 128 \
    --gradient_accumulation_steps 8 \
    --learning_rate 0.00176 \
    --save_steps 1000 \
    --logging_steps 1000 \
    --overwrite_cache \
    --max_seq_length 512 \
    --eval_accumulation_steps 10 \
    --load_best_model_at_end \
    --run_name model_1_percent \
    --save_total_limit 20 --tpu_metrics_debug

The model starts to train, but it doesn't take into account that it has 8 tpu cores:

[INFO|trainer.py:662] 2021-01-27 12:22:50,282 >> ***** Running training *****
[INFO|trainer.py:663] 2021-01-27 12:22:50,282 >>   Num examples = 5835032
[INFO|trainer.py:664] 2021-01-27 12:22:50,282 >>   Num Epochs = 3
[INFO|trainer.py:665] 2021-01-27 12:22:50,282 >>   Instantaneous batch size per device = 64
[INFO|trainer.py:666] 2021-01-27 12:22:50,282 >>   Total train batch size (w. parallel, distributed & accumulation) = 512
[INFO|trainer.py:667] 2021-01-27 12:22:50,282 >>   Gradient Accumulation steps = 8
[INFO|trainer.py:668] 2021-01-27 12:22:50,282 >>   Total optimization steps = 4272
  0%|                                                              | 3/4272 [04:18<113:20:52, 95.58s/it]

The tasks I am working on is:

To reproduce

Steps to reproduce the behavior:

  1. Instantiate a Google Cloud V3-8 TPU and a n1-standard-64 Google Cloud instance.
  2. Use any toy text dataset and any tokenizer and model name from the ones available in Transformers (these won't change the problem, so it's not necessary to have your own pretrained tokenizer or own dataset).
  3. Try to execute the command I posted above but setting XRT_TPU_CONFIG to the IP address of your TPU.

Expected behavior

It's expected that xla_spawn.py runs the python file passed to it in a multiprocessing fashion, distributing the batches and model over the TPU cores; however, at some point the xrt_world_size() turns to 1 and it doesn't see all the devices available anymore, but only one.

sgugger commented 3 years ago

Hi there. It's just a logging problem in the reporting of the total batch size. If we do the math, from your 5835032 samples, we get 91,172 batches per device, 11,396 batches total (divided by the number of cores) and 1,424 optimization steps (divided by the accumulation steps), which, multiplied by the 3 epochs, gives us the 4,272 steps you see.

So the number of cores is indeed taken into account.

avacaondata commented 3 years ago

Ahh, I see, my bad, I didn't calculate the number of steps correctly then (what a Data Scientist :P) Thank You very much @sgugger