facebookresearch / fairseq

Facebook AI Research Sequence-to-Sequence Toolkit written in Python.
MIT License
30.48k stars 6.41k forks source link

TPU + MT Command #2506

Closed sshleifer closed 2 years ago

sshleifer commented 4 years ago

@myleott I have been struggling with the GCP TPU tutorial and getting either OOM or "connection to mesh master" failed from https://cloud.google.com/tpu/docs/tutorials/transformer-pytorch In the case that you guys are not maintaining that tutorial, is there a working MT command with --tpu checked in? I tried the command below on master with

torch==1.5.0a0+ab660ae
torch-xla==1.5
fairseq-train \
  $HOME/pytorch-tutorial-data/wmt18_en_de_bpej32k \
  --save-interval=1 \
  --arch=transformer_vaswani_wmt_en_de_big \
  --max-target-positions=64  --max-tokens=64 \
  --attention-dropout=0.1 \
  --no-progress-bar \
  --criterion=label_smoothed_cross_entropy \
  --source-lang=en \
  --lr-scheduler=inverse_sqrt \
  --min-lr 1e-09 \
  --skip-invalid-size-inputs-valid-test \
  --target-lang=de \
  --label-smoothing=0.1 \
  --update-freq=1 \
  --optimizer adam \
  --adam-betas '(0.9, 0.98)' \
  --warmup-init-lr 1e-07 \
  --lr 0.0005 \
  --warmup-updates 4000 \
  --share-all-embeddings \
  --dropout 0.3 \
  --weight-decay 0.0 \
  --valid-subset=valid \
  --max-epoch=25 \
  --tpu --bf16 --distributed-world-size 8

and got the following traceback (once for each device)

Traceback (most recent call last):
  File "/anaconda3/envs/torch-xla-1.5/lib/python3.6/site-packages/torch_xla/distributed/xla_multiprocessing.py", lin
e 119, in _start_fn
    fn(gindex, *args)
  File "/home/shleifer/fairseq/fairseq/distributed_utils.py", line 156, in distributed_main
    main(args, **kwargs)
  File "/home/shleifer/fairseq/fairseq_cli/train.py", line 119, in main
    valid_losses, should_stop = train(args, trainer, task, epoch_itr)
  File "/anaconda3/envs/torch-xla-1.5/lib/python3.6/contextlib.py", line 52, in inner
    return func(*args, **kwds)
  File "/home/shleifer/fairseq/fairseq_cli/train.py", line 199, in train
    log_output = trainer.train_step(samples)
  File "/anaconda3/envs/torch-xla-1.5/lib/python3.6/contextlib.py", line 52, in inner
    return func(*args, **kwds)
  File "/home/shleifer/fairseq/fairseq/trainer.py", line 491, in train_step
    logging_outputs, sample_size, ooms, train_time, ignore=is_dummy_batch,
  File "/home/shleifer/fairseq/fairseq/trainer.py", line 820, in _aggregate_logging_outputs
    logging_outputs, *extra_stats_to_sum, ignore=ignore
  File "/home/shleifer/fairseq/fairseq/trainer.py", line 883, in _fast_stat_sync_sum
    group=self.data_parallel_process_group
  File "/home/shleifer/fairseq/fairseq/distributed_utils.py", line 312, in all_reduce_dict
    cpu_data = _all_reduce_dict(cpu_data)
  File "/home/shleifer/fairseq/fairseq/distributed_utils.py", line 309, in _all_reduce_dict
    all_reduce(buf, group=group)
  File "/home/shleifer/fairseq/fairseq/distributed_utils.py", line 207, in all_reduce
    return xm.all_reduce('sum', [tensor], groups=group[1])
  File "/anaconda3/envs/torch-xla-1.5/lib/python3.6/site-packages/torch_xla/core/xla_model.py", line 360, in all_red
uce
    reduce_type, inputs, _get_all_reduce_token(), scale, groups)
TypeError: _xla_all_reduce(): incompatible function arguments. The following argument types are supported:
    1. (arg0: str, arg1: List[at::Tensor], arg2: _XLAC.IrValue, arg3: float, arg4: list) -> _XLAC.IrValue

Invoked with: 'sum', [tensor([51.0000,  0.0000, 40.2500, 51.0000,  3.0000, 51.0000],device='xla:0',
       dtype=torch.float64)], <_XLAC.IrValue object at 0x7ff4085356c0>, 1.0, None
sshleifer commented 4 years ago

Same traceback with

torch==1.7.0a0+fcadca1
torch-xla==1.6+7d5500a
myleott commented 4 years ago

You can use fairseq master and something like:

fairseq-train --distributed-world-size 8 --tpu --task translation --num-batch-buckets 5 ...

It will be slower than GPUs because it will use less efficient batching in order to minimize the number of unique shapes (and thereby avoid XLA recompilations).

myleott commented 4 years ago

I just tested on:

torch==1.7.0a0+4e964f3
torch-xla==1.6+4c86d87
sshleifer commented 4 years ago

Slower than 1 GPU or 8 GPU?

sshleifer commented 4 years ago

Some progress.

Now on

torch==1.7.0a0+fcadca1
torch-xla==1.6+7d5500a

(Don't know how to get your exact versions, but this is closer) with --num-batch-buckets 5.

With --bf16 I get AttributeError: 'FP16Optimizer' object has no attribute '_multiply_factor' with or without export XLA_USE_BF16=1.

without --bf16 it runs, but seems slow (100 steps in 30 secs after slow startup), but maybe that's normal?

(many XLA compilation steps detected)

2020-08-21 17:18:53 | INFO | train_inner | epoch 001:    100 / 289965 loss=14.062, nll_loss=13.875, ppl=15040, wps=0, ups=0, wpb=276, bsz=13, num_updates=100, lr=1.25975e-05, gnorm=7.312, train_wall=340, wall=833
(fewer compilation steps)
2020-08-21 17:20:46 | INFO | train_inner | epoch 001:    200 / 289965 loss=13.25, nll_loss=12.938, ppl=7840, wps=3, ups=0.01, wpb=336, bsz=12, num_updates=200, lr=2.5095e-05, gnorm=4.75, train_wall=61, wall=946
(1 compilation step)
2020-08-21 17:21:20 | INFO | train_inner | epoch 001:    300 / 289965 loss=12.75, nll_loss=12.438, ppl=5536, wps=9.5, ups=0.03, wpb=320, bsz=20, num_updates=300, lr=3.75925e-05, gnorm=3.812, train_wall=23, wall=980

Best to give up/wait or is there some way to use this as is such that it's faster than 1 GPU training?

Or some hardware optimization to get wall closer to train_wall?

myleott commented 4 years ago

Some caveats about the reporting. On TPUs the train_wall is meaningless because of the way XLA works, it doesn’t actually do any real computation in the train_step (it’s just building the graph).

Compilations are slow, but should be very infrequent after a while (as you observed).

The reported wps is off by a factor of 100 (if using log_interval 100). So you’re getting 950 words per second, which is very slow. This is because you have a batch size of 64 tokens. Try setting --max-tokens 4000.

I don’t recommend ever setting the XLA_USE_BF16 environment variable, it will make all the reported loss values bfloat16, which are essentially meaningless. The --bf16 flag in fairseq uses mixed precision and is much better (but seems like you hit a bug, I’ll look into it).

sshleifer commented 4 years ago

TPU w --max_tokens 4000: wps=242.5 (--distributed_world_size=8, same GCP zone) v100 fp16: wps=14721.8, (only uses 10.5 GB) so I guess 8TPU ~= 1.67 GPUs ! Shocking!

Better 1 GPU command below: wps=21245.7 (removed buckets, added --fp16 --max-source-positions 64 --max_tokens 5000)

1 GPU Command:

fairseq-train \
  wmt18_en_de_bpej32k \
  --save-interval=1 \
  --arch=transformer_vaswani_wmt_en_de_big \
  --max-target-positions=64  --max-source-positions 64 --max-tokens=5000 \
  --attention-dropout=0.1 \
  --no-progress-bar \
  --criterion=label_smoothed_cross_entropy \
  --source-lang=en \
  --lr-scheduler=inverse_sqrt \
  --min-lr 1e-09 \
  --skip-invalid-size-inputs-valid-test \
  --target-lang=de \
  --label-smoothing=0.1 \
  --update-freq=1 \
  --optimizer adam \
  --adam-betas '(0.9, 0.98)' \
  --warmup-init-lr 1e-07 \
  --lr 0.0005 \
  --warmup-updates 4000 \
  --share-all-embeddings \
  --dropout 0.3 \
  --weight-decay 0.0 \
  --valid-subset=valid \
  --max-epoch=25 \
  --fp16
sshleifer commented 4 years ago

Also, would love any advice to make either command faster, I might train a few of these models soon and have a lot of TPU credits.

myleott commented 4 years ago

Here are some benchmarks that might be helpful. It's always tricky with TPUs because performance can reduce dramatically with any kind of dynamic shapes.

If I benchmark using a dummy task (which just feeds in a random 30 token source and 30 token target), here's what I get:

4 x V100 (in principle comparable to a v3-8 TPU) gives 57k wps:

CUDA_VISIBLE_DEVICES=0,1,2,3 python train.py --task dummy_mt --arch transformer_vaswani_wmt_en_de_big --share-all-embeddings --criterion=label_smoothed_cross_entropy --label-smoothing=0.1 --dropout 0.3  --attention-dropout=0.1 --optimizer adam --adam-betas '(0.9, 0.98)' --lr-scheduler=inverse_sqrt --lr 0.0005 --warmup-updates 4000 --max-tokens 5000 --log-interval 10 --log-format simple --fp16

v3-8 gives 72k wps:

python train.py --task dummy_mt --arch transformer_vaswani_wmt_en_de_big --share-all-embeddings --criterion=label_smoothed_cross_entropy --label-smoothing=0.1 --dropout 0.3  --attention-dropout=0.1 --optimizer adam --adam-betas '(0.9, 0.98)' --lr-scheduler=inverse_sqrt --lr 0.0005 --warmup-updates 4000 --max-tokens 5000 --log-interval 10 --log-format simple --tpu --distributed-world-size 8

But with real translation data, where the input sizes are variable, we have to aggressively bucket examples on TPUs (using --num-batch-buckets) to avoid too many compilations.

Here's a benchmark on real data (WMT'16 En-De).

4 x V100 gives ~100k wps (varies a bit):

CUDA_VISIBLE_DEVICES=0,1,2,3 python train.py ~/data/data-bin/wmt16_en_de_bpe32k/ --task translation --arch transformer_vaswani_wmt_en_de_big --share-all-embeddings --criterion=label_smoothed_cross_entropy --label-smoothing=0.1 --dropout 0.3  --attention-dropout=0.1 --optimizer adam --adam-betas '(0.9, 0.98)' --lr-scheduler=inverse_sqrt --lr 0.0005 --warmup-updates 4000 --max-tokens 5000 --log-interval 10 --log-format simple --fp16

v3-8 with --num-batch-buckets 5 gives a peak of ~20k wps (once the compilations stop):

python train.py ~/data/data-bin/wmt16_en_de_bpe32k/ --task translation --arch transformer_vaswani_wmt_en_de_big --share-all-embeddings --criterion=label_smoothed_cross_entropy --label-smoothing=0.1 --dropout 0.3  --attention-dropout=0.1 --optimizer adam --adam-betas '(0.9, 0.98)' --lr-scheduler=inverse_sqrt --lr 0.0005 --warmup-updates 4000 --max-tokens 5000 --log-interval 10 --log-format simple --tpu --distributed-world-size 8 --num-batch-buckets 5

I'll take a look at how we might improve this further, since that does seem like too dramatic of a slowdown. I think it's mostly a problem of padding inefficiency, since the largest "batch bucket" has length 485, but that size is being dominated by outliers. Probably we can speed things up by throwing out data with lengths > 100 tokens.

vinodganesan commented 3 years ago

Hi,

Have there been any updates on improving this slowdown for TPUs using FairSeq?

Thanks, Vinod

stale[bot] commented 3 years ago

This issue has been automatically marked as stale. If this issue is still affecting you, please leave any comment (for example, "bump"), and we'll keep it open. We are sorry that we haven't been able to prioritize it yet. If you have any new additional information, please include it with your comment!

scheiblr commented 2 years ago

Any news on that? As TPUv4 is available, it might be considerable to put some more effort into TPU compatibility? Any plans on your side to do so?

stale[bot] commented 2 years ago

This issue has been automatically marked as stale. If this issue is still affecting you, please leave any comment (for example, "bump"), and we'll keep it open. We are sorry that we haven't been able to prioritize it yet. If you have any new additional information, please include it with your comment!

stale[bot] commented 2 years ago

Closing this issue after a prolonged period of inactivity. If this issue is still present in the latest release, please create a new issue with up-to-date information. Thank you!