acbull / pyHGT

Code for "Heterogeneous Graph Transformer" (WWW'20), which is based on pytorch_geometric
MIT License
775 stars 162 forks source link

Disambiguation link prediction Paper -> Author fails to run on OAG dataset #46

Closed vldbnc closed 2 years ago

vldbnc commented 2 years ago

Default Settings

class DefaultArguments(object):
    def __init__(self):
        self.data_dir = './dataset/process_output'
        self.data_dir = './dataset/oag_output'
        self.model_dir = './model_save'
        self.task_name = 'AD'
        self.cuda = -1
        self.domain = '_CS'
        self.domain = '_ML'
        self.conv_name = 'hgt'
        self.n_hid = 400
        self.n_heads = 8
        self.n_layers = 3
        self.dropout = 0.2
        self.sample_depth = 6
        self.sample_width = 128
        # Optimization arguments
        self.optimizer = 'adamw'
        self.data_percentage = 1.0
        self.n_epoch = 100
        self.n_pool = 8
        self.n_batch = 32
        self.repeat = 2
        self.batch_size = 256
#         self.batch_size = 4
        self.clip = 0.25

args = DefaultArguments()

Im not able to run train_author_disambiguation.py task for '_ML' nor '_CS' (OAG dataset) which fails with following error in gnn function. Is there any known issue for this?


<ipython-input-113-17fa53d2f5f9> in <module>
      9 #     print(ylabel) #dic
     10     gnn.forward(node_feature.to(device), node_type.to(device), edge_time.to(device), edge_index.to(device),
---> 11                                    edge_type.to(device))

~/SageMaker/datascience/pyHGT/OAG/pyHGT/model.py in forward(self, node_feature, node_type, edge_time, edge_index, edge_type)
     77         del res
     78         for gc in self.gcs:
---> 79             meta_xs = gc(meta_xs, node_type, edge_index, edge_type, edge_time)
     80         return meta_xs

~/anaconda3/envs/gnn_hgt_gpu/lib/python3.6/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
   1049         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1050                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1051             return forward_call(*input, **kwargs)
   1052         # Do not call functions when jit is used
   1053         full_backward_hooks, non_full_backward_hooks = [], []

~/SageMaker/datascience/pyHGT/OAG/pyHGT/conv.py in forward(self, meta_xs, node_type, edge_index, edge_type, edge_time)
    316     def forward(self, meta_xs, node_type, edge_index, edge_type, edge_time):
    317         if self.conv_name == 'hgt':
--> 318             return self.base_conv(meta_xs, node_type, edge_index, edge_type, edge_time)
    319         elif self.conv_name == 'gcn':
    320             return self.base_conv(meta_xs, edge_index)

~/anaconda3/envs/gnn_hgt_gpu/lib/python3.6/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
   1049         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1050                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1051             return forward_call(*input, **kwargs)
   1052         # Do not call functions when jit is used
   1053         full_backward_hooks, non_full_backward_hooks = [], []

~/SageMaker/datascience/pyHGT/OAG/pyHGT/conv.py in forward(self, node_inp, node_type, edge_index, edge_type, edge_time)
     57     def forward(self, node_inp, node_type, edge_index, edge_type, edge_time):
     58         return self.propagate(edge_index, node_inp=node_inp, node_type=node_type, \
---> 59                               edge_type=edge_type, edge_time = edge_time)
     60 
     61     def message(self, edge_index_i, node_inp_i, node_inp_j, node_type_i, node_type_j, edge_type, edge_time):

~/anaconda3/envs/gnn_hgt_gpu/lib/python3.6/site-packages/torch_geometric/nn/conv/message_passing.py in propagate(self, edge_index, size, **kwargs)
    308 
    309                 coll_dict = self.__collect__(self.__user_args__, edge_index,
--> 310                                              size, kwargs)
    311 
    312                 msg_kwargs = self.inspector.distribute('message', coll_dict)

~/anaconda3/envs/gnn_hgt_gpu/lib/python3.6/site-packages/torch_geometric/nn/conv/message_passing.py in __collect__(self, args, edge_index, size, kwargs)
    198 
    199                 if isinstance(data, Tensor):
--> 200                     self.__set_size__(size, dim, data)
    201                     data = self.__lift__(data, edge_index,
    202                                          j if arg[-2:] == '_j' else i)

~/anaconda3/envs/gnn_hgt_gpu/lib/python3.6/site-packages/torch_geometric/nn/conv/message_passing.py in __set_size__(self, size, dim, src)
    161         if the_size is None:
    162             size[dim] = src.size(self.node_dim)
--> 163         elif the_size != src.size(self.node_dim):
    164             raise ValueError(
    165                 (f'Encountered tensor with size {src.size(self.node_dim)} in '

IndexError: Dimension out of range (expected to be in range of [-1, 0], but got -2)```