codeKgu / Text-GCN

A PyTorch implementation of "Graph Convolutional Networks for Text Classification." (AAAI 2019)
MIT License
123 stars 25 forks source link

Error in model_text_gnn.py when using the expand method #11

Open youb1997 opened 1 year ago

youb1997 commented 1 year ago

I encountered an error in the code when running the program, specifically in the model_text_gnn.py file. The error message I received is as follows: Traceback (most recent call last): File "main.py", line 44, in main() File "main.py", line 20, in main saved_model, model = train(train_data, val_data, saver) File "C:\Users\hm\Downloads\Text-GCN-master\Text-GCN-master\train.py", line 25, in train loss, preds_train = model(pyg_graph, train_data) File "C:\Users\hm\anaconda3\lib\site-packages\torch\nn\modules\module.py", line 1190, in _call_impl return forward_call(*input, kwargs) File "C:\Users\hm\Downloads\Text-GCN-master\Text-GCN-master\model_text_gnn.py", line 36, in forward outs = layer(ins, pyg_graph) File "C:\Users\hm\anaconda3\lib\site-packages\torch\nn\modules\module.py", line 1190, in _call_impl return forward_call(*input, *kwargs) File "C:\Users\hm\Downloads\Text-GCN-master\Text-GCN-master\model_text_gnn.py", line 106, in forward x = self.conv(ins, pyg_graph.edge_index) File "C:\Users\hm\anaconda3\lib\site-packages\torch\nn\modules\module.py", line 1190, in _call_impl return forward_call(input, kwargs) File "C:\Users\hm\Downloads\Text-GCN-master\Text-GCN-master\model_text_gnn.py", line 258, in forward edge_index, norm = GCNConv.norm(edge_index, x.size(0), edge_weight, File "C:\Users\hm\Downloads\Text-GCN-master\Text-GCN-master\model_text_gnn.py", line 244, in norm deg = scatter_add(edge_weight, row, dim=0, dim_size=num_nodes) File "C:\Users\hm\anaconda3\lib\site-packages\torch_scatter\scatter.py", line 29, in scatter_add return scatter_sum(src, index, dim, out, dim_size) File "C:\Users\hm\anaconda3\lib\site-packages\torch_scatter\scatter.py", line 11, in scatter_sum index = broadcast(index, src, dim) File "C:\Users\hm\anaconda3\lib\site-packages\torch_scatter\utils.py", line 12, in broadcast src = src.expand(other.size()) RuntimeError: expand(torch.LongTensor{[2, 30947]}, size=[30947]): the number of sizes provided (1) must be greater or equal to the number of dimensions in the tensor (2)

Kaushal-13 commented 11 months ago

Hey there, could you try changing the row,col = edge_index declaration to row,col = edge_index[0] I think this would work. on line 243

gkysaad commented 7 months ago

You're likely using a newer torch and torch-scatter version than the one that was used when it was implemented. Using the specified versions fixed this issue for me.