It seems there is a device mismatch at these lines, (one) and (two).
To Reproduce
During the training time, when calling the sample_negatives() function, there is a device mismatch that occurs when the torch.randint()(link) and the new tszs(link) using the same device.
And it seems the problem can be solved using the device= tags.
🐛 Bug
It seems there is a device mismatch at these lines, (one) and (two).
To Reproduce
During the training time, when calling the
sample_negatives()
function, there is a device mismatch that occurs when thetorch.randint()
(link) and the newtszs
(link) using the same device. And it seems the problem can be solved using thedevice=
tags.Regards