Closed Hansyvea closed 3 years ago
Let me check and reproduce the error
Let me check and reproduce the error
have ya solved it yet?
Sorry. I didn't run into any issues using PyTorch 1.8.1 and Cuda 10.2.
What CUDA version are you using? also your detailed environments?
Or you can make a .long()
to orig_to_token_index.unsqueeze(-1).expand(batch_size, max_sent_len, rep_size).long()
see if that solves your issue
Sorry. I didn't run into any issues using PyTorch 1.8.1 and Cuda 10.2.
What CUDA version are you using? also your detailed environments?
Or you can make a
.long()
toorig_to_token_index.unsqueeze(-1).expand(batch_size, max_sent_len, rep_size).long()
see if that solves your issue
thanks for your reply! I am using Cuda 11.2 and the other packages are all up-to-date, so I assume the problem is caused by operating system since I have been running it on Windows
return torch.gather(word_rep[:, 1:, :], 1, orig_to_token_index.unsqueeze(-1).expand(batch_size, max_sent_len, rep_size)) RuntimeError: gather_out_cuda(): Expected dtype int64 for index