goodbai-nlp / AMRBART

Code for our paper "Graph Pre-training for AMR Parsing and Generation" in ACL2022
MIT License
92 stars 28 forks source link

InPlace Operation Error #16

Closed keenjo closed 1 year ago

keenjo commented 1 year ago

Hello,

I'm trying to reproduce your pretraining script, run_multitask_unified_pretraining.py and I have been struggling with the following error:

/home/jkeenan/.conda/envs/amrbart/lib/python3.8/site-packages/torch/autograd/__init__.py:145: UserWarning: Error detected in NllLossBackward. Traceback of forward call that caused the error:
  File "training.py", line 1320, in <module>
    main()
  File "training.py", line 1257, in main
    global_step, tr_loss = train(
  File "training.py", line 399, in train
    outputs = model(
  File "/home/jkeenan/.conda/envs/amrbart/lib/python3.8/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/home/jkeenan/amrbart/pre-train/model_interface/modeling_bart.py", line 1375, in forward
    masked_lm_loss = loss_fct(lm_logits.view(-1, self.config.vocab_size), labels.view(-1))
  File "/home/jkeenan/.conda/envs/amrbart/lib/python3.8/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/home/jkeenan/.conda/envs/amrbart/lib/python3.8/site-packages/torch/nn/modules/loss.py", line 1047, in forward
    return F.cross_entropy(input, target, weight=self.weight,
  File "/home/jkeenan/.conda/envs/amrbart/lib/python3.8/site-packages/torch/nn/functional.py", line 2693, in cross_entropy
    return nll_loss(log_softmax(input, 1), target, weight, None, ignore_index, None, reduction)
  File "/home/jkeenan/.conda/envs/amrbart/lib/python3.8/site-packages/torch/nn/functional.py", line 2388, in nll_loss
    ret = torch._C._nn.nll_loss(input, target, weight, _Reduction.get_enum(reduction), ignore_index)
 (Triggered internally at  /opt/conda/conda-bld/pytorch_1616554793803/work/torch/csrc/autograd/python_anomaly_mode.cpp:104.)
  Variable._execution_engine.run_backward(
Iteration:   0%|          | 0/50 [00:01<?, ?it/s, lm_loss=16.9, lr=0]
Epoch:   0%|          | 0/4001 [00:01<?, ?it/s]
Traceback (most recent call last):
  File "training.py", line 1320, in <module>
    main()
  File "training.py", line 1257, in main
    global_step, tr_loss = train(
  File "training.py", line 529, in train
    loss.backward()
  File "/home/jkeenan/.conda/envs/amrbart/lib/python3.8/site-packages/torch/tensor.py", line 245, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)
  File "/home/jkeenan/.conda/envs/amrbart/lib/python3.8/site-packages/torch/autograd/__init__.py", line 145, in backward
    Variable._execution_engine.run_backward(
RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.LongTensor [42]] is at version 7; expected version 5 instead. Hint: the backtrace further above shows the operation that failed to compute its gradient. The variable in question was changed in there or anywhere later. Good luck!

I've installed all of the packages into my own conda environment using your requirements.yml file.

The error seems to be tied to the labels and the fact that we are doing multitask training here, as I have been able to get the script to run without any issues if I only define one task. I have tried many different options to solve the issue but nothing seems to be working. Do you have any ideas?

keenjo commented 1 year ago

It seems like in the original masking functions the labels are cloned when created from the input ids: As in line 668 of pre-train/common/utils.py -> labels = batch["input_ids"].clone()

However, they are simply renamed when the labels are created from the AMR graphs, which I believe is always called batch['labels']: Such as in line 674 of pre-train/common/utils.py -> labels = batch['labels']

By cloning the AMR graphs as well when they are the labels I seemed to have solved the issue: So anytime I saw what was written in line 674 of pre-train/common/utils.py, I changed it to the following labels = batch['labels'].clone()