joshcarty / tfgnn-ogb

Examples using TensorFlow GNN with Open Graph Benchmark datasets.
MIT License
7 stars 0 forks source link

Add model for GAT #4

Closed SidneyLann closed 2 years ago

SidneyLann commented 2 years ago

Could you please add a training model for GAT also?

joshcarty commented 2 years ago

Hey! Have you been able to test by replacing SimpleConvolution with a GAT layer here?

joshcarty commented 2 years ago

@SidneyLann, yeah this seems to train:

def OgbnArxivModel(
    num_classes: int,
    num_neighbourhoods: int,
    type_spec: tfgnn.GraphTensorSpec,
) -> tf.keras.Model:
    input = tf.keras.layers.Input(type_spec=type_spec)
    gnn = tfgnn.keras.ConvGNNBuilder(
        lambda edge: tfgnn.keras.layers.GATv2Convolution(
            num_heads=4,
            per_head_channels=16,
            receiver_tag=tfgnn.TARGET
        ),
        lambda node: tfgnn.keras.layers.NextStateFromConcat(
            tf.keras.layers.Dense(32)
        ),
    )
    hidden = gnn.Convolve()(input)
    for _ in range(num_neighbourhoods - 1):
        hidden = gnn.Convolve()(hidden)
    hidden = tfgnn.keras.layers.ReadoutFirstNode(node_set_name="paper")(hidden)
    output = tf.keras.layers.Dense(num_classes, activation="softmax")(hidden)
    return model.NodeClassificationModel(
        input, output, target_node="paper", label_name="label"
    )
SidneyLann commented 2 years ago

Thanks.

SidneyLann commented 2 years ago

What's the purpose of GATv2GraphUpdate? How to include it to the model?

joshcarty commented 2 years ago

I've only really interacted with the ConvGNNBuilder layer but it looks like you can use GraphUpdate layers independently. Here's a test using GATv2GraphUpdate that seems to do that. Let me know if you get that to work.

SidneyLann commented 2 years ago

done