memray / OpenNMT-kpg-release

Keyphrase Generation
MIT License
216 stars 34 forks source link

RuntimeError: Expected object of scalar type Byte but got scalar type Bool for argument #3 'other' in call to _th_ior_ #14

Closed arikanev closed 4 years ago

arikanev commented 4 years ago

Hi - When running the eval script presented in the README I get the above error -stacktrace below.

var = torch.tensor(arr, dtype=self.dtype, device=device) Translating 32/308 Traceback (most recent call last): File "kp_gen_eval.py", line 164, in <module> opt=opt File "/home/ari/DKG/OpenNMT-kpg-release/onmt/translate/translator.py", line 372, in translate batch, data.src_vocabs, attn_debug File "/home/ari/DKG/OpenNMT-kpg-release/onmt/translate/translator.py", line 593, in translate_batch return_attention=attn_debug or self.replace_unk) File "/home/ari/DKG/OpenNMT-kpg-release/onmt/translate/translator.py", line 766, in _translate_batch beam.update_finished(last_step=(step+1==max_length)) File "/home/ari/DKG/OpenNMT-kpg-release/onmt/translate/beam_search.py", line 217, in update_finished self.top_beam_finished |= self.is_finished[:, 0].eq(1) RuntimeError: Expected object of scalar type Byte but got scalar type Bool for argument #3 'other' in call to _th_ior_

memray commented 4 years ago

Ahh, I don't know where this comes from. I suspect it is due to version difference of Pytorch. Could you try checking the data type of those tensors and making them consistent? I hope this helps.

Rui

arikanev commented 4 years ago

Yeah thanks - So one of them is a tensor of Bools, the other is of type uint8. I'm thinking I should convert the Falses to 0, and the Trues to 1?

arikanev commented 4 years ago

changing the Bool tensor to Byte to .byte() seems to have worked - at least it's running - not sure if it messes the model up somehow. Will close the issue for now since it runs. thanks