Open aminaqi opened 6 months ago
currently using this for distributed training and i get error on attention_masks
transformer = Seq2SeqTransformer(NUM_ENCODER_LAYERS, NUM_DECODER_LAYERS, EMB_SIZE,
NHEAD, SRC_VOCAB_SIZE, TGT_VOCAB_SIZE, FFN_HID_DIM)
if torch.cuda.device_count() > 1:
transformer = nn.DataParallel(transformer)
for p in transformer.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
transformer = transformer.to(DEVICE)
`
Could you try using DistributedDataParallel
from torch.nn.parallel.distributed
instead of nn.DataParallel
?
https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html
Could you try using
DistributedDataParallel
fromtorch.nn.parallel.distributed
instead ofnn.DataParallel
?https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html
what should i set for master_addr and master_port?
@aminaqi Could you try taking a look at https://pytorch.org/tutorials/intermediate/ddp_tutorial.html?
If you are running on a single host, then using torchrun
might be the simplest option.
🐛 Describe the bug
code:
error ` RuntimeError Traceback (most recent call last) Cell In[2], line 6 4 for epoch in range(1, NUM_EPOCHS+1): 5 start_time = timer() ----> 6 train_loss, train_acc = train_epoch(transformer, optimizer) 7 end_time = timer() 8 val_loss, val_acc = evaluate(transformer)
Cell In[1], line 278, in train_epoch(model, optimizer) 274 tgt_input = tgt[:-1, :] 276 src_mask, tgt_mask, src_padding_mask, tgt_padding_mask = create_mask(src, tgt_input) --> 278 logits = model(src, tgt_input, src_mask, tgt_mask,src_padding_mask, tgt_padding_mask, src_padding_mask) 280 optimizer.zero_grad() 282 tgt_out = tgt[1:, :]
File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, *kwargs) 1496 # If we don't have any hooks, we want to skip the rest of the logic in 1497 # this function, and just call forward. 1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks 1499 or _global_backward_pre_hooks or _global_backward_hooks 1500 or _global_forward_hooks or _global_forward_pre_hooks): -> 1501 return forward_call(args, **kwargs) 1502 # Do not call functions when jit is used 1503 full_backward_hooks, non_full_backward_hooks = [], []
File /opt/conda/lib/python3.10/site-packages/torch/nn/parallel/data_parallel.py:171, in DataParallel.forward(self, *inputs, *kwargs) 169 return self.module(inputs[0], **kwargs[0]) 170 replicas = self.replicate(self.module, self.device_ids[:len(inputs)]) --> 171 outputs = self.parallel_apply(replicas, inputs, kwargs) 172 return self.gather(outputs, self.output_device)
File /opt/conda/lib/python3.10/site-packages/torch/nn/parallel/data_parallel.py:181, in DataParallel.parallel_apply(self, replicas, inputs, kwargs) 180 def parallel_apply(self, replicas, inputs, kwargs): --> 181 return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])
File /opt/conda/lib/python3.10/site-packages/torch/nn/parallel/parallel_apply.py:89, in parallel_apply(modules, inputs, kwargs_tup, devices) 87 output = results[i] 88 if isinstance(output, ExceptionWrapper): ---> 89 output.reraise() 90 outputs.append(output) 91 return outputs
File /opt/conda/lib/python3.10/site-packages/torch/_utils.py:644, in ExceptionWrapper.reraise(self) 640 except TypeError: 641 # If the exception takes multiple arguments, don't try to 642 # instantiate since we don't know how to 643 raise RuntimeError(msg) from None --> 644 raise exception
RuntimeError: Caught RuntimeError in replica 0 on device 0. Original Traceback (most recent call last): File "/opt/conda/lib/python3.10/site-packages/torch/nn/parallel/parallel_apply.py", line 64, in _worker output = module(*input, kwargs) File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl return forward_call(*args, *kwargs) File "/tmp/ipykernel_1768/3569133796.py", line 161, in forward outs = self.transformer(src_emb, tgt_emb, src_mask, tgt_mask, None, File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl return forward_call(args, kwargs) File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/transformer.py", line 145, in forward memory = self.encoder(src, mask=src_mask, src_key_padding_mask=src_key_padding_mask) File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl return forward_call(*args, kwargs) File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/transformer.py", line 315, in forward output = mod(output, src_mask=mask, is_causal=is_causal, src_key_padding_mask=src_key_padding_mask_for_layers) File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl return forward_call(*args, *kwargs) File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/transformer.py", line 591, in forward x = self.norm1(x + self._sa_block(x, src_mask, src_key_padding_mask, is_causal=is_causal)) File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/transformer.py", line 599, in _sa_block x = self.self_attn(x, x, x, File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl return forward_call(args, kwargs) File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/activation.py", line 1205, in forward attn_output, attn_output_weights = F.multi_head_attention_forward( File "/opt/conda/lib/python3.10/site-packages/torch/nn/functional.py", line 5251, in multi_head_attention_forward raise RuntimeError(f"The shape of the 2D attn_mask is {attn_mask.shape}, but should be {correct_2d_size}.") RuntimeError: The shape of the 2D attn_mask is torch.Size([21, 41]), but should be (21, 21). `
Versions
pytorch 2.0.1 cuda 11.7
cc @mrshenli @pritamdamania87 @zhaojuanmao @satgera @rohan-varma @gqchen @aazzolini @osalpekar @jiayisuse @H-Huang @kwen2501 @awgu @penguinwu @fegin @XilunWu @wanchaol @fduwjj @wz337 @tianyu-l @wconstab @yf225