facebookresearch / fairseq

Facebook AI Research Sequence-to-Sequence Toolkit written in Python.
MIT License
30.53k stars 6.41k forks source link

distributed training for transformer OOM #277

Closed vsuthichai closed 6 years ago

vsuthichai commented 6 years ago

I'm attemping to do distributed training a big transformer model in fp16 using the following script. I receive CUDA out of memory issues. I'm using a p3.16xl on AWS, 8 volta v100 gpus 16gb on a single node. I know I can do the same training using a different distributed training technique by spawning child processes through multiprocessing, but my end goal is to bench this on multi-node. I don't have slurm setup for this, but I'm following the instructions laid out at the end here manually starting one process per gpu: https://github.com/pytorch/fairseq/blob/master/docs/getting_started.rst

HOST_PORT="tcp://10.0.0.168:13333"

kill_children() {
  for PID in ${PIDS[*]}; do
    kill -TERM $PID
  done
}

for i in $(seq 0 7); do
  RANK=$i
  python train.py data-bin/wmt14_en_de_joined_dict   \
       --arch transformer_vaswani_wmt_en_de_big    \
       --share-all-embeddings                      \
       --optimizer adam                            \
       --adam-betas '(0.9, 0.98)'                  \
       --clip-norm 0.0                             \
       --lr-scheduler inverse_sqrt                 \
       --warmup-init-lr 1e-07                      \
       --warmup-updates 4000                       \
       --lr 0.0005                                 \
       --min-lr 1e-09                              \
       --dropout 0.3                               \
       --weight-decay 0.0                          \
       --criterion label_smoothed_cross_entropy    \
       --label-smoothing 0.1                       \
       --max-tokens 3584   --fp16                  \
       --distributed-world-size 8                  \
       --distributed-init-method $HOST_PORT        \
       --distributed-rank $RANK &
  PIDS[$RANK]=$!
done

trap kill_children SIGTERM SIGINT

for PID in ${PIDS[*]}; do
  wait $PID
done

This is the output:

| distributed init (rank 7): tcp://10.0.0.168:13333
| distributed init (rank 1): tcp://10.0.0.168:13333
| distributed init (rank 0): tcp://10.0.0.168:13333
| distributed init (rank 4): tcp://10.0.0.168:13333
| distributed init (rank 6): tcp://10.0.0.168:13333
| distributed init (rank 5): tcp://10.0.0.168:13333
| distributed init (rank 2): tcp://10.0.0.168:13333
| distributed init (rank 3): tcp://10.0.0.168:13333
| initialized host ip-10-0-0-168 as rank 0
Namespace(adam_betas='(0.9, 0.98)', adam_eps=1e-08, adaptive_softmax_cutoff=None, adaptive_softmax_dropout=0, arch='transformer_vaswani_wmt_en_de_big', attention_dropout=0.0, clip_norm=0.0, criterion='label_smoothed_cross_entropy', data='data-bin/wmt14_en_de_joined_dict', decoder_attention_heads=16, decoder_embed_dim=1024, decoder_embed_path=None, decoder_ffn_embed_dim=4096, decoder_input_dim=1024, decoder_layers=6, decoder_learned_pos=False, decoder_normalize_before=False, decoder_output_dim=1024, device_id=0, distributed_backend='nccl', distributed_init_method='tcp://10.0.0.168:13333', distributed_port=-1, distributed_rank=0, distributed_world_size=8, dropout=0.3, encoder_attention_heads=16, encoder_embed_dim=1024, encoder_embed_path=None, encoder_ffn_embed_dim=4096, encoder_layers=6, encoder_learned_pos=False, encoder_normalize_before=False, fp16=True, keep_interval_updates=-1, label_smoothing=0.1, left_pad_source='True', left_pad_target='False', log_format=None, log_interval=1000, lr=[0.0005], lr_scheduler='inverse_sqrt', lr_shrink=0.1, max_epoch=0, max_sentences=None, max_sentences_valid=None, max_source_positions=1024, max_target_positions=1024, max_tokens=3584, max_update=0, min_loss_scale=0.0001, min_lr=1e-09, momentum=0.99, no_epoch_checkpoints=False, no_progress_bar=False, no_save=False, no_token_positional_embeddings=False, optimizer='adam', optimizer_overrides='{}', raw_text=False, relu_dropout=0.0, reset_lr_scheduler=False, reset_optimizer=False, restore_file='checkpoint_last.pt', save_dir='checkpoints', save_interval=1, save_interval_updates=0, seed=1, sentence_avg=False, share_all_embeddings=True, share_decoder_input_output_embed=False, skip_invalid_size_inputs_valid_test=False, source_lang=None, target_lang=None, task='translation', train_subset='train', update_freq=[1], upsample_primary=1, valid_subset='valid', validate_interval=1, warmup_init_lr=1e-07, warmup_updates=4000, weight_decay=0.0)
| [en] dictionary: 32768 types
| [de] dictionary: 32768 types
| data-bin/wmt14_en_de_joined_dict train 4528446 examples
| data-bin/wmt14_en_de_joined_dict valid 3000 examples
| model transformer_vaswani_wmt_en_de_big, criterion LabelSmoothedCrossEntropyCriterion
| num. model params: 209911808
| training on 8 GPUs
| max tokens per GPU = 3584 and max sentences per GPU = None
| epoch 001:   0%|                                                                                                                                                                                 | 0/5492 [00:00<?, ?it/s]THCudaCheck FAIL file=/pytorch/aten/src/THC/generic/THCTensorMath.cu line=15 error=2 : out of memory
Traceback (most recent call last):
  File "train.py", line 356, in <module>
    distributed_main(args)
  File "/home/ubuntu/github/fairseq/distributed_train.py", line 39, in main
    single_process_main(args)
  File "/home/ubuntu/github/fairseq/train.py", line 95, in main
    train(args, trainer, task, epoch_itr)
  File "/home/ubuntu/github/fairseq/train.py", line 133, in train
    log_output = trainer.train_step(sample, update_params=True)
  File "/home/ubuntu/github/fairseq/fairseq/trainer.py", line 144, in train_step
    agg_logging_output = self._update_params()
  File "/home/ubuntu/github/fairseq/fairseq/trainer.py", line 163, in _update_params
    (sample_sizes, logging_outputs, ooms_fwd, ooms_bwd)
  File "/home/ubuntu/github/fairseq/fairseq/distributed_utils.py", line 73, in all_gather_list
    in_buffer[0] = enc_size // 255  # this encoding works for max_size < 65k
RuntimeError: cuda runtime error (2) : out of memory at /pytorch/aten/src/THC/generic/THCTensorMath.cu:15
Traceback (most recent call last):
  File "train.py", line 356, in <module>
    distributed_main(args)
  File "/home/ubuntu/github/fairseq/distributed_train.py", line 39, in main
    single_process_main(args)
  File "/home/ubuntu/github/fairseq/train.py", line 83, in main
    trainer.dummy_train_step(dummy_batch)
  File "/home/ubuntu/github/fairseq/fairseq/trainer.py", line 342, in dummy_train_step
    self.train_step(dummy_batch, update_params=False, dummy_batch=True)
  File "/home/ubuntu/github/fairseq/fairseq/trainer.py", line 133, in train_step
    loss, sample_size, logging_output, oom_fwd = self._forward(sample)
  File "/home/ubuntu/github/fairseq/fairseq/trainer.py", line 235, in _forward
    raise e
  File "/home/ubuntu/github/fairseq/fairseq/trainer.py", line 227, in _forward
    loss, sample_size, logging_output_ = self.task.get_loss(self.model, self.criterion, sample)
  File "/home/ubuntu/github/fairseq/fairseq/tasks/fairseq_task.py", line 157, in get_loss
    return criterion(model, sample)
  File "/home/ubuntu/anaconda3/envs/pytorch/lib/python3.6/site-packages/torch/nn/modules/module.py", line 477, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/ubuntu/github/fairseq/fairseq/criterions/label_smoothed_cross_entropy.py", line 36, in forward
    net_output = model(**sample['net_input'])
  File "/home/ubuntu/anaconda3/envs/pytorch/lib/python3.6/site-packages/torch/nn/modules/module.py", line 477, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/ubuntu/github/fairseq/fairseq/models/fairseq_model.py", line 159, in forward
    encoder_out = self.encoder(src_tokens, src_lengths)
  File "/home/ubuntu/anaconda3/envs/pytorch/lib/python3.6/site-packages/torch/nn/modules/module.py", line 477, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/ubuntu/github/fairseq/fairseq/models/transformer.py", line 290, in forward
    x = layer(x, encoder_padding_mask)
  File "/home/ubuntu/anaconda3/envs/pytorch/lib/python3.6/site-packages/torch/nn/modules/module.py", line 477, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/ubuntu/github/fairseq/fairseq/models/transformer.py", line 549, in forward
    x, _ = self.self_attn(query=x, key=x, value=x, key_padding_mask=encoder_padding_mask)
  File "/home/ubuntu/anaconda3/envs/pytorch/lib/python3.6/site-packages/torch/nn/modules/module.py", line 477, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/ubuntu/github/fairseq/fairseq/modules/multihead_attention.py", line 80, in forward
    q, k, v = self.in_proj_qkv(query)
  File "/home/ubuntu/github/fairseq/fairseq/modules/multihead_attention.py", line 150, in in_proj_qkv
    return self._in_proj(query).chunk(3, dim=-1)
  File "/home/ubuntu/github/fairseq/fairseq/modules/multihead_attention.py", line 170, in _in_proj
    return F.linear(input, weight, bias)
  File "/home/ubuntu/anaconda3/envs/pytorch/lib/python3.6/site-packages/torch/nn/functional.py", line 1026, in linear
    output = input.matmul(weight.t())
RuntimeError: cublas runtime error : resource allocation failed at /pytorch/aten/src/THC/THCGeneral.cpp:333
| WARNING: ran out of memory, skipping batch
Traceback (most recent call last):
  File "train.py", line 356, in <module>
    distributed_main(args)
  File "/home/ubuntu/github/fairseq/distributed_train.py", line 39, in main
    single_process_main(args)
  File "/home/ubuntu/github/fairseq/train.py", line 83, in main
    trainer.dummy_train_step(dummy_batch)
  File "/home/ubuntu/github/fairseq/fairseq/trainer.py", line 343, in dummy_train_step
    self.zero_grad()
  File "/home/ubuntu/github/fairseq/fairseq/fp16_trainer.py", line 94, in zero_grad
    self.optimizer.zero_grad()  # FP32
  File "/home/ubuntu/github/fairseq/fairseq/trainer.py", line 68, in optimizer
    self._build_optimizer()
  File "/home/ubuntu/github/fairseq/fairseq/fp16_trainer.py", line 66, in _build_optimizer
    self.fp32_params = params[0].new(0).float().new(total_param_size)
RuntimeError: CUDA error: out of memory
| epoch 001:   0%|                                                                                                                                                                                 | 0/5492 [00:00<?, ?it/s]Traceback (most recent call last):
  File "train.py", line 356, in <module>
    distributed_main(args)
  File "/home/ubuntu/github/fairseq/distributed_train.py", line 39, in main
    single_process_main(args)
  File "/home/ubuntu/github/fairseq/train.py", line 83, in main
    trainer.dummy_train_step(dummy_batch)
  File "/home/ubuntu/github/fairseq/fairseq/trainer.py", line 343, in dummy_train_step
    self.zero_grad()
  File "/home/ubuntu/github/fairseq/fairseq/fp16_trainer.py", line 94, in zero_grad
    self.optimizer.zero_grad()  # FP32
  File "/home/ubuntu/github/fairseq/fairseq/trainer.py", line 68, in optimizer
    self._build_optimizer()
  File "/home/ubuntu/github/fairseq/fairseq/fp16_trainer.py", line 73, in _build_optimizer
    self.fp32_params.grad = self.fp32_params.data.new(total_param_size)
RuntimeError: CUDA error: out of memory
Traceback (most recent call last):
  File "train.py", line 356, in <module>
    distributed_main(args)
  File "/home/ubuntu/github/fairseq/distributed_train.py", line 39, in main
    single_process_main(args)
  File "/home/ubuntu/github/fairseq/train.py", line 83, in main
    trainer.dummy_train_step(dummy_batch)
  File "/home/ubuntu/github/fairseq/fairseq/trainer.py", line 343, in dummy_train_step
    self.zero_grad()
  File "/home/ubuntu/github/fairseq/fairseq/fp16_trainer.py", line 94, in zero_grad
    self.optimizer.zero_grad()  # FP32
  File "/home/ubuntu/github/fairseq/fairseq/trainer.py", line 68, in optimizer
    self._build_optimizer()
  File "/home/ubuntu/github/fairseq/fairseq/fp16_trainer.py", line 73, in _build_optimizer
    self.fp32_params.grad = self.fp32_params.data.new(total_param_size)
RuntimeError: CUDA error: out of memory
Traceback (most recent call last):
  File "train.py", line 356, in <module>
    distributed_main(args)
  File "/home/ubuntu/github/fairseq/distributed_train.py", line 39, in main
    single_process_main(args)
  File "/home/ubuntu/github/fairseq/train.py", line 83, in main
    trainer.dummy_train_step(dummy_batch)
  File "/home/ubuntu/github/fairseq/fairseq/trainer.py", line 343, in dummy_train_step
    self.zero_grad()
  File "/home/ubuntu/github/fairseq/fairseq/fp16_trainer.py", line 94, in zero_grad
    self.optimizer.zero_grad()  # FP32
  File "/home/ubuntu/github/fairseq/fairseq/trainer.py", line 68, in optimizer
    self._build_optimizer()
  File "/home/ubuntu/github/fairseq/fairseq/fp16_trainer.py", line 73, in _build_optimizer
    self.fp32_params.grad = self.fp32_params.data.new(total_param_size)
RuntimeError: CUDA error: out of memory
| epoch 001:   0%|                                                                                                                                                                                 | 0/5492 [00:00<?, ?it/s]Traceback (most recent call last):
  File "train.py", line 356, in <module>
    distributed_main(args)
  File "/home/ubuntu/github/fairseq/distributed_train.py", line 39, in main
    single_process_main(args)
  File "/home/ubuntu/github/fairseq/train.py", line 95, in main
    train(args, trainer, task, epoch_itr)
  File "/home/ubuntu/github/fairseq/train.py", line 133, in train
    log_output = trainer.train_step(sample, update_params=True)
  File "/home/ubuntu/github/fairseq/fairseq/trainer.py", line 144, in train_step
    agg_logging_output = self._update_params()
  File "/home/ubuntu/github/fairseq/fairseq/trainer.py", line 163, in _update_params
    (sample_sizes, logging_outputs, ooms_fwd, ooms_bwd)
  File "/home/ubuntu/github/fairseq/fairseq/distributed_utils.py", line 77, in all_gather_list
    torch.distributed.all_gather(out_buffers, in_buffer.cuda())
  File "/home/ubuntu/anaconda3/envs/pytorch/lib/python3.6/site-packages/torch/distributed/__init__.py", line 439, in all_gather
    return all_gather_multigpu([tensor_list], [tensor], group)
  File "/home/ubuntu/anaconda3/envs/pytorch/lib/python3.6/site-packages/torch/distributed/__init__.py", line 413, in all_gather_multigpu
    group)
RuntimeError: Connection reset by peer
Traceback (most recent call last):
  File "train.py", line 356, in <module>
    distributed_main(args)
  File "/home/ubuntu/github/fairseq/distributed_train.py", line 39, in main
    single_process_main(args)
  File "/home/ubuntu/github/fairseq/train.py", line 95, in main
    train(args, trainer, task, epoch_itr)
  File "/home/ubuntu/github/fairseq/train.py", line 133, in train
    log_output = trainer.train_step(sample, update_params=True)
  File "/home/ubuntu/github/fairseq/fairseq/trainer.py", line 144, in train_step
    agg_logging_output = self._update_params()
  File "/home/ubuntu/github/fairseq/fairseq/trainer.py", line 163, in _update_params
    (sample_sizes, logging_outputs, ooms_fwd, ooms_bwd)
  File "/home/ubuntu/github/fairseq/fairseq/distributed_utils.py", line 77, in all_gather_list
    torch.distributed.all_gather(out_buffers, in_buffer.cuda())
  File "/home/ubuntu/anaconda3/envs/pytorch/lib/python3.6/site-packages/torch/distributed/__init__.py", line 439, in all_gather
    return all_gather_multigpu([tensor_list], [tensor], group)
  File "/home/ubuntu/anaconda3/envs/pytorch/lib/python3.6/site-packages/torch/distributed/__init__.py", line 413, in all_gather_multigpu
    group)
RuntimeError: Connection reset by peer
edunov commented 6 years ago

You may need to force each worker to it's own GPU, e.g. by setting CUDA_VISIBLE_DEVICES=$RANK in your for loop. Otherwise, I suspect they all start on GPU 0, you can actually confirm it by running nvidia-smi when the job starts

vsuthichai commented 6 years ago

@edunov Appreciate the quick response, it got past the OOM error, but now runs into RuntimeError: NCCL error in: /pytorch/torch/lib/THD/base/data_channels/DataChannelNccl.cpp:322, unhandled cuda error.

Are there any environment variables I need to set? I'm currently on cuda9.2, nccl2.2.13, driver version 396.37

vsuthichai commented 6 years ago

@edunov An update, with nccl debugging turned on, I can see this error now :

ip-10-0-0-168:55603:55931 [0] transport/p2p.cu:526 WARN failed to open CUDA IPC handle : 11 invalid argument
ip-10-0-0-168:55603:55931 [0] INFO init.cu:475 -> 1
ip-10-0-0-168:55603:55931 [0] INFO init.cu:536 -> 1
ip-10-0-0-168:55603:55931 [0] INFO misc/group.cu:70 -> 1 [Async thread]

ip-10-0-0-168:55604:55932 [0] transport/p2p.cu:526 WARN failed to open CUDA IPC handle : 11 invalid argument
ip-10-0-0-168:55604:55932 [0] INFO init.cu:475 -> 1
ip-10-0-0-168:55604:55932 [0] INFO init.cu:536 -> 1
ip-10-0-0-168:55604:55932 [0] INFO misc/group.cu:70 -> 1 [Async thread]
edunov commented 6 years ago

Ouch, sorry, I think I was wrong about CUDA_VISIBLE_DEVICES. Can you try this: 1) Remove CUDA_VISIBLE_DEVICES 2) add --device-id $RANK to train.py

Sorry for the confusion

vsuthichai commented 6 years ago

@edunov that seems to have done the trick. If I could request the getting_started.rst link be updated, that'd be great :)

So it seems the next issue is that the throughput wps has dropped significantly. Running it on a single node with 8 gpus, I can get it close to the published 143k wps. Running the job on two nodes with 16 gpus, the wps is just over 10k.

The two scripts used to launch below:

#!/bin/bash

HOST_PORT="tcp://10.0.0.168:13333"

kill_children() {
  for PID in ${PIDS[*]}; do
    kill -TERM $PID
  done
}

NODE=0
RANKS_PER_NODE=8

for i in $(seq 0 7); do
  LOCAL_RANK=$i
  DISTRIBUTED_RANK=$((RANKS_PER_NODE * NODE + LOCAL_RANK))
  python train.py data-bin/wmt14_en_de_joined_dict   \
       --arch transformer_vaswani_wmt_en_de_big    \
       --share-all-embeddings                      \
       --optimizer adam                            \
       --adam-betas '(0.9, 0.98)'                  \
       --clip-norm 0.0                             \
       --lr-scheduler inverse_sqrt                 \
       --warmup-init-lr 1e-07                      \
       --warmup-updates 4000                       \
       --lr 0.0005                                 \
       --min-lr 1e-09                              \
       --dropout 0.3                               \
       --weight-decay 0.0                          \
       --criterion label_smoothed_cross_entropy    \
       --label-smoothing 0.1                       \
       --max-tokens 3584   --fp16                  \
       --distributed-world-size 16                 \
       --distributed-init-method $HOST_PORT        \
       --device-id $LOCAL_RANK                     \
       --distributed-rank $DISTRIBUTED_RANK &
  PIDS[$RANK]=$!
done

trap kill_children SIGTERM SIGINT

for PID in ${PIDS[*]}; do
  wait $PID
done
#!/bin/bash

HOST_PORT="tcp://10.0.0.168:13333"

kill_children() {
  for PID in ${PIDS[*]}; do
    kill -TERM $PID
  done
}

NODE=1
RANKS_PER_NODE=8

for i in $(seq 0 7); do
  LOCAL_RANK=$i
  DISTRIBUTED_RANK=$((RANKS_PER_NODE * NODE + LOCAL_RANK))
  python train.py data-bin/wmt14_en_de_joined_dict   \
       --arch transformer_vaswani_wmt_en_de_big    \
       --share-all-embeddings                      \
       --optimizer adam                            \
       --adam-betas '(0.9, 0.98)'                  \
       --clip-norm 0.0                             \
       --lr-scheduler inverse_sqrt                 \
       --warmup-init-lr 1e-07                      \
       --warmup-updates 4000                       \
       --lr 0.0005                                 \
       --min-lr 1e-09                              \
       --dropout 0.3                               \
       --weight-decay 0.0                          \
       --criterion label_smoothed_cross_entropy    \
       --label-smoothing 0.1                       \
       --max-tokens 3584   --fp16                  \
       --distributed-world-size 16                 \
       --distributed-init-method $HOST_PORT        \
       --device-id $LOCAL_RANK                     \
       --distributed-rank $DISTRIBUTED_RANK &
  PIDS[$RANK]=$!
done

trap kill_children SIGTERM SIGINT

for PID in ${PIDS[*]}; do
  wait $PID
done
vsuthichai commented 6 years ago

I've increased NCCL_MIN_NRINGS=5, this helps just a tiny bit.

vsuthichai commented 6 years ago

@edunov I'm also wondering if there is a simpler way to start multi node distributed training than what I've done so far. I don't have slurm installed on my cluster, so I've been resorting to starting every process manually.

edunov commented 6 years ago

10k is pretty small, do you know what is the network speed between two nodes on AWS? The results we have reported in the paper are obtained on InfiniBand connected cluster, so it was very fast. On Ethernet, we also observe a slow down then two machines are used compared to one.

We're working on a new version at the moment that should help a bit with network latency. And yes, let me update the wiki and the start script. In theory, we should be able to use multiprocessing_train.py to launch distributed jobs.

vsuthichai commented 6 years ago

@edunov 25Gbps between between two nodes. 10Gbps per connection. I can look at bandwidth usage, but I'm not convinced this is the problem. Decreasing from 143k wps (1 node / 8 gpus) to 10k wps (2 nodes / 8 gpus) is a bit much. I would have expected to see something greater than 143k.

Ultimately I'm trying to reproduce the results published in the recent paper : https://arxiv.org/pdf/1806.00187.pdf

myleott commented 6 years ago

Hmm, we haven't experimented on AWS infra before, but we get much better results than this over ethernet. For example, using a batch size of 5120, FP16 and ethernet we get 167k wps on a single node and 237k wps on two nodes. Can you confirm that your AWS instances are in the same placement group?

myleott commented 6 years ago

README was updated, closing for now.