FighterLYL / GraphNeuralNetwork

《深入浅出图神经网络:GNN原理解析》配套代码
1.7k stars 457 forks source link

关于第八章图分类实战的代码运行 #31

Closed Zengyf-CVer closed 4 years ago

Zengyf-CVer commented 4 years ago

您好,我在运行第八章关于SAGPool的代码时,在运行dataset = DDDatataset()时出现以下问题: Loading DD_A.txt Loading DD_node_labels.txt Loading DD_graph_indicator.txt Loading DD_graph_labels.txt Number of nodes: 334925 Traceback (most recent call last): File "self_attn_pool.py", line 347, in dataset = DDDataset() File "self_attn_pool.py", line 83, in init self.train_index, self.test_index = self.split_data(train_size) File "self_attn_pool.py", line 91, in split_data random_state=1234) File "/home/zyf/anaconda3/envs/pt36/lib/python3.6/site-packages/sklearn/model_selection/_split.py", line 2096, in train_test_split arrays = indexable(arrays) File "/home/zyf/anaconda3/envs/pt36/lib/python3.6/site-packages/sklearn/utils/validation.py", line 230, in indexable check_consistent_length(result) File "/home/zyf/anaconda3/envs/pt36/lib/python3.6/site-packages/sklearn/utils/validation.py", line 201, in check_consistent_length lengths = [_num_samples(X) for X in arrays if X is not None] File "/home/zyf/anaconda3/envs/pt36/lib/python3.6/site-packages/sklearn/utils/validation.py", line 201, in lengths = [_num_samples(X) for X in arrays if X is not None] File "/home/zyf/anaconda3/envs/pt36/lib/python3.6/site-packages/sklearn/utils/validation.py", line 146, in _num_samples " a valid collection." % x) TypeError: Singleton array array({0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27,(中间省略),1167, 1168, 1169, 1170, 1171, 1172, 1173, 1174, 1175, 1176, 1177}, dtype=object) cannot be considered a valid collection. 不知道是数据集的问题还是其他问题,请您帮我解决一下,谢谢。

FighterLYL commented 4 years ago

你使用的代码有个小bug,将这一行 https://github.com/FighterLYL/GraphNeuralNetwork/blob/master/chapter8/self_attn_pool.py#L88 改为

unique_indicator = np.asarray(list(set(self.graph_indicator)))
Zengyf-CVer commented 4 years ago

你使用的代码有个小bug,将这一行 https://github.com/FighterLYL/GraphNeuralNetwork/blob/master/chapter8/self_attn_pool.py#L88 改为

unique_indicator = np.asarray(list(set(self.graph_indicator)))

问题解决,非常感谢。