lonicera-yx / MGT

MIT License
35 stars 7 forks source link

MGT.py里的代码报错 #3

Open shuqincao opened 2 years ago

shuqincao commented 2 years ago

MGT.py文件里的第87,88行“inputs_extras_embedding = torch.cat([self.embedding_modulesi for i in range(len(self.num_embeddings))] + [inputs_pe], dim=-1)” 代码报错,报错详情如下:

---------- Training ---------- num_samples: 1188, num_batches: 594 0%| | 0/594 [00:00<?, ?it/s]Traceback (most recent call last): File "F:/orginalCode/MGT-main/main.py", line 236, in train(args, logger) File "F:/orginalCode/MGT-main/main.py", line 156, in train criterion, optimizer, scheduler, args) File "F:/orginalCode/MGT-main/main.py", line 91, in train_epoch outputs = model(inputs, targets, *extras, statics) File "D:\programfiles\Anaconda3\lib\site-packages\torch\nn\modules\module.py", line 727, in _call_impl result = self.forward(*input, *kwargs) File "F:\orginalCode\MGT-main\models\MGT.py", line 503, in forward z_inputs, z_targets = self.temporal_embedding(extras) # (B, P, d_model), (B, Q, d_model) File "D:\programfiles\Anaconda3\lib\site-packages\torch\nn\modules\module.py", line 727, in _call_impl result = self.forward(input, kwargs) File "F:\orginalCode\MGT-main\models\MGT.py", line 88, in forward for i in range(len(self.num_embeddings))] + [inputs_pe], dim=-1) File "F:\orginalCode\MGT-main\models\MGT.py", line 88, in for i in range(len(self.num_embeddings))] + [inputs_pe], dim=-1) File "D:\programfiles\Anaconda3\lib\site-packages\torch\nn\modules\module.py", line 727, in _call_impl result = self.forward(*input, **kwargs) File "D:\programfiles\Anaconda3\lib\site-packages\torch\nn\modules\sparse.py", line 126, in forward self.norm_type, self.scale_grad_by_freq, self.sparse) File "D:\programfiles\Anaconda3\lib\site-packages\torch\nn\functional.py", line 1852, in embedding return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse) TypeError: embedding(): argument 'indices' (position 2) must be Tensor, not list 0%| | 0/594 [00:13<?, ?it/s]

Process finished with exit code 1

lonicera-yx commented 2 years ago

我更新了requirements.txt,给出了各package的版本号,你参考一下,我在自己的机器上跑是没问题的。

shuqincao commented 2 years ago

这个代码我还是跑不起来,我可以加你的微信吗?我的1033898863