utterworks / fast-bert

Super easy library for BERT based NLP models
Apache License 2.0
1.86k stars 341 forks source link

Error when using multi_gpu = True #214

Open nectario opened 4 years ago

nectario commented 4 years ago

I get the following error when I use multi_gpu=True:

[34mOriginal Traceback (most recent call last): File "/opt/conda/lib/python3.7/site-packages/torch/nn/parallel/parallel_apply.py", line 60, in _worker output = module(*input, **kwargs) File "/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py", line 532, in __call__ result = self.forward(*input, **kwargs) File "/opt/conda/lib/python3.7/site-packages/apex-0.1-py3.7-linux-x86_64.egg/apex/amp/_initialize.py", line 197, in new_fwd **applier(kwargs, input_caster)) File "/opt/conda/lib/python3.7/site-packages/fast_bert/modeling.py", line 116, in forward head_mask=head_mask, File "/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py", line 532, in __call__ result = self.forward(*input, **kwargs) File "/opt/conda/lib/python3.7/site-packages/transformers/modeling_bert.py", line 783, in forward input_ids=input_ids, position_ids=position_ids, token_type_ids=token_type_ids, inputs_embeds=inputs_embeds File "/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py", line 532, in __call__ result = self.forward(*input, **kwargs) File "/opt/conda/lib/python3.7/site-packages/transformers/modeling_roberta.py", line 65, in forward input_ids, token_type_ids=token_type_ids, position_ids=position_ids, inputs_embeds=inputs_embeds File "/opt/conda/lib/python3.7/site-packages/transformers/modeling_bert.py", line 173, in forward inputs_embeds = self.word_embeddings(input_ids) File "/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py", line 532, in __call__ result = self.forward(*input, **kwargs) File "/opt/conda/lib/python3.7/site-packages/torch/nn/modules/sparse.py", line 114, in forward self.norm_type, self.scale_grad_by_freq, self.sparse) File "/opt/conda/lib/python3.7/site-packages/torch/nn/functional.py", line 1484, in embedding return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)[0m [34mRuntimeError: arguments are located on different GPUs at /pytorch/aten/src/THC/generic/THCTensorIndex.cu:400