FighterLYL / GraphNeuralNetwork

《深入浅出图神经网络:GNN原理解析》配套代码
1.72k stars 459 forks source link

GraphSage中multihop_sampling代码 #5

Closed Jerry992 closed 4 years ago

Jerry992 commented 4 years ago

我将main函数中的train部分改为 def train(): model.train() for e in range(EPOCHS): for batch in range(NUM_BATCH_PER_EPOCH): batch_src_index = np.random.choice(train_index, size=(BTACH_SIZE,)) r=batch_src_index

t=len(r)

        print("batch_src_index ")
        print(r)#添加部分,显示1---------------------------------------------------显示1
        batch_src_label = torch.from_numpy(train_label[batch_src_index]).long().to(DEVICE)
        batch_sampling_result = multihop_sampling(batch_src_index, NUM_NEIGHBORS_LIST, data.adjacency_dict)
        batch_sampling_x = [torch.from_numpy(x[idx]).float().to(DEVICE) for idx in batch_sampling_result]
        # r=batch_sampling_x[0],x=3个16的tensor
        # t=len(r)
        # print(t)
        batch_train_logits = model(batch_sampling_x)
        loss = criterion(batch_train_logits, batch_src_label)
        optimizer.zero_grad()
        loss.backward()  # 反向传播计算参数的梯度
        optimizer.step()  # 使用优化方法进行梯度更新
        print("Epoch {:03d} Batch {:03d} Loss: {:.4f}".format(e, batch, loss.item()))
    test()

在multihop_sampling函数中添加显示 def multihop_sampling(src_nodes, sample_nums, neighbor_table): sampling_result = [src_nodes] print(src_nodes)# 显示2------------------------------------------------------显示2 for k, hopk_num in enumerate(sample_nums): r=hopk_num print(k)# 显示3------------------------------------------------------显示3 print(r)# 显示4--------------------------------------------------- 显示4 q=sampling_result[k] print(q)# 显示5---------------------------------------------------显示5 hopk_result = sampling(sampling_result[k], hopk_num, neighbor_table) sampling_result.append(hopk_result) return sampling_result 在sampling函数中添加显示 def sampling(src_nodes, sample_num, neighbor_table): results = [] print(src_nodes)# 显示6---------------------------------------------------显示6 对应的multihop_sampling中的循环结果为: [107 95 10 100 24 118 108 50 83 68 33 19 90 3 92 83]#显示1 [107 95 10 100 24 118 108 50 83 68 33 19 90 3 92 83]#显示2 0#显示3 10#显示4 [107 95 10 100 24 118 108 50 83 68 33 19 90 3 92 83]#显示5 [107 95 10 100 24 118 108 50 83 68 33 19 90 3 92 83]#显示6 1#二轮显示3 10#二轮显示4 [ 541 1650 541 1113 1650 1113 541 541 971 1113 334 1303 456 2182 1580 1580 2199 2200 1628 861 2545 476 476 2545 2545 476 476 2545 476 476 2056 2056 2056 1602 2056 1602 1602 2056 2056 1602 1701 2141 1701 17 17 201 2141 2139 1636 17 1343 842 2165 1690 1507 842 2166 1616 2165 554 2209 1647 2157 1647 2209 2209 2209 1647 2157 2157 1441 1441 1441 1441 1441 1441 1441 1441 1441 1441 2581 2581 2581 2581 2581 2581 2581 2581 1520 2581 391 1986 1358 1986 1358 1358 391 1986 1358 1358 2040 1051 286 2121 1051 2119 286 2120 911 2119 1939 1939 1939 1939 1939 1939 1939 1939 1939 1939 1358 155 155 156 155 1358 155 1358 156 1358 2544 2544 2544 2544 2544 2544 2544 2544 2544 2544 1836 898 1836 898 898 898 1836 1836 1836 1836 1520 2581 2581 2581 1520 1520 1520 1520 2581 2581]#二轮显示5,总共1000维 [ 541 1650 541 1113 1650 1113 541 541 971 1113 334 1303 456 2182 1580 1580 2199 2200 1628 861 2545 476 476 2545 2545 476 476 2545 476 476 2056 2056 2056 1602 2056 1602 1602 2056 2056 1602 1701 2141 1701 17 17 201 2141 2139 1636 17 1343 842 2165 1690 1507 842 2166 1616 2165 554 2209 1647 2157 1647 2209 2209 2209 1647 2157 2157 1441 1441 1441 1441 1441 1441 1441 1441 1441 1441 2581 2581 2581 2581 2581 2581 2581 2581 1520 2581 391 1986 1358 1986 1358 1358 391 1986 1358 1358 2040 1051 286 2121 1051 2119 286 2120 911 2119 1939 1939 1939 1939 1939 1939 1939 1939 1939 1939 1358 155 155 156 155 1358 155 1358 156 1358 2544 2544 2544 2544 2544 2544 2544 2544 2544 2544 1836 898 1836 898 898 898 1836 1836 1836 1836 1520 2581 2581 2581 1520 1520 1520 1520 2581 2581]#二轮显示6 问题1:二轮的数据怎么生成的? 问题2:显示5中的sampling_result[k]不应该为sampling_result的第k个数据,不应是一个数据吗?怎么显示的是一个张量?

FighterLYL commented 4 years ago

问题1: 这里的第二轮数据实际对应的是这批节点[107 95 10 100 24 118 108 50 83 68 33 19 90 3 92 83]的一阶邻居采样结果,由于有16个源节点,1阶采样的个数为10(对应于显示4的打印结果),因此这批节点的一阶采样结果中总共有16 x 10=160个节点,输出的结果将所有这些结果拼接在了一起,即二轮显示5/6的结果

问题2: 变量sampling_result存储的类型是一个由np.ndarray构成的List对象,因此sampling_result[k]取出来的是一个ndarray对象,sampling_result[0]对应的是源节点,其余对应的第k阶采样的结果

你可以仔细看看函数sampling