Open HHHedo opened 5 years ago
Hi, there. Have you solved the compatibility of Pytorch?
Use torch.reshape
out = torch.bmm(weight, torch.reshape(x, (batchSize, inputSize, 1)))
It seems that robert's suggestions can work, but gradOutput.data.resize(batchSize, 1, K+1) in around line 60 should be changed to gradOutput.resize(batchSize, 1, K + 1).
My pytorch is 1.1.0. x.data.resize(batchSize, inputSize, 1)_ in NCEAverage.py doesn't work. _with torch.nograd(): x.resize(batchSize, inputSize, 1) still doesn't work. I replaced it as x.unsqueeze(2))_ . I wonder if this function works the same as x.data.resize(batchSize, inputSize, 1)_