INK-USC / RE-Net

Recurrent Event Network: Autoregressive Structure Inference over Temporal Knowledge Graphs (EMNLP 2020)
http://inklab.usc.edu/renet/
436 stars 95 forks source link

RuntimeError: Expected object of scalar type Long but got scalar type Int for argument #2 'target' in call to _thnn_nll_loss_forward #68

Open SunforYou0 opened 1 year ago

SunforYou0 commented 1 year ago

Error information: Using backend: pytorch Namespace(backup=False, batch_size=128, dataset='YAGO', dropout=0.5, gpu=0, grad_norm=1.0, lr=0.001, max_epochs=20, maxpool=1, model=0, n_hidden=200, num_k=1000, raw=False, rnn_layers=1, seq_len=10, valid=False, valid_every=1) start training... D:\MyProjects\PycharmProjects\RE-Net-master\RE-Net-master\utils.py:214: UserWarning: This overload of nonzero is deprecated: nonzero(Tensor input, , Tensor out) Consider using one of the following signatures instead: nonzero(Tensor input, , bool as_tuple) (Triggered internally at ..\torch\csrc\utils\python_arg_parser.cpp:766.) num_non_zero = len(torch.nonzero(s_len)) Epoch 0001 | Loss 8.2041 | time 1205.2197 Epoch 0002 | Loss 3.9099 | time 1204.5098 Epoch 0003 | Loss 3.3134 | time 1195.3128 Epoch 0004 | Loss 3.0940 | time 1160.4676 Epoch 0005 | Loss 2.9595 | time 1142.0331 Epoch 0006 | Loss 2.8683 | time 1171.5391 Epoch 0007 | Loss 2.8052 | time 1289.8578 Epoch 0008 | Loss 2.7475 | time 1266.8438 Epoch 0009 | Loss 2.7018 | time 1227.6809 Epoch 0010 | Loss 2.6605 | time 1246.5386 Traceback (most recent call last): File "D:\AnacondaEnv\renet\lib\code.py", line 91, in runcode exec(code, self.locals) File "", line 1, in File "D:\ProgramFiles\PyCharm 2020.3.4\plugins\python\helpers\pydev_pydev_bundle\pydev_umd.py", line 198, in runfile pydev_imports.execfile(filename, global_vars, local_vars) # execute the script File "D:\ProgramFiles\PyCharm 2020.3.4\plugins\python\helpers\pydev_pydev_imps_pydev_execfile.py", line 18, in execfile exec(compile(contents+"\n", file, 'exec'), glob, loc) File "D:/MyProjects/PycharmProjects/RE-Net-master/RE-Net-master/train.py", line 241, in train(args) File "D:/MyProjects/PycharmProjects/RE-Net-master/RE-Net-master/train.py", line 172, in train ranks, loss = model.evaluate_filter(batch_data, (s_hist, s_hist_t), (o_hist, o_hist_t), global_model, total_data) File "D:\MyProjects\PycharmProjects\RE-Net-master\RE-Net-master\model.py", line 388, in evaluate_filter loss, sub_pred, ob_pred = self.predict(triplet, s_hist, o_hist, global_model) File "D:\MyProjects\PycharmProjects\RE-Net-master\RE-Net-master\model.py", line 358, in predict loss_sub = self.criterion(ob_pred.view(1, -1), o.view(-1)) File "D:\AnacondaEnv\renet\lib\site-packages\torch\nn\modules\module.py", line 722, in _call_impl result = self.forward(*input, **kwargs) File "D:\AnacondaEnv\renet\lib\site-packages\torch\nn\modules\loss.py", line 948, in forward ignore_index=self.ignore_index, reduction=self.reduction) File "D:\AnacondaEnv\renet\lib\site-packages\torch\nn\functional.py", line 2422, in cross_entropy return nll_loss(log_softmax(input, 1), target, weight, None, ignore_index, None, reduction) File "D:\AnacondaEnv\renet\lib\site-packages\torch\nn\functional.py", line 2218, in nll_loss ret = torch._C._nn.nll_loss(input, target, weight, _Reduction.get_enum(reduction), ignore_index) RuntimeError: Expected object of scalar type Long but got scalar type Int for argument #2 'target' in call to _thnn_nll_loss_forward

GreenerZ commented 1 year ago

I made the following changes, and got the training result. but i reduced some parameters cuzz my hardware, so got different result to the paper and i am not sure this is corrent fix.

1.change file “model.py” line 358

“loss_sub = self.criterion(ob_pred.view(1, -1), o.view(-1))” to "loss_sub = self.criterion(ob_pred.view(1, -1), o.view(-1)).type(torch.cuda.LongTensor)" same change to line 359

2.after change 1, i got another error called "tensors used as indices must be long, byte or bool tensors". i made another change.

change file "model.py" line 399 "idx = all_triplets[idx, 2]" to "idx = all_triplets[idx, 2].type(torch.LongTensor)" same change to line 412