clovaai / FocusSeq2Seq

[EMNLP 2019] Mixture Content Selection for Diverse Sequence Generation (Question Generation / Abstractive Summarization)
https://arxiv.org/abs/1909.01953
MIT License
113 stars 20 forks source link

RuntimeError: expected device cuda:0 and dtype Byte but got device cuda:0 and dtype Bool #4

Closed riturajkunwar closed 4 years ago

riturajkunwar commented 4 years ago

On executing train.py, I am getting the following error:

Epoch [0/20] | Iteration [1345/1346] | NLL Loss : 3.776 | NLL Loss (running avg) : 3.881 | Focus Loss : 0.176 | Time taken: : 11.62 Epoch Done! It took 338.58s Evaluation start! /pytorch/aten/src/ATen/native/cuda/LegacyDefinitions.cpp:14: UserWarning: maskedfill received a mask with dtype torch.uint8, this behavior is now deprecated,please use a mask with dtype torch.bool instead. /pytorch/aten/src/ATen/native/cuda/LegacyDefinitions.cpp:14: UserWarning: maskedfill received a mask with dtype torch.uint8, this behavior is now deprecated,please use a mask with dtype torch.bool instead. /pytorch/aten/src/ATen/native/cuda/LegacyDefinitions.cpp:14: UserWarning: maskedfill received a mask with dtype torch.uint8, this behavior is now deprecated,please use a mask with dtype torch.bool instead. /pytorch/aten/src/ATen/native/cuda/LegacyDefinitions.cpp:14: UserWarning: maskedfill received a mask with dtype torch.uint8, this behavior is now deprecated,please use a mask with dtype torch.bool instead. /pytorch/aten/src/ATen/native/cuda/LegacyDefinitions.cpp:14: UserWarning: maskedfill received a mask with dtype torch.uint8, this behavior is now deprecated,please use a mask with dtype torch.bool instead. /pytorch/aten/src/ATen/native/cuda/LegacyDefinitions.cpp:14: UserWarning: maskedfill received a mask with dtype torch.uint8, this behavior is now deprecated,please use a mask with dtype torch.bool instead. /pytorch/aten/src/ATen/native/cuda/LegacyDefinitions.cpp:14: UserWarning: maskedfill received a mask with dtype torch.uint8, this behavior is now deprecated,please use a mask with dtype torch.bool instead. Traceback (most recent call last): File "train.py", line 440, in val_loader, model, epoch, config) File "/home/riturajk/my_notebook_env/Q_Gen_exp_1/FocusSeq2Seq/evaluate.py", line 146, in evaluate diversity_lambda=config.diversity_lambda) File "/home/riturajk/my_notebook_env/lib/python3.6/site-packages/torch/nn/modules/module.py", line 547, in call result = self.forward(*input, *kwargs) File "/home/riturajk/my_notebook_env/Q_Gen_exp_1/FocusSeq2Seq/models.py", line 272, in forward diversity_lambda=diversity_lambda) File "/home/riturajk/my_notebook_env/lib/python3.6/site-packages/torch/nn/modules/module.py", line 547, in call result = self.forward(input, **kwargs) File "/home/riturajk/my_notebook_env/Q_Gen_exp_1/FocusSeq2Seq/layers/decoder.py", line 472, in forward finished += generated_eos RuntimeError: expected device cuda:0 and dtype Byte but got device cuda:0 and dtype Bool

riturajkunwar commented 4 years ago

is it because i am using pytorch 1.2 ?

j-min commented 4 years ago

Probably. Try adding .byte() for generated_eos to match the datatype with finished (torch.uint8)

j-min commented 4 years ago

I just found that torch.bool is added since PyTorch 1.2. I updated the readme documented to specify PyTorch version 1.1

riturajkunwar commented 4 years ago

I downgraded torch from 1.2 to 1.1. Its working fine !!! Can you please tell me the combination of hyper-parameter which worked the best for you for question generation?

riturajkunwar commented 4 years ago

Instead f using the glove embedding, if we use ELMo or BERT word embedding do you think that is going to improve the results?

j-min commented 4 years ago

Follow the settings in our EMNLP paper, and use the default configuration if you can't find one in the paper.

j-min commented 4 years ago

I used glove embedding just because the original NQG++ model uses it. Since NQG++ is just a single-layer RNN model, I think ELMo / BERT representation would improve the results.

j-min commented 4 years ago

Also, please open a separate issue for different questions. Thanks :)