Closed sshleifer closed 2 years ago
Same traceback with
torch==1.7.0a0+fcadca1
torch-xla==1.6+7d5500a
You can use fairseq master and something like:
fairseq-train --distributed-world-size 8 --tpu --task translation --num-batch-buckets 5 ...
It will be slower than GPUs because it will use less efficient batching in order to minimize the number of unique shapes (and thereby avoid XLA recompilations).
I just tested on:
torch==1.7.0a0+4e964f3
torch-xla==1.6+4c86d87
Slower than 1 GPU or 8 GPU?
Some progress.
Now on
torch==1.7.0a0+fcadca1
torch-xla==1.6+7d5500a
(Don't know how to get your exact versions, but this is closer)
with --num-batch-buckets 5
.
With --bf16
I get
AttributeError: 'FP16Optimizer' object has no attribute '_multiply_factor'
with or without export XLA_USE_BF16=1
.
without --bf16
it runs, but seems slow (100 steps in 30 secs after slow startup), but maybe that's normal?
(many XLA compilation steps detected)
2020-08-21 17:18:53 | INFO | train_inner | epoch 001: 100 / 289965 loss=14.062, nll_loss=13.875, ppl=15040, wps=0, ups=0, wpb=276, bsz=13, num_updates=100, lr=1.25975e-05, gnorm=7.312, train_wall=340, wall=833
(fewer compilation steps)
2020-08-21 17:20:46 | INFO | train_inner | epoch 001: 200 / 289965 loss=13.25, nll_loss=12.938, ppl=7840, wps=3, ups=0.01, wpb=336, bsz=12, num_updates=200, lr=2.5095e-05, gnorm=4.75, train_wall=61, wall=946
(1 compilation step)
2020-08-21 17:21:20 | INFO | train_inner | epoch 001: 300 / 289965 loss=12.75, nll_loss=12.438, ppl=5536, wps=9.5, ups=0.03, wpb=320, bsz=20, num_updates=300, lr=3.75925e-05, gnorm=3.812, train_wall=23, wall=980
Best to give up/wait or is there some way to use this as is such that it's faster than 1 GPU training?
Or some hardware optimization to get wall closer to train_wall?
Some caveats about the reporting. On TPUs the train_wall is meaningless because of the way XLA works, it doesn’t actually do any real computation in the train_step (it’s just building the graph).
Compilations are slow, but should be very infrequent after a while (as you observed).
The reported wps is off by a factor of 100 (if using log_interval 100). So you’re getting 950 words per second, which is very slow. This is because you have a batch size of 64 tokens. Try setting --max-tokens 4000.
I don’t recommend ever setting the XLA_USE_BF16 environment variable, it will make all the reported loss values bfloat16, which are essentially meaningless. The --bf16 flag in fairseq uses mixed precision and is much better (but seems like you hit a bug, I’ll look into it).
TPU w --max_tokens 4000: wps=242.5 (--distributed_world_size=8, same GCP zone) v100 fp16: wps=14721.8, (only uses 10.5 GB) so I guess 8TPU ~= 1.67 GPUs ! Shocking!
Better 1 GPU command below: wps=21245.7 (removed buckets, added --fp16 --max-source-positions 64 --max_tokens 5000
)
fairseq-train \
wmt18_en_de_bpej32k \
--save-interval=1 \
--arch=transformer_vaswani_wmt_en_de_big \
--max-target-positions=64 --max-source-positions 64 --max-tokens=5000 \
--attention-dropout=0.1 \
--no-progress-bar \
--criterion=label_smoothed_cross_entropy \
--source-lang=en \
--lr-scheduler=inverse_sqrt \
--min-lr 1e-09 \
--skip-invalid-size-inputs-valid-test \
--target-lang=de \
--label-smoothing=0.1 \
--update-freq=1 \
--optimizer adam \
--adam-betas '(0.9, 0.98)' \
--warmup-init-lr 1e-07 \
--lr 0.0005 \
--warmup-updates 4000 \
--share-all-embeddings \
--dropout 0.3 \
--weight-decay 0.0 \
--valid-subset=valid \
--max-epoch=25 \
--fp16
Also, would love any advice to make either command faster, I might train a few of these models soon and have a lot of TPU credits.
Here are some benchmarks that might be helpful. It's always tricky with TPUs because performance can reduce dramatically with any kind of dynamic shapes.
If I benchmark using a dummy task (which just feeds in a random 30 token source and 30 token target), here's what I get:
4 x V100 (in principle comparable to a v3-8 TPU) gives 57k wps:
CUDA_VISIBLE_DEVICES=0,1,2,3 python train.py --task dummy_mt --arch transformer_vaswani_wmt_en_de_big --share-all-embeddings --criterion=label_smoothed_cross_entropy --label-smoothing=0.1 --dropout 0.3 --attention-dropout=0.1 --optimizer adam --adam-betas '(0.9, 0.98)' --lr-scheduler=inverse_sqrt --lr 0.0005 --warmup-updates 4000 --max-tokens 5000 --log-interval 10 --log-format simple --fp16
v3-8 gives 72k wps:
python train.py --task dummy_mt --arch transformer_vaswani_wmt_en_de_big --share-all-embeddings --criterion=label_smoothed_cross_entropy --label-smoothing=0.1 --dropout 0.3 --attention-dropout=0.1 --optimizer adam --adam-betas '(0.9, 0.98)' --lr-scheduler=inverse_sqrt --lr 0.0005 --warmup-updates 4000 --max-tokens 5000 --log-interval 10 --log-format simple --tpu --distributed-world-size 8
But with real translation data, where the input sizes are variable, we have to aggressively bucket examples on TPUs (using --num-batch-buckets
) to avoid too many compilations.
Here's a benchmark on real data (WMT'16 En-De).
4 x V100 gives ~100k wps (varies a bit):
CUDA_VISIBLE_DEVICES=0,1,2,3 python train.py ~/data/data-bin/wmt16_en_de_bpe32k/ --task translation --arch transformer_vaswani_wmt_en_de_big --share-all-embeddings --criterion=label_smoothed_cross_entropy --label-smoothing=0.1 --dropout 0.3 --attention-dropout=0.1 --optimizer adam --adam-betas '(0.9, 0.98)' --lr-scheduler=inverse_sqrt --lr 0.0005 --warmup-updates 4000 --max-tokens 5000 --log-interval 10 --log-format simple --fp16
v3-8 with --num-batch-buckets 5
gives a peak of ~20k wps (once the compilations stop):
python train.py ~/data/data-bin/wmt16_en_de_bpe32k/ --task translation --arch transformer_vaswani_wmt_en_de_big --share-all-embeddings --criterion=label_smoothed_cross_entropy --label-smoothing=0.1 --dropout 0.3 --attention-dropout=0.1 --optimizer adam --adam-betas '(0.9, 0.98)' --lr-scheduler=inverse_sqrt --lr 0.0005 --warmup-updates 4000 --max-tokens 5000 --log-interval 10 --log-format simple --tpu --distributed-world-size 8 --num-batch-buckets 5
I'll take a look at how we might improve this further, since that does seem like too dramatic of a slowdown. I think it's mostly a problem of padding inefficiency, since the largest "batch bucket" has length 485, but that size is being dominated by outliers. Probably we can speed things up by throwing out data with lengths > 100 tokens.
Hi,
Have there been any updates on improving this slowdown for TPUs using FairSeq?
Thanks, Vinod
This issue has been automatically marked as stale. If this issue is still affecting you, please leave any comment (for example, "bump"), and we'll keep it open. We are sorry that we haven't been able to prioritize it yet. If you have any new additional information, please include it with your comment!
Any news on that? As TPUv4 is available, it might be considerable to put some more effort into TPU compatibility? Any plans on your side to do so?
This issue has been automatically marked as stale. If this issue is still affecting you, please leave any comment (for example, "bump"), and we'll keep it open. We are sorry that we haven't been able to prioritize it yet. If you have any new additional information, please include it with your comment!
Closing this issue after a prolonged period of inactivity. If this issue is still present in the latest release, please create a new issue with up-to-date information. Thank you!
@myleott I have been struggling with the GCP TPU tutorial and getting either OOM or "connection to mesh master" failed from https://cloud.google.com/tpu/docs/tutorials/transformer-pytorch In the case that you guys are not maintaining that tutorial, is there a working MT command with
--tpu
checked in? I tried the command below on master withand got the following traceback (once for each device)