tensorflow / gnn

TensorFlow GNN is a library to build Graph Neural Networks on the TensorFlow platform.
Apache License 2.0
1.34k stars 174 forks source link

Problems for binary heterogeneous graph classification with gat_v2.GATv2MPNNGraphUpdate #257

Closed Hanc1999 closed 1 year ago

Hanc1999 commented 1 year ago

Following the pipeline shown in input pipeline and modeling, building model for a binary classification for the heterogeneous directed graph, with gat_v2.GATv2MPNNGraphUpdate, raising errors. The following is my gnn model:

def gnn(graph):
    for i in range(2):
        graph = gat_v2.GATv2MPNNGraphUpdate(units=5, message_dim=5, num_heads=1, receiver_tag=tfgnn.SOURCE)(graph)
    return graph

While the graph spec after preprocess model is like:

GraphTensorSpec({'context': ContextSpec({'features': {'hidden_state': TensorSpec(shape=(None, 2), dtype=tf.float32, name=None)}, 'sizes': TensorSpec(shape=(None,), dtype=tf.int32, name=None)}, TensorShape([]), tf.int32, None), 
'node_sets': 
{'operate': NodeSetSpec({'features': {'hidden_state': TensorSpec(shape=(None, 9), dtype=tf.float32, name=None)}, 'sizes': TensorSpec(shape=(None,), dtype=tf.int32, name=None)}, TensorShape([]), tf.int32, None), 
'source': NodeSetSpec({'features': {'hidden_state': TensorSpec(shape=(None, 1), dtype=tf.float32, name=None)}, 'sizes': TensorSpec(shape=(None,), dtype=tf.int32, name=None)}, TensorShape([]), tf.int32, None)}, 
'edge_sets': 
{'op2op': EdgeSetSpec({'features': {}, 'sizes': TensorSpec(shape=(None,), dtype=tf.int32, name=None), 'adjacency': AdjacencySpec({'#index.0': TensorSpec(shape=(None,), dtype=tf.int32, name=None), '#index.1': TensorSpec(shape=(None,), dtype=tf.int32, name=None)}, TensorShape([]), tf.int32, {'#index.0': 'operate', '#index.1': 'operate'})}, TensorShape([]), tf.int32, None), 
'src2op': EdgeSetSpec({'features': {}, 'sizes': TensorSpec(shape=(None,), dtype=tf.int32, name=None), 'adjacency': AdjacencySpec({'#index.0': TensorSpec(shape=(None,), dtype=tf.int32, name=None), '#index.1': TensorSpec(shape=(None,), dtype=tf.int32, name=None)}, TensorShape([]), tf.int32, {'#index.0': 'source', '#index.1': 'operate'})}, TensorShape([]), tf.int32, None)}}, TensorShape([]), tf.int32, None)

Where there are 2 node sets, operate and source, while there is always 1 'source' node in each graph, The classification model wanting to pool features from all nodes (both node sets), which is defined as:

model_input_spec, _, = dataset.element_spec  
model_input = tf.keras.layers.Input(type_spec=model_input_spec)
graph = gnn(model_input) # apply gnn model
pooled_features_s = tfgnn.keras.layers.Pool(tfgnn.CONTEXT, "mean", node_set_name="source")(graph)
pooled_features_op = tfgnn.keras.layers.Pool(tfgnn.CONTEXT, "mean", node_set_name="operate")(graph)
pooled_features = tf.keras.layers.concatenate([pooled_features_s, pooled_features_op])
logits = tf.keras.layers.Dense(1)(pooled_features)
model = tf.keras.Model(model_input, logits)

But when training following error raises:

File "/.../lib/python3.8/site-packages/tensorflow_gnn/graph/graph_tensor_ops.py", line 106, in broadcast_node_to_edges
      return tf.gather(node_value, adjacency[node_tag])
Node: 'model_6/graph_update_6/node_set_update_1/gat_v2_conv/GatherV2_1'
indices[157] = 224 is not in [0, 224)
     [[{{node model_6/graph_update_6/node_set_update_1/gat_v2_conv/GatherV2_1}}]] [Op:__inference_train_function_30804]

The error pattern always turns that indices[xx] = yy is not in [0, yy). Could anyone please have a look? Thanks.

Hanc1999 commented 1 year ago

Found the raised error similar to part of another issue #230

Hanc1999 commented 1 year ago

It turns out that it's my own problem with GraphTensor construction. But I will still post the reasoning since it's related to the _merge_batch_tocomponents() function and may be helpful to some others.

My original graph samples have a problem in that there will be a wrong edge that points to a non-existing node; for example, for a node set with nodes [1, ..., n], there will be a wrong edge pointing to node n+1. However, since graphs in the same batch are combined together, nodes from different GraphTensors are re-indexed, so the 'first' graph's wrong edge will point to one of the 'second' graph's nodes, while only the wrong edge of the 'last' graph will point to nowhere and raise the error in the form of 'indices[xx] = yy is not in [0, yy).'