facebookresearch / XLM

PyTorch original implementation of Cross-lingual Language Model Pretraining.
Other
2.89k stars 498 forks source link

Multi-GPU training extremely slow #299

Open Jamiroquai88 opened 4 years ago

Jamiroquai88 commented 4 years ago

Hi,

See #189. I have a very similar problem with speed.

CPU

Intel(R) Xeon(R) CPU E5-2650 v4 @ 2.20GHz

MB

SuperMicro X10DRG-Q

1x RTX 2080

export CUDA_VISIBLE_DEVICES=0
export NGPU=1

python -W ignore::UserWarning -m torch.distributed.launch --nproc_per_node=$NGPU train.py \
    --exp_name xlm_cs-en \
    --dump_path ./dumped \
    --data_path $OUTPATH \
    --lgs 'cs-en' \
    --clm_steps '' \
    --mlm_steps 'cs,en' \
    --emb_dim 1024 \
    --n_layers 12 \
    --n_heads 8 \
    --dropout 0.1 \
    --attention_dropout 0.1 \
    --gelu_activation true \
    --batch_size 4 \
    --bptt 256 \
    --optimizer adam,lr=0.0001 \
    --epoch_size 300000 \
    --max_epoch 100000 \
    --validation_metrics _valid_mlm_ppl \
    --stopping_criterion _valid_mlm_ppl,25 \
    --max_vocab 150000 \
    --amp 01 \
    --fp16 true
INFO - 06/22/20 10:51:39 - 0:01:18 -      20 -   14.63 sent/s -   511.97 words/s - MLM-cs:  6.2584 || MLM-en:  7.7495 -  - model LR: 1.0000e-04

10x RTX 2080

export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7,8,9
export NGPU=10

python -W ignore::UserWarning -m torch.distributed.launch --nproc_per_node=$NGPU train.py \
    --exp_name xlm_cs-en \
    --dump_path ./dumped \
    --data_path $OUTPATH \
    --lgs 'cs-en' \
    --clm_steps '' \
    --mlm_steps 'cs,en' \
    --emb_dim 1024 \
    --n_layers 12 \
    --n_heads 8 \
    --dropout 0.1 \
    --attention_dropout 0.1 \
    --gelu_activation true \
    --batch_size 4 \
    --bptt 256 \
    --optimizer adam,lr=0.0001 \
    --epoch_size 300000 \
    --max_epoch 100000 \
    --validation_metrics _valid_mlm_ppl \
    --stopping_criterion _valid_mlm_ppl,25 \
    --max_vocab 150000 \
    --amp 01 \
    --fp16 true \
    --accumulate_gradients 4
INFO - 06/22/20 10:47:45 - 0:02:48 -      25 -    2.55 sent/s -    96.99 words/s - MLM-cs:  7.2759 || MLM-en:  8.0421 -  - model LR: 1.0000e-04
INFO - 06/22/20 10:47:45 - 0:02:48 -      25 -    2.55 sent/s -    95.45 words/s - MLM-cs:  7.0937 || MLM-en:  7.9573 -  - model LR: 1.0000e-04
INFO - 06/22/20 10:47:45 - 0:02:49 -      25 -    2.55 sent/s -    93.41 words/s - MLM-cs:  7.6418 || MLM-en:  8.1558 -  - model LR: 1.0000e-04
INFO - 06/22/20 10:47:45 - 0:02:49 -      25 -    2.55 sent/s -    94.94 words/s - MLM-cs:  7.5107 || MLM-en:  8.1327 -  - model LR: 1.0000e-04
INFO - 06/22/20 10:47:45 - 0:02:49 -      25 -    2.55 sent/s -    98.00 words/s - MLM-cs:  7.4999 || MLM-en:  8.1016 -  - model LR: 1.0000e-04
INFO - 06/22/20 10:47:45 - 0:02:49 -      25 -    2.55 sent/s -    94.94 words/s - MLM-cs:  7.7264 || MLM-en:  7.9569 -  - model LR: 1.0000e-04
INFO - 06/22/20 10:47:45 - 0:02:48 -      25 -    2.55 sent/s -    95.96 words/s - MLM-cs:  7.1065 || MLM-en:  8.0713 -  - model LR: 1.0000e-04
INFO - 06/22/20 10:47:45 - 0:02:48 -      25 -    2.55 sent/s -    95.95 words/s - MLM-cs:  7.4804 || MLM-en:  8.2623 -  - model LR: 1.0000e-04
INFO - 06/22/20 10:47:45 - 0:02:48 -      25 -    2.55 sent/s -    96.47 words/s - MLM-cs:  7.4324 || MLM-en:  8.1150 -  - model LR: 1.0000e-04
INFO - 06/22/20 10:47:45 - 0:02:49 -      25 -    2.55 sent/s -    94.42 words/s - MLM-cs:  7.4249 || MLM-en:  7.7899 -  - model LR: 1.0000e-04

I tried printing this after every iteration and it returns same value for all gpus:

logger.info(sum(p.sum().item() for p in model.parameters()))

Also, I am getting this warning:

/tmp/pip-req-build-ocx5vxk7/aten/src/ATen/native/IndexingUtils.h:20: UserWarning: indexing with dtype torch.uint8 is now deprecated, please use a dtype torch.bool instead.

Comment: I downgraded to pytorch==1.0.1 and warning disappeared. Problem with speed is still present.

Overall, it looks like training with 10x RTX 2080 is only a little bit faster than with 1x RTX 2080. I am getting very similar results for another server with 5x1080. I used Horovod for parallel training and it was always fine (not linear, but close enough). Is there something, that I can try?

Thanks.

Originally posted by @Jamiroquai88 in https://github.com/facebookresearch/XLM/issues/189#issuecomment-647415523