abertsch72 / unlimiformer

Public repo for the NeurIPS 2023 paper "Unlimiformer: Long-Range Transformers with Unlimited Length Input"
MIT License
1.05k stars 77 forks source link

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, .... #49

Closed shi-kejian closed 10 months ago

shi-kejian commented 11 months ago

Hi,

Thank you for this great effort.

I'm running into an issue with multi-gpu training. Here's my entry command.

  1. I'm using local data files.
  2. The base_training_args is the default one.

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cuda:1! (when checking argument for argument index in method wrapper_CUDA__index_select)

python src/run.py \
    src/configs/training/base_training_args.json \
    --model_name_or_path facebook/bart-large \
    --train_file ...\
    --validation_file ...\
    --test_file ...\
    --input_column ...\
    --input_prefix_column ... \
    --output_column ...\
    --overwrite_cache \
    --output_dir... \
    --overwrite_output_dir \
    --max_source_length 1024 \
    --eval_max_source_length 999999 \
    --generation_max_length 640 \
    --max_target_length 640 \
    --max_prefix_length 96 \
    --pad_prefix=True \
    --do_eval=True \
    --learning_rate 1e-5 \
    --per_device_eval_batch_size 1 \
    --per_device_train_batch_size 2 \
    --unlimiformer_training=True \
    --test_unlimiformer \
    --eval_steps 30 --save_steps 30 \
    --num_train_epochs 10 \
    --metric_names rouge \
    --extra_metrics bertscore \
    --metric_for_best_model bertscore \

The error arises in the forward pass, File "/ext3/miniconda3/lib/python3.11/site-packages/transformers/models/bart/modeling_bart.py", line 810, in forward inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale

And the error is: RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cuda:1! (when checking argument for argument index in method wrapper_CUDA__index_select)

Could you please give some clues on where to look at for debugging? I don't think this is related to custom datasets itself. I'm aware issues could be traced to index, datastore, batching, ... The nature of this work has complexity on this, and unfortunately I really have limited knowledge.

Thank you very much!

Attached is a full stack trace:

Traceback (most recent call last): File "unlimiformer/src/run.py", line 1183, in main() File "unlimiformer/src/run.py", line 803, in main train_result = trainer.train(resume_from_checkpoint=checkpoint) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/ext3/miniconda3/lib/python3.11/site-packages/transformers/trainer.py", line 1539, in train return inner_training_loop( ^^^^^^^^^^^^^^^^^^^^ File "/ext3/miniconda3/lib/python3.11/site-packages/transformers/trainer.py", line 1809, in _inner_training_loop tr_loss_step = self.training_step(model, inputs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/ext3/miniconda3/lib/python3.11/site-packages/transformers/trainer.py", line 2654, in training_step loss = self.compute_loss(model, inputs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/ext3/miniconda3/lib/python3.11/site-packages/transformers/trainer.py", line 2679, in compute_loss outputs = model(inputs) ^^^^^^^^^^^^^^^ File "/ext3/miniconda3/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl return self._call_impl(*args, *kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/ext3/miniconda3/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl return forward_call(args, kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/ext3/miniconda3/lib/python3.11/site-packages/torch/nn/parallel/data_parallel.py", line 185, in forward outputs = self.parallel_apply(replicas, inputs, module_kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/ext3/miniconda3/lib/python3.11/site-packages/torch/nn/parallel/data_parallel.py", line 200, in parallel_apply return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)]) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/ext3/miniconda3/lib/python3.11/site-packages/torch/nn/parallel/parallel_apply.py", line 110, in parallel_apply output.reraise() File "/ext3/miniconda3/lib/python3.11/site-packages/torch/_utils.py", line 693, in reraise raise exception RuntimeError: Caught RuntimeError in replica 1 on device 1. Original Traceback (most recent call last): File "/ext3/miniconda3/lib/python3.11/site-packages/torch/nn/parallel/parallel_apply.py", line 85, in _worker output = module(*input, kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^ File "/ext3/miniconda3/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl return self._call_impl(*args, *kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/ext3/miniconda3/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl return forward_call(args, kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/scratch/ks4765/research/unlimiformer_ODMDS/src/unlimiformer.py", line 551, in pre_forward_hook result = self.original_forward_func(input_ids=input_ids, labels=labels, attention_mask=attention_mask, kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/ext3/miniconda3/lib/python3.11/site-packages/transformers/models/bart/modeling_bart.py", line 1380, in forward outputs = self.model( ^^^^^^^^^^^ File "/ext3/miniconda3/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl return self._call_impl(*args, *kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/ext3/miniconda3/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl return forward_call(args, kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/ext3/miniconda3/lib/python3.11/site-packages/transformers/models/bart/modeling_bart.py", line 1248, in forward encoder_outputs = self.encoder( ^^^^^^^^^^^^^ File "/ext3/miniconda3/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl return self._call_impl(*args, kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/ext3/miniconda3/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl return forward_call(*args, kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/ext3/miniconda3/lib/python3.11/site-packages/transformers/models/bart/modeling_bart.py", line 810, in forward inputs_embeds = self.embed_tokens(input_ids) self.embed_scale ^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/ext3/miniconda3/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl return self._call_impl(args, kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/ext3/miniconda3/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl return forward_call(*args, kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/ext3/miniconda3/lib/python3.11/site-packages/torch/nn/modules/sparse.py", line 162, in forward return F.embedding( ^^^^^^^^^^^^ File "/ext3/miniconda3/lib/python3.11/site-packages/torch/nn/functional.py", line 2235, in embedding return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)