bjascob / amrlib

A python library that makes AMR parsing, generation and visualization simple.
MIT License
219 stars 34 forks source link

distributed training of BART #56

Closed haixpham closed 1 year ago

haixpham commented 2 years ago

Hello,

I tried to train Model_Parse_XFM with BART-base backbone using torch.distributed.launch, nproc_per_node == 2. The error occurs as following:

  File "/home/SERILOCAL/hai.xuanpham/amr_train/scripts/33_Model_Parse_XFM/20_Train_Model.py", line 22, in <module>
    trainer.train()
  File "/home/SERILOCAL/hai.xuanpham/anaconda3/lib/python3.9/site-packages/amrlib/models/parse_xfm/trainer.py", line 64, in train
    trainer.train(resume_from_checkpoint=self.training_args.resume_from_checkpoint)
  File "/home/SERILOCAL/hai.xuanpham/anaconda3/lib/python3.9/site-packages/transformers/trainer.py", line 1498, in train
    return inner_training_loop(
  File "/home/SERILOCAL/hai.xuanpham/anaconda3/lib/python3.9/site-packages/accelerate/utils/memory.py", line 79, in decorator
    return function(batch_size, *args, **kwargs)
  File "/home/SERILOCAL/hai.xuanpham/anaconda3/lib/python3.9/site-packages/transformers/trainer.py", line 1740, in _inner_training_loop
    tr_loss_step = self.training_step(model, inputs)
  File "/home/SERILOCAL/hai.xuanpham/anaconda3/lib/python3.9/site-packages/transformers/trainer.py", line 2470, in training_step
    loss = self.compute_loss(model, inputs)
  File "/home/SERILOCAL/hai.xuanpham/anaconda3/lib/python3.9/site-packages/transformers/trainer.py", line 2502, in compute_loss
    outputs = model(**inputs)
  File "/home/SERILOCAL/hai.xuanpham/anaconda3/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/SERILOCAL/hai.xuanpham/anaconda3/lib/python3.9/site-packages/torch/nn/parallel/distributed.py", line 1008, in forward
    output = self._run_ddp_forward(*inputs, **kwargs)
  File "/home/SERILOCAL/hai.xuanpham/anaconda3/lib/python3.9/site-packages/torch/nn/parallel/distributed.py", line 969, in _run_ddp_forward
    return module_to_run(*inputs[0], **kwargs[0])
  File "/home/SERILOCAL/hai.xuanpham/anaconda3/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/SERILOCAL/hai.xuanpham/anaconda3/lib/python3.9/site-packages/transformers/models/bart/modeling_bart.py", line 1353, in forward
    outputs = self.model(
  File "/home/SERILOCAL/hai.xuanpham/anaconda3/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/SERILOCAL/hai.xuanpham/anaconda3/lib/python3.9/site-packages/transformers/models/bart/modeling_bart.py", line 1222, in forward
    encoder_outputs = self.encoder(
  File "/home/SERILOCAL/hai.xuanpham/anaconda3/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/SERILOCAL/hai.xuanpham/anaconda3/lib/python3.9/site-packages/transformers/models/bart/modeling_bart.py", line 799, in forward
    inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
  File "/home/SERILOCAL/hai.xuanpham/anaconda3/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/SERILOCAL/hai.xuanpham/anaconda3/lib/python3.9/site-packages/torch/nn/modules/sparse.py", line 158, in forward
    return F.embedding(
  File "/home/SERILOCAL/hai.xuanpham/anaconda3/lib/python3.9/site-packages/torch/nn/functional.py", line 2199, in embedding
    return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cuda:1! (when checking argument for argument index in method wrapper__index_select)

It showed that the call to torch.embedding() caused the exception. I'm not sure what went wrong with BART in that regard! I tried to train T5 model, it went well. T5 and BART are mostly similar in terms of model archs.

Has anyone experienced that issue before?

bjascob commented 2 years ago

I've never tried distributed training and I'm not sure what's going on but it looks like it's something in the transformers lib itself. Everything below trainer.train() is transformers lib code. Maybe there's something in the trainer config arguments that needs to be setup different? You might look through the hf_args section of the config file and compare it to a distributed example in the transformers lib. I'd also check the transformers bug list to see if anyone has had issues training bart with distributed.

haixpham commented 2 years ago

Indeed the error originated from transformers. I will head over there to look for the cause

bjascob commented 1 year ago

Closing. No activity on issue for 1 month.