Open youb1997 opened 1 year ago
Hey there, could you try changing the row,col = edge_index declaration to row,col = edge_index[0] I think this would work. on line 243
You're likely using a newer torch and torch-scatter version than the one that was used when it was implemented. Using the specified versions fixed this issue for me.
I encountered an error in the code when running the program, specifically in the model_text_gnn.py file. The error message I received is as follows: Traceback (most recent call last): File "main.py", line 44, in
main()
File "main.py", line 20, in main
saved_model, model = train(train_data, val_data, saver)
File "C:\Users\hm\Downloads\Text-GCN-master\Text-GCN-master\train.py", line 25, in train
loss, preds_train = model(pyg_graph, train_data)
File "C:\Users\hm\anaconda3\lib\site-packages\torch\nn\modules\module.py", line 1190, in _call_impl
return forward_call(*input, kwargs)
File "C:\Users\hm\Downloads\Text-GCN-master\Text-GCN-master\model_text_gnn.py", line 36, in forward
outs = layer(ins, pyg_graph)
File "C:\Users\hm\anaconda3\lib\site-packages\torch\nn\modules\module.py", line 1190, in _call_impl
return forward_call(*input, *kwargs)
File "C:\Users\hm\Downloads\Text-GCN-master\Text-GCN-master\model_text_gnn.py", line 106, in forward
x = self.conv(ins, pyg_graph.edge_index)
File "C:\Users\hm\anaconda3\lib\site-packages\torch\nn\modules\module.py", line 1190, in _call_impl
return forward_call(input, kwargs)
File "C:\Users\hm\Downloads\Text-GCN-master\Text-GCN-master\model_text_gnn.py", line 258, in forward
edge_index, norm = GCNConv.norm(edge_index, x.size(0), edge_weight,
File "C:\Users\hm\Downloads\Text-GCN-master\Text-GCN-master\model_text_gnn.py", line 244, in norm
deg = scatter_add(edge_weight, row, dim=0, dim_size=num_nodes)
File "C:\Users\hm\anaconda3\lib\site-packages\torch_scatter\scatter.py", line 29, in scatter_add
return scatter_sum(src, index, dim, out, dim_size)
File "C:\Users\hm\anaconda3\lib\site-packages\torch_scatter\scatter.py", line 11, in scatter_sum
index = broadcast(index, src, dim)
File "C:\Users\hm\anaconda3\lib\site-packages\torch_scatter\utils.py", line 12, in broadcast
src = src.expand(other.size())
RuntimeError: expand(torch.LongTensor{[2, 30947]}, size=[30947]): the number of sizes provided (1) must be greater or equal to the number of dimensions in the tensor (2)