acbull / pyHGT

Code for "Heterogeneous Graph Transformer" (WWW'20), which is based on pytorch_geometric
MIT License
775 stars 162 forks source link

HGT for a many(source)-to-one(target) node classification task #36

Open shaanchandra opened 3 years ago

shaanchandra commented 3 years ago

Hi @acbull , Thank you for the great work here! Really appreciate it ! I had the following questoin:

I am trying to benchmark HGT on the Online Retai-2 dataset that has transaction records of the type (invoice_id, stock_id, customer_id, country_id, quantity, price). I made an artificial node classification task for this where we want to classify the customer nodes as either low, mid, or high activity (based on some segregation of their spending behavior).

So in this case, the customer nodes are the target nodes. For the source nodes, we have 2 options:

I wanted to know your thoughts on how can the code be adapted for this (if at all? ) Since the current version of tasks seem quite limiting in the sense that there can be just one type of source and target nodes in a heterogenous graph processing where many kinds of nodes exist.

Specifically, I planned to add multiple source nodes to each target node in this way:

train_pairs = defaultdict(list)
valid_pairs = defaultdict(list)
test_pairs  = defaultdict(list)
for target_id in self.config['graph'].edge_list['customer']['stock']['ST_CU']:
    for source_id in self.config['graph'].edge_list['customer']['stock']['ST_CU'][target_id]:
        _time = self.config['graph'].edge_list['customer']['stock']['ST_CU'][target_id][source_id]
        if _time in self.train_range:
            if target_id not in train_pairs:
                train_pairs[target_id].append([source_id, _time])
        elif _time in self.valid_range:
            if target_id not in valid_pairs:
                valid_pairs[target_id].append([source_id, _time])
        else:
            if target_id not in test_pairs:
                test_pairs[target_id].append([source_id, _time])

NOTE: the train_pairs etc are now defaultdict(list) and we append each [source_id, _time] edge information to each target_id key. However, this would also change the sample_sub_graph() method and all the subsequent methods like to_torch() thereafter. Especially, how do propose to change for example this part layer_data[_type][_id] = [len(layer_data[_type]), _time] in sample_sub_graph() ? Since now, we have multiple _time instances coming from the multiple source edges? Essentially what am asking is how do we incorporate the generalized use-case of many source nodes connected to a target node and how to determine the _time information in such case. Does the entire rest of the framework support this more practical use-case?

I am happy to share more information or explain better if this is not clear already!

I also wanted to know why do you call the PV and PF tasks as node classification? It looks more like an edge prediction task between edges of nodes of a graph. This is also more apparent from your code where you mask edges between target and source nodes. In node classification tasks, the labels are not really nodes of the graph. If you are predicting if 2 nodes of a graph are connected then it is a node prediction task isn't it? Please correct me if I am wrong!

Looking forward to your reply!