nlpyang / BertSum

Code for paper Fine-tune BERT for Extractive Summarization
Apache License 2.0
1.47k stars 423 forks source link

expected mask dtype to be Bool but got Long #130

Open haidequanbu opened 2 years ago

haidequanbu commented 2 years ago

Hi.when i am tring to reproduce the results using transformers,i got a problem. Traceback (most recent call last): File "train.py", line 341, in train(args, device_id) File "train.py", line 273, in train trainer.train(train_iter_fct, args.train_steps) File "/root/code/BertSum/src/models/trainer.py", line 155, in train self._gradient_accumulation( File "/root/code/BertSum/src/models/trainer.py", line 321, in _gradient_accumulation sent_scores, mask = self.model(src, segs, clss, mask, mask_cls) File "/root/miniconda3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl return forward_call(*input, kwargs) File "/root/code/BertSum/src/models/model_builder.py", line 96, in forward sent_scores = self.encoder(sents_vec, mask_cls).squeeze(-1) File "/root/miniconda3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl return forward_call(*input, *kwargs) File "/root/code/BertSum/src/models/encoder.py", line 97, in forward x = self.transformer_inter[i](i, x, x, 1 - mask) # all_sents max_tokens dim File "/root/miniconda3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl return forward_call(input, kwargs) File "/root/code/BertSum/src/models/encoder.py", line 68, in forward context = self.self_attn(input_norm, input_norm, input_norm, File "/root/miniconda3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl return forward_call(*input, **kwargs) File "/root/code/BertSum/src/models/neural.py", line 210, in forward scores = scores.masked_fill(mask, -1e18) RuntimeError: expected mask dtype to be Bool but got Long cuda:11 torch:1.10.1 Does it has any problem with my envirenment? Thanks!