acbull / pyHGT

Code for "Heterogeneous Graph Transformer" (WWW'20), which is based on pytorch_geometric
MIT License
775 stars 162 forks source link

Bug in eval_ogbn_mag.py #34

Closed lingfanyu closed 3 years ago

lingfanyu commented 3 years ago

There are not enough values to unpack: https://github.com/acbull/pyHGT/blob/16547c0c0a6977c40b8efa88a7f3e40cf1955362/ogbn-mag/eval_ogbn_mag.py#L140 yindxs is not in return value of ogbn_sample https://github.com/acbull/pyHGT/blob/16547c0c0a6977c40b8efa88a7f3e40cf1955362/ogbn-mag/eval_ogbn_mag.py#L81

acbull commented 3 years ago

Sorry. Already fixed it.

acbull commented 3 years ago
def ogbn_sample(seed, samp_nodes):
    np.random.seed(seed)
    ylabel      = torch.LongTensor(graph.y[samp_nodes])
    feature, times, edge_list, indxs, _ = sample_subgraph(graph, \
                inp = {'paper': np.concatenate([samp_nodes, graph.years[samp_nodes]]).reshape(2, -1).transpose()}, \
                sampled_depth = args.sample_depth, sampled_number = args.sample_width, \
                    feature_extractor = feature_MAG)
    node_feature, node_type, edge_time, edge_index, edge_type, node_dict, edge_dict = \
            to_torch(feature, times, edge_list, graph)
    train_mask = graph.train_mask[indxs['paper']]
    valid_mask = graph.valid_mask[indxs['paper']]
    test_mask  = graph.test_mask[indxs['paper']]
    ylabel     = graph.y[indxs['paper']]
    yindxs     = indxs['paper'][test_mask]
    return node_feature, node_type, edge_time, edge_index, edge_type, (train_mask, valid_mask, test_mask), ylabel, yindxs
lingfanyu commented 3 years ago

Thanks!